try making tf-qas

This commit is contained in:
Noa Aarts 2026-01-12 09:30:18 +01:00
parent 804d6acbee
commit b7a89f007b
Signed by: noa
GPG key ID: 1850932741EFF672

View file

@ -2,6 +2,7 @@
from dataclasses import dataclass
from enum import IntEnum
import math
from multiprocessing import Pool
import random
import numpy as np
from typing import assert_type, override
@ -49,7 +50,7 @@ class QuantumCircuit:
two_qubit_gates: int
params: int
paths: int = 0
expressibility: float = -10000.0
expressibility: float = float("inf")
def calculate_paths(self):
path_counts = [1 for _ in range(self.qubits)]
@ -139,7 +140,8 @@ class QuantumCircuit:
kl = kl_divergence(p, q)
self.expressibility = kl
return kl
print(f"completed with seed {seed}")
return self
@override
def __str__(self):
@ -243,10 +245,17 @@ if __name__ == "__main__":
for circuit in tqdm(circuits):
circuit.calculate_paths()
circuits.sort(key=lambda qc: qc.paths, reverse=True)
for circuit in tqdm(circuits[:proxy_pass_amount]):
circuit.expressibility_estimate(
expressibility_samples, rng.randint(0, 100000000)
seeds = [rng.randint(0, 1000000000) for _ in range(proxy_pass_amount)]
def starmap_func(circ, seed):
return circ.expressibility_estimate(expressibility_samples, seed)
with Pool() as p:
final_circuits = p.starmap(
starmap_func,
zip(circuits[:proxy_pass_amount], seeds),
)
circuits.sort(key=lambda qc: qc.expressibility)
for i, circuit in enumerate(circuits[:proxy_pass_amount]):
final_circuits.sort(key=lambda qc: qc.expressibility)
for i, circuit in enumerate(final_circuits):
print(f"circuit {i}:\n{circuit}")