batch the calls in grad
This commit is contained in:
parent
4d4751e64d
commit
0865fddcb5
1 changed files with 26 additions and 4 deletions
|
|
@ -400,14 +400,35 @@ def adam_optimize_tfim_energy(
|
||||||
return energy_expectation_from_sv(H, sv)
|
return energy_expectation_from_sv(H, sv)
|
||||||
|
|
||||||
def grad(th: np.ndarray) -> np.ndarray:
|
def grad(th: np.ndarray) -> np.ndarray:
|
||||||
# parameter-shift: d/dθ f(θ) = 0.5*(f(θ+π/2)-f(θ-π/2))
|
theta_list = []
|
||||||
g = np.zeros_like(th)
|
|
||||||
for i in range(p):
|
for i in range(p):
|
||||||
plus = th.copy()
|
plus = th.copy()
|
||||||
minus = th.copy()
|
|
||||||
plus[i] += param_shift
|
plus[i] += param_shift
|
||||||
|
theta_list.append(plus)
|
||||||
|
minus = th.copy()
|
||||||
minus[i] -= param_shift
|
minus[i] -= param_shift
|
||||||
g[i] = 0.5 * (E(plus) - E(minus))
|
theta_list.append(minus)
|
||||||
|
tla = np.array(theta_list).T
|
||||||
|
|
||||||
|
binds = [{param: list(row) for param, row in zip(thetas.params, tla)}]
|
||||||
|
|
||||||
|
job = backend.run([tqc], parameter_binds=binds)
|
||||||
|
res = job.result()
|
||||||
|
|
||||||
|
energies = np.array(
|
||||||
|
[
|
||||||
|
energy_expectation_from_sv(H, sv)
|
||||||
|
for sv in [
|
||||||
|
np.asarray(res.get_statevector(k), dtype=np.complex128)
|
||||||
|
for k in range(2 * p)
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
g = np.empty(p, dtype=np.float64)
|
||||||
|
for i in range(p):
|
||||||
|
g[i] = 0.5 * (energies[2 * i] - energies[2 * i + 1])
|
||||||
|
|
||||||
return g
|
return g
|
||||||
|
|
||||||
m = np.zeros_like(theta)
|
m = np.zeros_like(theta)
|
||||||
|
|
@ -614,6 +635,7 @@ def main():
|
||||||
)
|
)
|
||||||
if circ.gt_error is not None and circ.gt_error < min_error:
|
if circ.gt_error is not None and circ.gt_error < min_error:
|
||||||
print(f"new best error for {queried}: {circ.gt_error}")
|
print(f"new best error for {queried}: {circ.gt_error}")
|
||||||
|
min_error = circ.gt_error
|
||||||
queried += 1
|
queried += 1
|
||||||
if circ.gt_success:
|
if circ.gt_success:
|
||||||
success_idx = i
|
success_idx = i
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue