It works? but slowly

This commit is contained in:
Noa Aarts 2026-01-11 10:25:51 +01:00
parent 6dfbffd05d
commit 804d6acbee
Signed by: noa
GPG key ID: 1850932741EFF672
8 changed files with 420 additions and 28 deletions

4
src/qas_flow/__init__.py Normal file
View file

@ -0,0 +1,4 @@
from .stream import Stream
from .funcs import map, filter, take, skip, batch, enumerate, collect
__all__ = ["Stream", "map", "filter", "take", "skip", "batch", "enumerate", "collect"]

76
src/qas_flow/funcs.py Normal file
View file

@ -0,0 +1,76 @@
from typing import Callable
from .stream import Stream, T, U
@Stream.extension()
def map(stream: Stream[T], fn: Callable[[T], U]) -> Stream[U]:
def gen():
for x in stream:
yield fn(x)
return Stream(gen())
@Stream.extension()
def filter(stream: Stream[T], pred: Callable[[T], bool]) -> Stream[T]:
def gen():
for x in stream:
if pred(x):
yield x
return Stream(gen())
@Stream.extension()
def take(stream: Stream[T], n: int) -> Stream[T]:
def gen():
c = 0
for x in stream:
if c < n:
c += 1
yield x
else:
return
return Stream(gen())
@Stream.extension()
def skip(stream: Stream[T], n: int) -> Stream[T]:
def gen():
c = 0
for x in stream:
c += 1
if c > n:
yield x
return Stream(gen())
@Stream.extension()
def batch(stream: Stream[T], n: int) -> Stream[list[T]]:
def gen():
ls: list[T] = []
for x in stream:
ls.append(x)
if len(ls) == n:
yield ls
ls = []
return Stream(gen())
@Stream.extension()
def enumerate(stream: Stream[T]) -> Stream[tuple[int, T]]:
def gen():
idx = 0
for x in stream:
yield (idx, x)
idx += 1
return Stream(gen())
@Stream.extension()
def collect(stream: Stream[T]) -> list[T]:
return [v for v in stream]

39
src/qas_flow/stream.py Normal file
View file

@ -0,0 +1,39 @@
from collections.abc import Iterator
from typing import Any, Callable, Generic, TypeVar, final
T = TypeVar("T")
U = TypeVar("U")
Op = Callable[["Stream[T]"], "Stream[U]"]
@final
class Stream(Generic[T]):
_extensions: dict[str, Callable[..., Any]] = {}
def __init__(self, it: Iterator[T]) -> None:
self._it = it
def __iter__(self) -> Iterator[T]:
return self._it
@classmethod
def extension(cls, name: str | None = None):
"""Register a function as Stream.<name>(...). First arg will be the stream."""
def deco(fn: Callable[..., Any]):
cls._extensions[name or fn.__name__] = fn
return fn
return deco
def __getattr__(self, attr: str):
fn = self._extensions.get(attr)
if fn is None:
raise AttributeError(attr)
def bound(*args, **kwargs):
return fn(self, *args, **kwargs)
return bound

252
src/tf-qas.py Executable file
View file

