batch the calls in grad

This commit is contained in:
Noa Aarts 2026-01-12 15:14:46 +01:00
parent 4d4751e64d
commit 0865fddcb5
Signed by: noa
GPG key ID: 1850932741EFF672

View file

@ -400,14 +400,35 @@ def adam_optimize_tfim_energy(
return energy_expectation_from_sv(H, sv) return energy_expectation_from_sv(H, sv)
def grad(th: np.ndarray) -> np.ndarray: def grad(th: np.ndarray) -> np.ndarray:
# parameter-shift: d/dθ f(θ) = 0.5*(f(θ+π/2)-f(θ-π/2)) theta_list = []
g = np.zeros_like(th)
for i in range(p): for i in range(p):
plus = th.copy() plus = th.copy()
minus = th.copy()
plus[i] += param_shift plus[i] += param_shift
theta_list.append(plus)
minus = th.copy()
minus[i] -= param_shift minus[i] -= param_shift
g[i] = 0.5 * (E(plus) - E(minus)) theta_list.append(minus)
tla = np.array(theta_list).T
binds = [{param: list(row) for param, row in zip(thetas.params, tla)}]
job = backend.run([tqc], parameter_binds=binds)
res = job.result()
energies = np.array(
[
energy_expectation_from_sv(H, sv)
for sv in [
np.asarray(res.get_statevector(k), dtype=np.complex128)
for k in range(2 * p)
]
]
)
g = np.empty(p, dtype=np.float64)
for i in range(p):
g[i] = 0.5 * (energies[2 * i] - energies[2 * i + 1])
return g return g
m = np.zeros_like(theta) m = np.zeros_like(theta)
@ -614,6 +635,7 @@ def main():
) )
if circ.gt_error is not None and circ.gt_error < min_error: if circ.gt_error is not None and circ.gt_error < min_error:
print(f"new best error for {queried}: {circ.gt_error}") print(f"new best error for {queried}: {circ.gt_error}")
min_error = circ.gt_error
queried += 1 queried += 1
if circ.gt_success: if circ.gt_success:
success_idx = i success_idx = i