-
Notifications
You must be signed in to change notification settings - Fork 499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[original author: mrnikwaws] fixing incorrect stride information on xla tensors #5486
Conversation
aws-kingrj
commented
Aug 23, 2023
- fixing incorrect xla tensor stride information
- breaks log softmax lowering
- tested on AWS Neuron internal testing
- original author mrnikwaws
- old PR [original author: mrnikwaws] fixing incorrect stride information on xla tensors #5468
Sorry, had to create a new PR because of the rebase conflicts |
@@ -161,6 +161,11 @@ at::IntArrayRef XLATensorImpl::strides_custom() const { | |||
return strides_default(); | |||
} | |||
|
|||
c10::SymIntArrayRef XLATensorImpl::sym_strides_custom() const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add a test? It is a bit hard for me to tell what this pr is fixing and how it affect user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me check in with Ryan and create a test.
This issue occurs when pytorch checks the strides of a tensor have the same rank as the shape of a tensor. By default XLA returns strides of one. This will cause log_softmax to fail lowering (based on the input XLA tensor failing an assertion in pytorch code prior to lowering), so using this lowering should form a simple test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(test_env) ubuntu@ip-172-31-63-138:~/waldronn/asr$ python
Python 3.8.10 (default, May 26 2023, 14:05:08)
[GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import torch_neuronx
>>> example_input = torch.rand(1, 80, 643, dtype=torch.float)
>>> print(example_input.stride())
(51440, 643, 1)
>>> example_input = example_input.to('xla')
>>> print(example_input.stride())
(1,)
With the change the strides will match
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great, can we make this a test case? You can add it somewhere in https://github.com/pytorch/xla/blob/master/test/test_operations.py#L907 test_operation
and create a new test for it.
This has already been fixed in 2.1, after doing testing with log softmax lowering and getting the correct stride information |