diff --git a/src/openfermion/resource_estimates/pbc/thc/factorizations/thc_jax.py b/src/openfermion/resource_estimates/pbc/thc/factorizations/thc_jax.py index a5aacea30..3ffbf0919 100644 --- a/src/openfermion/resource_estimates/pbc/thc/factorizations/thc_jax.py +++ b/src/openfermion/resource_estimates/pbc/thc/factorizations/thc_jax.py @@ -27,6 +27,7 @@ # pylint: disable=wrong-import-position import math import time +from typing import cast import h5py import numpy as np @@ -213,7 +214,7 @@ def compute_objective_batched( lambda_z = jnp.sum(jnp.einsum("jpq->j", 0.5 * jnp.abs(mpq_normalized)) ** 2.0) res = 0.5 * jnp.sum((jnp.abs(deri)) ** 2) + penalty_param * lambda_z - return float(res) + return cast(float, res) def prepare_batched_data_indx_arrays(