Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RET_CHECK failure using odeint with simple case #26051

Open
DanPuzzuoli opened this issue Jan 22, 2025 · 1 comment
Open

RET_CHECK failure using odeint with simple case #26051

DanPuzzuoli opened this issue Jan 22, 2025 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@DanPuzzuoli
Copy link
Contributor

DanPuzzuoli commented Jan 22, 2025

Description

I have (what I believe) is a simple call to odeint:

import jax

from jax import jit
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

import jax.numpy as jnp

from jax.experimental.ode import odeint


def param_sim(a):

    def rhs(y, t):
        return a * y

    return odeint(
        rhs, 
        y0=jnp.array([0.0, 1.0], dtype=complex), 
        t=jnp.array([0., 10.]), 
        atol=1e-10, 
        rtol=1e-10
    )[-1] 

jit(param_sim)(1.)

This results in the following error message (Note that changing y0 to be dtype=float is removes the error, which of course is fine for solving this ODE, but this example is a stripped down version of a problem I'm encountering where I need to use complex data type.):

E0122 15:40:43.877828 1241887 status_macros.cc:57] INTERNAL: RET_CHECK failure (external/xla/xla/hlo/ir/hlo_computation.cc:1397) ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape()) "abs.0" (f64[2]{0}) vs "iota" (c128[2]{0})
*** Begin stack trace ***
	tsl::CurrentStackTrace()
	xla::status_macros::MakeError(char const*, int, absl::lts_20230802::StatusCode, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&, bool, absl::lts_20230802::LogSeverity, bool)
	xla::status_macros::MakeErrorStream::Impl::GetStatus()
	xla::status_macros::MakeErrorStream::MakeErrorStreamWithOutput::operator absl::lts_20230802::StatusOr<bool><bool>()
	xla::HloComputation::ReplaceInstruction(xla::HloInstruction*, xla::HloInstruction*, bool, bool, bool)
	xla::DfsHloRewriteVisitor::ReplaceInstruction(xla::HloInstruction*, xla::HloInstruction*, bool)
	xla::AlgebraicSimplifierVisitor::HandleAbs(xla::HloInstruction*)
	absl::lts_20230802::Status xla::PostOrderDFS<xla::DfsHloVisitorBase<xla::HloInstruction*>>(xla::HloInstruction*, xla::DfsHloVisitorBase<xla::HloInstruction*>*, std::__1::optional<absl::lts_20230802::FunctionRef<bool (std::__1::pair<int, xla::HloInstruction const*>, std::__1::pair<int, xla::HloInstruction const*>)>>, bool, bool)
	absl::lts_20230802::Status xla::HloInstruction::Accept<xla::HloInstruction*>(xla::DfsHloVisitorBase<xla::HloInstruction*>*, bool, bool, bool)
	absl::lts_20230802::Status xla::HloComputation::Accept<xla::HloInstruction*>(xla::DfsHloVisitorBase<xla::HloInstruction*>*) const
	xla::AlgebraicSimplifierVisitor::Run(xla::HloComputation*, xla::AlgebraicSimplifierOptions const&, xla::AlgebraicSimplifier*)
	xla::AlgebraicSimplifier::Run(xla::HloModule*, absl::lts_20230802::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&)
	xla::HloPassPipeline::RunHelper(xla::HloPassInterface*, xla::HloModule*, absl::lts_20230802::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&)
	absl::lts_20230802::StatusOr<bool> xla::HloPassPipeline::RunPassesInternal<xla::HloModule>(xla::HloModule*, xla::DebugOptions const&, absl::lts_20230802::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&)
	xla::HloPassPipeline::Run(xla::HloModule*, absl::lts_20230802::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&)
	xla::HloPassFix<xla::HloPassPipeline, 25>::RunOnChangedComputationsOnce(xla::HloModule*, xla::HloPassInterface::RunState*, absl::lts_20230802::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&)
	xla::HloPassFix<xla::HloPassPipeline, 25>::RunToFixPoint(xla::HloModule*, xla::HloPassInterface::RunState*, absl::lts_20230802::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&)
	xla::HloPassFix<xla::HloPassPipeline, 25>::Run(xla::HloModule*, absl::lts_20230802::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&)
	xla::HloPassPipeline::RunHelper(xla::HloPassInterface*, xla::HloModule*, absl::lts_20230802::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&)
	absl::lts_20230802::StatusOr<bool> xla::HloPassPipeline::RunPassesInternal<xla::HloModule>(xla::HloModule*, xla::DebugOptions const&, absl::lts_20230802::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&)
	xla::HloPassPipeline::Run(xla::HloModule*, absl::lts_20230802::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&)
	xla::cpu::CpuCompiler::RunHloPassesThroughLayoutAssn(xla::HloModule*, bool, xla::cpu::TargetMachineFeatures*, bool)
	xla::cpu::CpuCompiler::RunHloPasses(xla::HloModule*, bool, llvm::TargetMachine*, xla::Compiler::CompileOptions const&, bool)
	xla::cpu::CpuCompiler::RunHloPasses(std::__1::unique_ptr<xla::HloModule, std::__1::default_delete<xla::HloModule>>, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&)
	xla::TfrtCpuClient::Compile(xla::XlaComputation const&, xla::CompileOptions)
	xla::TfrtCpuClient::Compile(mlir::ModuleOp, xla::CompileOptions)
	xla::ifrt::PjRtLoadedExecutable::Create(xla::ifrt::PjRtCompatibleClient*, mlir::ModuleOp, xla::CompileOptions, std::__1::vector<tsl::RCReference<xla::ifrt::LoadedHostCallback>, std::__1::allocator<tsl::RCReference<xla::ifrt::LoadedHostCallback>>>)
	xla::ifrt::PjRtCompiler::Compile(std::__1::unique_ptr<xla::ifrt::Program, std::__1::default_delete<xla::ifrt::Program>>, std::__1::unique_ptr<xla::ifrt::CompileOptions, std::__1::default_delete<xla::ifrt::CompileOptions>>)
	xla::PyClient::CompileIfrtProgram(xla::nb_class_ptr<xla::PyClient>, std::__1::unique_ptr<xla::ifrt::Program, std::__1::default_delete<xla::ifrt::Program>>, std::__1::unique_ptr<xla::ifrt::CompileOptions, std::__1::default_delete<xla::ifrt::CompileOptions>>)
	xla::PyClient::Compile(xla::nb_class_ptr<xla::PyClient>, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, xla::CompileOptions, std::__1::vector<nanobind::capsule, std::__1::allocator<nanobind::capsule>>)
	_object* nanobind::detail::func_create<false, true, xla::PyClient::RegisterPythonTypes(nanobind::module_&)::$_7, xla::nb_class_ptr<xla::PyLoadedExecutable>, xla::nb_class_ptr<xla::PyClient>, nanobind::bytes, xla::CompileOptions, std::__1::vector<nanobind::capsule, std::__1::allocator<nanobind::capsule>>, 0ul, 1ul, 2ul, 3ul, nanobind::scope, nanobind::name, nanobind::is_method, nanobind::arg, nanobind::arg_v, nanobind::arg_v>(xla::PyClient::RegisterPythonTypes(nanobind::module_&)::$_7&&, xla::nb_class_ptr<xla::PyLoadedExecutable> (*)(xla::nb_class_ptr<xla::PyClient>, nanobind::bytes, xla::CompileOptions, std::__1::vector<nanobind::capsule, std::__1::allocator<nanobind::capsule>>), std::__1::integer_sequence<unsigned long, 0ul, 1ul, 2ul, 3ul>, nanobind::scope const&, nanobind::name const&, nanobind::is_method const&, nanobind::arg const&, nanobind::arg_v const&, nanobind::arg_v const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::__invoke(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)
	nanobind::detail::nb_func_vectorcall_complex(_object*, _object* const*, unsigned long, _object*)
	nanobind::detail::nb_bound_method_vectorcall(_object*, _object* const*, unsigned long, _object*)
	PyObject_Vectorcall
	_PyEval_EvalFrameDefault
	PyObject_Vectorcall
	nanobind::detail::obj_vectorcall(_object*, _object* const*, unsigned long, _object*, bool)
	nanobind::object nanobind::detail::api<nanobind::handle>::operator()<(nanobind::rv_policy)1, nanobind::object&, nanobind::detail::args_proxy, nanobind::detail::kwargs_proxy>(nanobind::object&, nanobind::detail::args_proxy&&, nanobind::detail::kwargs_proxy&&) const
	jax::WeakrefLRUCache::Call(nanobind::object, nanobind::args, nanobind::kwargs)
	_object* nanobind::detail::func_create<false, true, void nanobind::cpp_function_def<jax::WeakrefLRUCache, nanobind::object, jax::WeakrefLRUCache, nanobind::object, nanobind::args, nanobind::kwargs, nanobind::scope, nanobind::name, nanobind::is_method, nanobind::lock_self>(nanobind::object (jax::WeakrefLRUCache::*)(nanobind::object, nanobind::args, nanobind::kwargs), nanobind::scope const&, nanobind::name const&, nanobind::is_method const&, nanobind::lock_self const&)::'lambda'(jax::WeakrefLRUCache*, nanobind::object, nanobind::args, nanobind::kwargs), nanobind::object, jax::WeakrefLRUCache*, nanobind::object, nanobind::args, nanobind::kwargs, 0ul, 1ul, 2ul, 3ul, nanobind::scope, nanobind::name, nanobind::is_method, nanobind::lock_self>(jax::WeakrefLRUCache&&, nanobind::object (*)(nanobind::scope, nanobind::name, nanobind::is_method, nanobind::lock_self), std::__1::integer_sequence<unsigned long, 0ul, 1ul, 2ul, 3ul>, nanobind::scope const&, nanobind::name const&, nanobind::is_method const&, nanobind::lock_self const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::__invoke(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)
	nanobind::detail::nb_func_vectorcall_complex(_object*, _object* const*, unsigned long, _object*)
	_PyObject_FastCallDictTstate
	slot_tp_call
	_PyObject_MakeTpCall
	_PyEval_EvalFrameDefault
	PyObject_Vectorcall
	jax::(anonymous namespace)::PjitFunction::Call(nanobind::handle, _object* const*, unsigned long, _object*)
	PjitFunction_tp_vectorcall
	PyObject_Vectorcall
	_PyEval_EvalFrameDefault
	method_vectorcall
	_PyEval_EvalFrameDefault
	_PyObject_FastCallDictTstate
	slot_tp_call
	_PyObject_MakeTpCall
	_PyEval_EvalFrameDefault
	method_vectorcall
	_PyEval_EvalFrameDefault
	_PyObject_FastCallDictTstate
	slot_tp_call
	_PyObject_MakeTpCall
	_PyEval_EvalFrameDefault
	method_vectorcall
	_PyEval_EvalFrameDefault
	_PyObject_FastCallDictTstate
	slot_tp_call
	_PyObject_MakeTpCall
	_PyEval_EvalFrameDefault
	_PyObject_FastCallDictTstate
	slot_tp_init
	type_call
	_PyObject_MakeTpCall
	_PyEval_EvalFrameDefault
	PyEval_EvalCode
	builtin_exec
	cfunction_vectorcall_FASTCALL_KEYWORDS
	PyObject_Vectorcall
	_PyEval_EvalFrameDefault
	pymain_run_module
	Py_RunMain
	pymain_main
	main
	start
*** End stack trace ***

2025-01-22 15:40:43.878555: F external/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc:578] Non-OK-status: computation->Accept(this)
Status: INTERNAL: RET_CHECK failure (external/xla/xla/hlo/ir/hlo_computation.cc:1397) ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape()) "abs.0" (f64[2]{0}) vs "iota" (c128[2]{0})
zsh: abort      python -m unittest

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.5.0
jaxlib: 0.5.0
numpy:  2.2.2
python: 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:13:44) [Clang 16.0.6 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='Daniels-MBP.lan', release='23.6.0', version='Darwin Kernel Version 23.6.0: Mon Jul 29 21:14:30 PDT 2024; root:xnu-10063.141.2~1/RELEASE_ARM64_T6030', machine='arm64')
@DanPuzzuoli DanPuzzuoli added the bug Something isn't working label Jan 22, 2025
@DanPuzzuoli
Copy link
Contributor Author

DanPuzzuoli commented Jan 23, 2025

An update: changing

y0=jnp.array([0.0, 1.0], dtype=complex)

to

y0=jnp.array([0.0, 1.0j], dtype=complex)

causes the code to run correctly. I think there is an issue with with jnp.array([0.0, 1.0], dtype=complex) at some point being recorded as a real array (probably in compilation).

Similarly, changing it to something like

y0=jnp.array([0.0, 1.0 + 1e-20j], dtype=complex)

also causes the code to run properly. It seems like something about having y0.imag being exactly the zero vector is ruining the interpretation of y0 as a complex array somewhere.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants