Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[original author: mrnikwaws] fixing incorrect stride information on xla tensors #5486

Closed
wants to merge 4 commits into from

Conversation

aws-kingrj
Copy link
Collaborator

@aws-kingrj aws-kingrj requested a review from JackCaoG August 23, 2023 17:35
@aws-kingrj
Copy link
Collaborator Author

Sorry, had to create a new PR because of the rebase conflicts

@@ -161,6 +161,11 @@ at::IntArrayRef XLATensorImpl::strides_custom() const {
return strides_default();
}

c10::SymIntArrayRef XLATensorImpl::sym_strides_custom() const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a test? It is a bit hard for me to tell what this pr is fixing and how it affect user.

Copy link
Contributor

@mrnikwaws mrnikwaws Aug 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me check in with Ryan and create a test.

This issue occurs when pytorch checks the strides of a tensor have the same rank as the shape of a tensor. By default XLA returns strides of one. This will cause log_softmax to fail lowering (based on the input XLA tensor failing an assertion in pytorch code prior to lowering), so using this lowering should form a simple test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(test_env) ubuntu@ip-172-31-63-138:~/waldronn/asr$ python
Python 3.8.10 (default, May 26 2023, 14:05:08) 
[GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import torch_neuronx
>>> example_input = torch.rand(1, 80, 643, dtype=torch.float)
>>> print(example_input.stride())
(51440, 643, 1)
>>> example_input = example_input.to('xla')
>>> print(example_input.stride())
(1,)

With the change the strides will match

Copy link
Collaborator

@JackCaoG JackCaoG Sep 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great, can we make this a test case? You can add it somewhere in https://github.com/pytorch/xla/blob/master/test/test_operations.py#L907 test_operation and create a new test for it.

@aws-kingrj
Copy link
Collaborator Author

This has already been fixed in 2.1, after doing testing with log softmax lowering and getting the correct stride information

@aws-kingrj aws-kingrj closed this Nov 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants