Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[original author: mrnikwaws] fixing incorrect stride information on xla tensors #5486

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions torch_xla/csrc/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ at::IntArrayRef XLATensorImpl::strides_custom() const {
return strides_default();
}

c10::SymIntArrayRef XLATensorImpl::sym_strides_custom() const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a test? It is a bit hard for me to tell what this pr is fixing and how it affect user.

Copy link
Contributor

@mrnikwaws mrnikwaws Aug 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me check in with Ryan and create a test.

This issue occurs when pytorch checks the strides of a tensor have the same rank as the shape of a tensor. By default XLA returns strides of one. This will cause log_softmax to fail lowering (based on the input XLA tensor failing an assertion in pytorch code prior to lowering), so using this lowering should form a simple test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(test_env) ubuntu@ip-172-31-63-138:~/waldronn/asr$ python
Python 3.8.10 (default, May 26 2023, 14:05:08) 
[GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import torch_neuronx
>>> example_input = torch.rand(1, 80, 643, dtype=torch.float)
>>> print(example_input.stride())
(51440, 643, 1)
>>> example_input = example_input.to('xla')
>>> print(example_input.stride())
(1,)

With the change the strides will match

Copy link
Collaborator

@JackCaoG JackCaoG Sep 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great, can we make this a test case? You can add it somewhere in https://github.com/pytorch/xla/blob/master/test/test_operations.py#L907 test_operation and create a new test for it.

const_cast<XLATensorImpl*>(this)->SetupSizeProperties();
return c10::SymIntArrayRef(sym_strides_.data(), sym_strides_.size());
}

int64_t XLATensorImpl::dim_custom() const {
const_cast<XLATensorImpl*>(this)->SetupSizeProperties();
return dim_default();
Expand Down Expand Up @@ -205,6 +210,8 @@ void XLATensorImpl::SetupSymSizeProperties() {
auto rank = shape.get().rank();
std::vector<c10::SymInt> sym_sizes;
sym_sizes.reserve(rank);
std::vector<c10::SymInt> sym_strides(rank);
size_t index = rank;

XLAIrBuilder a = XLAIrBuilder();
for (auto i : c10::irange(rank)) {
Expand All @@ -219,6 +226,16 @@ void XLATensorImpl::SetupSymSizeProperties() {
}
}
sym_sizes_ = sym_sizes;

c10::SymInt prod{1};

while (index > 0) {
--index;
sym_strides[index] = prod;
prod *= sym_sizes[index];
}

sym_strides_ = sym_strides;
}

caffe2::TypeMeta XLATensorImpl::GetTypeMeta(const XLATensor& tensor) {
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/tensor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class XLATensorImpl : public c10::TensorImpl {
c10::SymIntArrayRef sym_sizes_custom() const override;
c10::SymInt sym_numel_custom() const override;
at::IntArrayRef strides_custom() const override;
c10::SymIntArrayRef sym_strides_custom() const override;

int64_t dim_custom() const override;

Expand All @@ -67,6 +68,7 @@ class XLATensorImpl : public c10::TensorImpl {

XLATensorPtr tensor_;
std::vector<c10::SymInt> sym_sizes_;
std::vector<c10::SymInt> sym_strides_;
size_t generation_ = 0;
};

Expand Down