diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index 62804b6f2f2..91822a3ba4a 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -161,6 +161,11 @@ at::IntArrayRef XLATensorImpl::strides_custom() const { return strides_default(); } +c10::SymIntArrayRef XLATensorImpl::sym_strides_custom() const { + const_cast(this)->SetupSizeProperties(); + return c10::SymIntArrayRef(sym_strides_.data(), sym_strides_.size()); +} + int64_t XLATensorImpl::dim_custom() const { const_cast(this)->SetupSizeProperties(); return dim_default(); @@ -205,6 +210,8 @@ void XLATensorImpl::SetupSymSizeProperties() { auto rank = shape.get().rank(); std::vector sym_sizes; sym_sizes.reserve(rank); + std::vector sym_strides(rank); + size_t index = rank; XLAIrBuilder a = XLAIrBuilder(); for (auto i : c10::irange(rank)) { @@ -219,6 +226,16 @@ void XLATensorImpl::SetupSymSizeProperties() { } } sym_sizes_ = sym_sizes; + + c10::SymInt prod{1}; + + while (index > 0) { + --index; + sym_strides[index] = prod; + prod *= sym_sizes[index]; + } + + sym_strides_ = sym_strides; } caffe2::TypeMeta XLATensorImpl::GetTypeMeta(const XLATensor& tensor) { diff --git a/torch_xla/csrc/tensor_impl.h b/torch_xla/csrc/tensor_impl.h index 9c9bda1df54..d9d2f5379d1 100644 --- a/torch_xla/csrc/tensor_impl.h +++ b/torch_xla/csrc/tensor_impl.h @@ -46,6 +46,7 @@ class XLATensorImpl : public c10::TensorImpl { c10::SymIntArrayRef sym_sizes_custom() const override; c10::SymInt sym_numel_custom() const override; at::IntArrayRef strides_custom() const override; + c10::SymIntArrayRef sym_strides_custom() const override; int64_t dim_custom() const override; @@ -67,6 +68,7 @@ class XLATensorImpl : public c10::TensorImpl { XLATensorPtr tensor_; std::vector sym_sizes_; + std::vector sym_strides_; size_t generation_ = 0; };