You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have (what I believe) is a simple call to odeint:
import jax
from jax import jit
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
import jax.numpy as jnp
from jax.experimental.ode import odeint
def param_sim(a):
def rhs(y, t):
return a * y
return odeint(
rhs,
y0=jnp.array([0.0, 1.0], dtype=complex),
t=jnp.array([0., 10.]),
atol=1e-10,
rtol=1e-10
)[-1]
jit(param_sim)(1.)
This results in the following error message (Note that changing y0 to be dtype=float is removes the error, which of course is fine for solving this ODE, but this example is a stripped down version of a problem I'm encountering where I need to use complex data type.):
causes the code to run correctly. I think there is an issue with with jnp.array([0.0, 1.0], dtype=complex) at some point being recorded as a real array (probably in compilation).
Similarly, changing it to something like
y0=jnp.array([0.0, 1.0 + 1e-20j], dtype=complex)
also causes the code to run properly. It seems like something about having y0.imag being exactly the zero vector is ruining the interpretation of y0 as a complex array somewhere.
Description
I have (what I believe) is a simple call to
odeint
:This results in the following error message (Note that changing
y0
to bedtype=float
is removes the error, which of course is fine for solving this ODE, but this example is a stripped down version of a problem I'm encountering where I need to use complex data type.):System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: