diff --git a/src/tf-qas.py b/src/tf-qas.py index 90d6477..0947e11 100755 --- a/src/tf-qas.py +++ b/src/tf-qas.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from enum import IntEnum import math +from multiprocessing import Pool import random import numpy as np from typing import assert_type, override @@ -49,7 +50,7 @@ class QuantumCircuit: two_qubit_gates: int params: int paths: int = 0 - expressibility: float = -10000.0 + expressibility: float = float("inf") def calculate_paths(self): path_counts = [1 for _ in range(self.qubits)] @@ -139,7 +140,8 @@ class QuantumCircuit: kl = kl_divergence(p, q) self.expressibility = kl - return kl + print(f"completed with seed {seed}") + return self @override def __str__(self): @@ -243,10 +245,17 @@ if __name__ == "__main__": for circuit in tqdm(circuits): circuit.calculate_paths() circuits.sort(key=lambda qc: qc.paths, reverse=True) - for circuit in tqdm(circuits[:proxy_pass_amount]): - circuit.expressibility_estimate( - expressibility_samples, rng.randint(0, 100000000) + seeds = [rng.randint(0, 1000000000) for _ in range(proxy_pass_amount)] + + def starmap_func(circ, seed): + return circ.expressibility_estimate(expressibility_samples, seed) + + with Pool() as p: + final_circuits = p.starmap( + starmap_func, + zip(circuits[:proxy_pass_amount], seeds), ) - circuits.sort(key=lambda qc: qc.expressibility) - for i, circuit in enumerate(circuits[:proxy_pass_amount]): + + final_circuits.sort(key=lambda qc: qc.expressibility) + for i, circuit in enumerate(final_circuits): print(f"circuit {i}:\n{circuit}")