try making tf-qas
This commit is contained in:
parent
804d6acbee
commit
b7a89f007b
1 changed files with 16 additions and 7 deletions
|
|
@ -2,6 +2,7 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
import math
|
import math
|
||||||
|
from multiprocessing import Pool
|
||||||
import random
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import assert_type, override
|
from typing import assert_type, override
|
||||||
|
|
@ -49,7 +50,7 @@ class QuantumCircuit:
|
||||||
two_qubit_gates: int
|
two_qubit_gates: int
|
||||||
params: int
|
params: int
|
||||||
paths: int = 0
|
paths: int = 0
|
||||||
expressibility: float = -10000.0
|
expressibility: float = float("inf")
|
||||||
|
|
||||||
def calculate_paths(self):
|
def calculate_paths(self):
|
||||||
path_counts = [1 for _ in range(self.qubits)]
|
path_counts = [1 for _ in range(self.qubits)]
|
||||||
|
|
@ -139,7 +140,8 @@ class QuantumCircuit:
|
||||||
|
|
||||||
kl = kl_divergence(p, q)
|
kl = kl_divergence(p, q)
|
||||||
self.expressibility = kl
|
self.expressibility = kl
|
||||||
return kl
|
print(f"completed with seed {seed}")
|
||||||
|
return self
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
|
@ -243,10 +245,17 @@ if __name__ == "__main__":
|
||||||
for circuit in tqdm(circuits):
|
for circuit in tqdm(circuits):
|
||||||
circuit.calculate_paths()
|
circuit.calculate_paths()
|
||||||
circuits.sort(key=lambda qc: qc.paths, reverse=True)
|
circuits.sort(key=lambda qc: qc.paths, reverse=True)
|
||||||
for circuit in tqdm(circuits[:proxy_pass_amount]):
|
seeds = [rng.randint(0, 1000000000) for _ in range(proxy_pass_amount)]
|
||||||
circuit.expressibility_estimate(
|
|
||||||
expressibility_samples, rng.randint(0, 100000000)
|
def starmap_func(circ, seed):
|
||||||
|
return circ.expressibility_estimate(expressibility_samples, seed)
|
||||||
|
|
||||||
|
with Pool() as p:
|
||||||
|
final_circuits = p.starmap(
|
||||||
|
starmap_func,
|
||||||
|
zip(circuits[:proxy_pass_amount], seeds),
|
||||||
)
|
)
|
||||||
circuits.sort(key=lambda qc: qc.expressibility)
|
|
||||||
for i, circuit in enumerate(circuits[:proxy_pass_amount]):
|
final_circuits.sort(key=lambda qc: qc.expressibility)
|
||||||
|
for i, circuit in enumerate(final_circuits):
|
||||||
print(f"circuit {i}:\n{circuit}")
|
print(f"circuit {i}:\n{circuit}")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue