I get 4/sec for expressibility on my laptop

This commit is contained in:
Noa Aarts 2026-01-12 10:24:48 +01:00
parent b7a89f007b
commit c0a27cecce
Signed by: noa
GPG key ID: 1850932741EFF672

View file

@ -1,11 +1,10 @@
#!/usr/bin/env python
from dataclasses import dataclass
from enum import IntEnum
import math
from multiprocessing import Pool
import random
import numpy as np
from typing import assert_type, override
from typing import override
from qiskit.circuit import ParameterVector
from tqdm import tqdm
from qiskit import QuantumCircuit as QiskitCircuit, transpile
@ -99,30 +98,28 @@ class QuantumCircuit:
backend = AerSimulator(method="statevector", seed_simulator=seed)
tqc = transpile(qc, backend)
tqc = transpile(qc, backend, optimization_level=0)
# 1) build 2*samples parameterized circuits (left+right)
binds = []
for _ in range(2 * samples):
binds.append(
{thetas: [rng.random() for _ in range(self.params)]}
) # SAFE binding
nexp = 2 * samples
# 2) compute statevectors (no backend/job overhead)
job = backend.run(
[tqc.assign_parameters(bind, inplace=False) for bind in binds]
)
binds = [
{param: [rng.random() for _ in range(nexp)] for param in thetas.params}
]
job = backend.run([tqc], parameter_binds=binds)
result = job.result()
sv = [result.get_statevector(i) for i in range(len(binds))]
sv = [
np.asarray(result.get_statevector(i), dtype=np.complex128)
for i in range(nexp)
]
left = sv[:samples]
right = sv[samples:]
# 3) fidelities F = |<ψ|φ>|^2
inners = np.array(
[l.inner(r) for l, r in zip(left, right)], dtype=np.complex128
[np.vdot(l, r) for l, r in zip(left, right)], dtype=np.complex128
)
F = (inners * inners.conjugate()).real
F = np.clip(F, 0.0, 1.0) # numerical safety
F = (inners.conjugate() * inners).real
# 4) empirical histogram (as a probability mass over bins)
hist, edges = np.histogram(F, bins=bins, range=(0.0, 1.0), density=False)
@ -140,7 +137,6 @@ class QuantumCircuit:
kl = kl_divergence(p, q)
self.expressibility = kl
print(f"completed with seed {seed}")
return self
@override
@ -232,8 +228,8 @@ def more_single_than_double(qc: QuantumCircuit) -> bool:
if __name__ == "__main__":
rng = random.Random()
qubits = 6
depth = 10
sample_amount = 50000
depth = 15
sample_amount = 5000
expressibility_samples = 2000
proxy_pass_amount = 5000
circuits = (
@ -247,14 +243,21 @@ if __name__ == "__main__":
circuits.sort(key=lambda qc: qc.paths, reverse=True)
seeds = [rng.randint(0, 1000000000) for _ in range(proxy_pass_amount)]
def starmap_func(circ, seed):
return circ.expressibility_estimate(expressibility_samples, seed)
def starmap_func(args):
(circ, seed, idx) = args
res = circ.expressibility_estimate(expressibility_samples, seed)
return res
final_circuits: list[QuantumCircuit] = []
with Pool() as p:
final_circuits = p.starmap(
starmap_func,
zip(circuits[:proxy_pass_amount], seeds),
)
for circ in tqdm(
p.imap_unordered(
starmap_func,
zip(circuits[:proxy_pass_amount], seeds, range(proxy_pass_amount)),
),
total=proxy_pass_amount,
):
final_circuits.append(circ)
final_circuits.sort(key=lambda qc: qc.expressibility)
for i, circuit in enumerate(final_circuits):