You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It runs normally when using titans-pytorch to process data independently, but when it's added as an intermediate layer in some networks, the following error occurs. How should I handle this?
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/titans_pytorch/neural_memory.py", line 910, in forward
next_updates, next_neural_mem_state = self.store_memories(
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/titans_pytorch/neural_memory.py", line 624, in store_memories
grads = self.per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/apis.py", line 203, in wrapped
return vmap_impl(
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
return _flat_vmap(
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/apis.py", line 399, in wrapper
return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 1406, in grad_impl
results = grad_and_value_impl(func, argnums, has_aux, args, kwargs)
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 48, in fn
return f(*args, **kwargs)
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 1364, in grad_and_value_impl
output = func(*args, **kwargs)
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/titans_pytorch/neural_memory.py", line 349, in forward_and_loss
pred = functional_call(self.memory_model, params, inputs)
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/functional_call.py", line 148, in functional_call
return nn.utils.stateless._functional_call(
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/nn/utils/stateless.py", line 286, in _functional_call
raise RuntimeError("The stateless API can't be used with Jitted modules")
RuntimeError: The stateless API can't be used with Jitted modules
The text was updated successfully, but these errors were encountered:
It runs normally when using titans-pytorch to process data independently, but when it's added as an intermediate layer in some networks, the following error occurs. How should I handle this?
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/titans_pytorch/neural_memory.py", line 910, in forward
next_updates, next_neural_mem_state = self.store_memories(
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/titans_pytorch/neural_memory.py", line 624, in store_memories
grads = self.per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/apis.py", line 203, in wrapped
return vmap_impl(
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
return _flat_vmap(
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/apis.py", line 399, in wrapper
return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 1406, in grad_impl
results = grad_and_value_impl(func, argnums, has_aux, args, kwargs)
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 48, in fn
return f(*args, **kwargs)
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 1364, in grad_and_value_impl
output = func(*args, **kwargs)
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/titans_pytorch/neural_memory.py", line 349, in forward_and_loss
pred = functional_call(self.memory_model, params, inputs)
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/_functorch/functional_call.py", line 148, in functional_call
return nn.utils.stateless._functional_call(
File "/usr/local/miniconda3/envs/pointmamba/lib/python3.9/site-packages/torch/nn/utils/stateless.py", line 286, in _functional_call
raise RuntimeError("The stateless API can't be used with Jitted modules")
RuntimeError: The stateless API can't be used with Jitted modules
The text was updated successfully, but these errors were encountered: