try making tf-qas

This commit is contained in:
Noa Aarts 2026-01-12 09:30:18 +01:00
parent 804d6acbee
commit b7a89f007b
Signed by: noa
GPG key ID: 1850932741EFF672

View file

@ -2,6 +2,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum from enum import IntEnum
import math import math
from multiprocessing import Pool
import random import random
import numpy as np import numpy as np
from typing import assert_type, override from typing import assert_type, override
@ -49,7 +50,7 @@ class QuantumCircuit:
two_qubit_gates: int two_qubit_gates: int
params: int params: int
paths: int = 0 paths: int = 0
expressibility: float = -10000.0 expressibility: float = float("inf")
def calculate_paths(self): def calculate_paths(self):
path_counts = [1 for _ in range(self.qubits)] path_counts = [1 for _ in range(self.qubits)]
@ -139,7 +140,8 @@ class QuantumCircuit:
kl = kl_divergence(p, q) kl = kl_divergence(p, q)
self.expressibility = kl self.expressibility = kl
return kl print(f"completed with seed {seed}")
return self
@override @override
def __str__(self): def __str__(self):
@ -243,10 +245,17 @@ if __name__ == "__main__":
for circuit in tqdm(circuits): for circuit in tqdm(circuits):
circuit.calculate_paths() circuit.calculate_paths()
circuits.sort(key=lambda qc: qc.paths, reverse=True) circuits.sort(key=lambda qc: qc.paths, reverse=True)
for circuit in tqdm(circuits[:proxy_pass_amount]): seeds = [rng.randint(0, 1000000000) for _ in range(proxy_pass_amount)]
circuit.expressibility_estimate(
expressibility_samples, rng.randint(0, 100000000) def starmap_func(circ, seed):
return circ.expressibility_estimate(expressibility_samples, seed)
with Pool() as p:
final_circuits = p.starmap(
starmap_func,
zip(circuits[:proxy_pass_amount], seeds),
) )
circuits.sort(key=lambda qc: qc.expressibility)
for i, circuit in enumerate(circuits[:proxy_pass_amount]): final_circuits.sort(key=lambda qc: qc.expressibility)
for i, circuit in enumerate(final_circuits):
print(f"circuit {i}:\n{circuit}") print(f"circuit {i}:\n{circuit}")