@ -0,0 +1,252 @@
#!/usr/bin/env python
from dataclasses import dataclass
from enum import IntEnum
import math
import random
import numpy as np
from typing import assert_type, override
from qiskit.circuit import ParameterVector
from tqdm import tqdm
from qiskit import QuantumCircuit as QiskitCircuit, transpile
from qiskit_aer import AerSimulator
from qas_flow import Stream
TWO_QUBIT_GATE_PROBABILITY = 0.5
class GateType(IntEnum):
RX = 1
RY = 2
RZ = 3
XX = 4
YY = 5
ZZ = 6
@dataclass(frozen=True)
class Gate:
type: GateType
qubits: int | tuple[int, int]
param_idx: int
def haar_fidelity_pdf(F: np.ndarray, d: int) -> np.ndarray:
# p(F) = (d-1) * (1-F)^(d-2), for F in [0,1]
return (d - 1) * np.power(1.0 - F, d - 2)
def kl_divergence(p: np.ndarray, q: np.ndarray) -> float:
# assumes p, q are normalized and strictly positive
return float(np.sum(p * np.log(p / q)))
@dataclass()
class QuantumCircuit:
qubits: int
gates: list[list[Gate]]
single_qubit_gates: int
two_qubit_gates: int
params: int
paths: int = 0
expressibility: float = -10000.0
def calculate_paths(self):
path_counts = [1 for _ in range(self.qubits)]
for layer in self.gates:
for gate in layer:
if gate.type <= 3 or type(gate.qubits) is int:
continue
(q1, q2) = gate.qubits
path_counts[q1] = path_counts[q1] + path_counts[q2]
path_counts[q2] = path_counts[q1]
self.paths = sum(path_counts)
def to_qiskit(self):
qc = QiskitCircuit(self.qubits)
thetas = ParameterVector("theta", self.params)
for layer in self.gates:
for gate in layer:
theta = thetas[gate.param_idx]
if gate.type == GateType.RX:
qc.rx(theta, gate.qubits)
elif gate.type == GateType.RY:
qc.ry(theta, gate.qubits)
elif gate.type == GateType.RZ:
qc.rz(theta, gate.qubits)
elif gate.type == GateType.XX:
qc.rxx(theta, *gate.qubits)
elif gate.type == GateType.YY:
qc.ryy(theta, *gate.qubits)
elif gate.type == GateType.ZZ:
qc.rzz(theta, *gate.qubits)
qc.save_statevector()
return qc, thetas
def expressibility_estimate(
self, samples: int, seed: int, bins: int = 75, eps: float = 1e-12
):
qc, thetas = self.to_qiskit()
if self.params <= 0:
return float("inf")
rng = random.Random(seed)
d = 1 << self.qubits
backend = AerSimulator(method="statevector", seed_simulator=seed)
tqc = transpile(qc, backend)
# 1) build 2*samples parameterized circuits (left+right)
binds = []
for _ in range(2 * samples):
binds.append(
{thetas: [rng.random() for _ in range(self.params)]}
) # SAFE binding
# 2) compute statevectors (no backend/job overhead)
job = backend.run(
[tqc.assign_parameters(bind, inplace=False) for bind in binds]
)
result = job.result()
sv = [result.get_statevector(i) for i in range(len(binds))]
left = sv[:samples]
right = sv[samples:]
# 3) fidelities F = |<ψ|φ>|^2
inners = np.array(
[l.inner(r) for l, r in zip(left, right)], dtype=np.complex128
)
F = (inners * inners.conjugate()).real
F = np.clip(F, 0.0, 1.0) # numerical safety
# 4) empirical histogram (as a probability mass over bins)
hist, edges = np.histogram(F, bins=bins, range=(0.0, 1.0), density=False)
p = hist.astype(np.float64)
p = p + eps
p = p / p.sum()
# 5) Haar distribution mass per bin (integrate PDF over each bin)
# approximate via midpoint rule
mids = 0.5 * (edges[:-1] + edges[1:])
widths = edges[1:] - edges[:-1]
q = haar_fidelity_pdf(mids, d) * widths
q = q + eps
q = q / q.sum()
kl = kl_divergence(p, q)
self.expressibility = kl
return kl
@override
def __str__(self):
strs = ["-" for _ in range(self.qubits)]
for layer in self.gates:
for i in range(self.qubits):
strs[i] += "---"
idx = 0
for gate in layer:
match gate.type:
case GateType.RX:
strs[gate.qubits] = strs[gate.qubits][:-2] + "RX"
case GateType.RY:
strs[gate.qubits] = strs[gate.qubits][:-2] + "RY"
case GateType.RZ:
strs[gate.qubits] = strs[gate.qubits][:-2] + "RZ"
case GateType.XX:
(q1, q2) = gate.qubits
strs[q1] = strs[q1][:-2] + f"X{idx}"
strs[q2] = strs[q2][:-2] + f"X{idx}"
idx += 1
case GateType.YY:
(q1, q2) = gate.qubits
strs[q1] = strs[q1][:-2] + f"Y{idx}"
strs[q2] = strs[q2][:-2] + f"Y{idx}"
idx += 1
case GateType.ZZ:
(q1, q2) = gate.qubits
strs[q1] = strs[q1][:-2] + f"Z{idx}"
strs[q2] = strs[q2][:-2] + f"Z{idx}"
idx += 1
return (
"\n".join(strs)
+ f"\npaths: {self.paths}, expressibility: {self.expressibility}"
)
def even_parity(qubits: int):
return [(x, x + 1) for x in range(0, qubits, 2)]
def odd_parity(qubits: int):
return [(x, x + 1) for x in range(1, qubits, 2)]
def sample_circuit(rng: random.Random, qubits: int, depth: int) -> QuantumCircuit:
even = even_parity(qubits)
odd = odd_parity(qubits)
total_single = 0
total_double = 0
params = 0
gates: list[list[Gate]] = []
for _ in range(depth):
gate_type_offset = 3 if rng.random() < TWO_QUBIT_GATE_PROBABILITY else 0
gate_type = rng.randint(1, 3)
gate_locations = even if rng.random() < 0.5 else odd
if gate_type_offset == 0:
gates.append(
[Gate(GateType(gate_type), x, params) for (x, _) in gate_locations]
)
total_single += len(gate_locations)
else:
gates.append(
[
Gate(GateType(gate_type + gate_type_offset), xy, params)
for xy in gate_locations
if xy[1] != qubits
]
)
total_double += len(gate_locations)
params += 1
return QuantumCircuit(qubits, gates, total_single, total_double, params)
def sample_generator(rng: random.Random, qubits: int, depth: int):
while True:
yield sample_circuit(rng, qubits, depth)
def more_single_than_double(qc: QuantumCircuit) -> bool:
return qc.single_qubit_gates >= qc.two_qubit_gates
if __name__ == "__main__":
rng = random.Random()
qubits = 6
depth = 10
sample_amount = 50000
expressibility_samples = 2000
proxy_pass_amount = 5000
circuits = (
Stream(sample_generator(rng, qubits, depth))
.filter(more_single_than_double)
.take(sample_amount)
.collect()
)
for circuit in tqdm(circuits):
circuit.calculate_paths()
circuits.sort(key=lambda qc: qc.paths, reverse=True)
for circuit in tqdm(circuits[:proxy_pass_amount]):
circuit.expressibility_estimate(
expressibility_samples, rng.randint(0, 100000000)
)
circuits.sort(key=lambda qc: qc.expressibility)
for i, circuit in enumerate(circuits[:proxy_pass_amount]):
print(f"circuit {i}:\n{circuit}")