diff --git a/src/phlashlib/hmm.py b/src/phlashlib/hmm.py index 4ee5cf9..076ad3f 100644 --- a/src/phlashlib/hmm.py +++ b/src/phlashlib/hmm.py @@ -30,6 +30,7 @@ def forward( ) -> tuple[Float[Array, "M"], Scalar]: emis = jnp.stack([pp.emis0, pp.emis1, jnp.ones_like(pp.emis0)]) + @jax.remat def fwd(tup, ob): alpha_hat, ll = tup alpha_hat = _matvec_smc(alpha_hat, pp)