Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The stateless API can't be used with Jitted modules #19

Open
y0ngzq opened this issue Jan 31, 2025 · 1 comment
Open

The stateless API can't be used with Jitted modules #19

y0ngzq opened this issue Jan 31, 2025 · 1 comment

Comments

@y0ngzq
Copy link

y0ngzq commented Jan 31, 2025

It runs normally when using titans-pytorch to process data independently, but when it's added as an intermediate layer in some networks, the following error occurs. How should I handle this?

File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/titans_pytorch/neural_memory.py", line 910, in forward
next_updates, next_neural_mem_state = self.store_memories(
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/titans_pytorch/neural_memory.py", line 624, in store_memories
grads = self.per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/apis.py", line 203, in wrapped
return vmap_impl(
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
return _flat_vmap(
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/apis.py", line 399, in wrapper
return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 1406, in grad_impl
results = grad_and_value_impl(func, argnums, has_aux, args, kwargs)
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 48, in fn
return f(*args, **kwargs)
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 1364, in grad_and_value_impl
output = func(*args, **kwargs)
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/titans_pytorch/neural_memory.py", line 349, in forward_and_loss
pred = functional_call(self.memory_model, params, inputs)
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/functional_call.py", line 148, in functional_call
return nn.utils.stateless._functional_call(
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/nn/utils/stateless.py", line 286, in _functional_call
raise RuntimeError("The stateless API can't be used with Jitted modules")
RuntimeError: The stateless API can't be used with Jitted modules

@looper99
Copy link

@y0ngzq I also ran into this problem and could not resolve it. Did you find any elegant solution?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants