I get 4/sec for expressibility on my laptop

This commit is contained in:
Noa Aarts 2026-01-12 10:24:48 +01:00
parent b7a89f007b
commit c0a27cecce
Signed by: noa
GPG key ID: 1850932741EFF672

View file

@ -1,11 +1,10 @@
#!/usr/bin/env python #!/usr/bin/env python
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum from enum import IntEnum
import math
from multiprocessing import Pool from multiprocessing import Pool
import random import random
import numpy as np import numpy as np
from typing import assert_type, override from typing import override
from qiskit.circuit import ParameterVector from qiskit.circuit import ParameterVector
from tqdm import tqdm from tqdm import tqdm
from qiskit import QuantumCircuit as QiskitCircuit, transpile from qiskit import QuantumCircuit as QiskitCircuit, transpile
@ -99,30 +98,28 @@ class QuantumCircuit:
backend = AerSimulator(method="statevector", seed_simulator=seed) backend = AerSimulator(method="statevector", seed_simulator=seed)
tqc = transpile(qc, backend) tqc = transpile(qc, backend, optimization_level=0)
# 1) build 2*samples parameterized circuits (left+right) nexp = 2 * samples
binds = []
for _ in range(2 * samples):
binds.append(
{thetas: [rng.random() for _ in range(self.params)]}
) # SAFE binding
# 2) compute statevectors (no backend/job overhead) binds = [
job = backend.run( {param: [rng.random() for _ in range(nexp)] for param in thetas.params}
[tqc.assign_parameters(bind, inplace=False) for bind in binds] ]
)
job = backend.run([tqc], parameter_binds=binds)
result = job.result() result = job.result()
sv = [result.get_statevector(i) for i in range(len(binds))]
sv = [
np.asarray(result.get_statevector(i), dtype=np.complex128)
for i in range(nexp)
]
left = sv[:samples] left = sv[:samples]
right = sv[samples:] right = sv[samples:]
# 3) fidelities F = |<ψ|φ>|^2
inners = np.array( inners = np.array(
[l.inner(r) for l, r in zip(left, right)], dtype=np.complex128 [np.vdot(l, r) for l, r in zip(left, right)], dtype=np.complex128
) )
F = (inners * inners.conjugate()).real F = (inners.conjugate() * inners).real
F = np.clip(F, 0.0, 1.0) # numerical safety
# 4) empirical histogram (as a probability mass over bins) # 4) empirical histogram (as a probability mass over bins)
hist, edges = np.histogram(F, bins=bins, range=(0.0, 1.0), density=False) hist, edges = np.histogram(F, bins=bins, range=(0.0, 1.0), density=False)
@ -140,7 +137,6 @@ class QuantumCircuit:
kl = kl_divergence(p, q) kl = kl_divergence(p, q)
self.expressibility = kl self.expressibility = kl
print(f"completed with seed {seed}")
return self return self
@override @override
@ -232,8 +228,8 @@ def more_single_than_double(qc: QuantumCircuit) -> bool:
if __name__ == "__main__": if __name__ == "__main__":
rng = random.Random() rng = random.Random()
qubits = 6 qubits = 6
depth = 10 depth = 15
sample_amount = 50000 sample_amount = 5000
expressibility_samples = 2000 expressibility_samples = 2000
proxy_pass_amount = 5000 proxy_pass_amount = 5000
circuits = ( circuits = (
@ -247,14 +243,21 @@ if __name__ == "__main__":
circuits.sort(key=lambda qc: qc.paths, reverse=True) circuits.sort(key=lambda qc: qc.paths, reverse=True)
seeds = [rng.randint(0, 1000000000) for _ in range(proxy_pass_amount)] seeds = [rng.randint(0, 1000000000) for _ in range(proxy_pass_amount)]
def starmap_func(circ, seed): def starmap_func(args):
return circ.expressibility_estimate(expressibility_samples, seed) (circ, seed, idx) = args
res = circ.expressibility_estimate(expressibility_samples, seed)
return res
final_circuits: list[QuantumCircuit] = []
with Pool() as p: with Pool() as p:
final_circuits = p.starmap( for circ in tqdm(
p.imap_unordered(
starmap_func, starmap_func,
zip(circuits[:proxy_pass_amount], seeds), zip(circuits[:proxy_pass_amount], seeds, range(proxy_pass_amount)),
) ),
total=proxy_pass_amount,
):
final_circuits.append(circ)
final_circuits.sort(key=lambda qc: qc.expressibility) final_circuits.sort(key=lambda qc: qc.expressibility)
for i, circuit in enumerate(final_circuits): for i, circuit in enumerate(final_circuits):