From 86c90f61a77af159d19df43c7aaa0a808f6a0aa2 Mon Sep 17 00:00:00 2001 From: King Date: Fri, 18 Aug 2023 11:51:00 -0700 Subject: [PATCH 01/11] adding glu operator support --- codegen/xla_native_functions.yaml | 1 + test/cpp/test_aten_xla_tensor_6.cpp | 25 +++++++++++++++++++++++++ torch_xla/csrc/ops/ops_lower_fn.cpp | 18 ++++++++++++++++++ torch_xla/csrc/ops/ops_xla_shape_fn.cpp | 14 ++++++++++++++ torch_xla/csrc/ops/ops_xla_shape_fn.h | 2 ++ 5 files changed, 60 insertions(+) diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index baf623f05cb..f8e21fd9bfa 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -38,6 +38,7 @@ full_codegen: - frac - ge.Scalar - ge.Tensor + - glu - gt.Scalar - gt.Tensor - hardshrink diff --git a/test/cpp/test_aten_xla_tensor_6.cpp b/test/cpp/test_aten_xla_tensor_6.cpp index b77392dc4c4..9a1f3f33f59 100644 --- a/test/cpp/test_aten_xla_tensor_6.cpp +++ b/test/cpp/test_aten_xla_tensor_6.cpp @@ -1319,5 +1319,30 @@ TEST_F(AtenXlaTensorTest, TestCdistForward) { ExpectCounterChanged("xla::_cdist_forward", cpp_test::GetIgnoredCounters()); } +TEST_F(AtenXlaTensorTest, TestGlu) { + std::vector > sizes { + {3, 8}, + {3, 5, 6}, + {3, 8, 5}, + {3, 8, 8, 16} + }; + std::vector dims {-1, -1, 1, 3}; + + auto size_it = sizes.begin(); + auto dim_it = dims.begin(); + for ( ; size_it != sizes.end() && dim_it != dims.end() ; ++size_it, ++dim_it ) { + torch::Tensor input = + torch::rand(*size_it, torch::TensorOptions(torch::kFloat)); + torch::Tensor output = torch::glu(input, *dim_it); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_input = CopyToDevice(input, device); + torch::Tensor xla_output = torch::glu(xla_input, *dim_it); + AllClose(output, xla_output); + }); + } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::glu", cpp_test::GetIgnoredCounters()); +} + } // namespace cpp_test } // namespace torch_xla diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 5f588bdd912..32f8bb0a2c3 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -347,6 +347,24 @@ torch_xla::XlaOpVector GeTensor::Lower(LoweringContext* loctx) const { return ReturnOp(BuildComparisonOp(at::aten::ge, xla_input, xla_other), loctx); } +torch_xla::XlaOpVector Glu::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + + // Calculate half input shape on target dim - since input must be sliced in 2 + const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(xla_input); + int64_t ldim = dim; + if (ldim < 0) ldim += input_shape.rank(); + absl::Span inp_dimensions = input_shape.dimensions(); + int64_t split_size = inp_dimensions[ldim]/2; + + // Split the input tensor into two parts, take sigmoid of RHS and multiple element-wise + xla::XlaOp a = xla::SliceInDim(xla_input, 0, split_size, 1, ldim); + xla::XlaOp b = xla::SliceInDim(xla_input, split_size, split_size + split_size, 1, ldim); + xla::XlaOp result = a * BuildSigmoid(b); + + return ReturnOp(result, loctx); +} + torch_xla::XlaOpVector GtScalar::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); xla::XlaOp xla_other = loctx->GetOutputOp(operand(1)); diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index 0d153ab78b5..148f6e23553 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -421,6 +421,20 @@ xla::Shape GeTensorOutputShape(const torch::lazy::Value& self, return GeScalarOutputShape(self, other); } +xla::Shape GluOutputShape(const torch::lazy::Value& input, int64_t dim) { + const xla::Shape& input_shape = GetXlaShape(input); + + if (dim < 0) dim += input_shape.rank(); + + absl::Span inp_dimensions = input_shape.dimensions(); + std::vector output_sizes(std::begin(inp_dimensions),std::end(inp_dimensions)); + + // Output shape is always half the input shape on the specified dimension + output_sizes[dim] = inp_dimensions[dim]/2; + + return xla::ShapeUtil::MakeShape(input_shape.element_type(), output_sizes); +} + xla::Shape GtScalarOutputShape(const torch::lazy::Value& self, const torch::lazy::Value& other) { auto lower_for_shape_fn = diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.h b/torch_xla/csrc/ops/ops_xla_shape_fn.h index 537574c3b37..76a0144a50e 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.h +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.h @@ -137,6 +137,8 @@ xla::Shape GeScalarOutputShape(const torch::lazy::Value& self, xla::Shape GeTensorOutputShape(const torch::lazy::Value& self, const torch::lazy::Value& other); +xla::Shape GluOutputShape(const torch::lazy::Value& input, int64_t dim); + xla::Shape GtScalarOutputShape(const torch::lazy::Value& self, const torch::lazy::Value& other); From 50ff0fb2922de57615f29619c5c018c2a042fc45 Mon Sep 17 00:00:00 2001 From: King Date: Fri, 18 Aug 2023 13:17:13 -0700 Subject: [PATCH 02/11] adding glu operator --- codegen/xla_native_functions.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index f8e21fd9bfa..e8de6f8aee5 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -162,8 +162,6 @@ supported: - convolution_overrideable - copy - copy_ - - count_nonzero - - count_nonzero.dim_IntList - cross - cumprod - cumsum From 569bf7e14a8e4f877d21eb2d599bebe8502faf41 Mon Sep 17 00:00:00 2001 From: King Date: Fri, 18 Aug 2023 13:19:02 -0700 Subject: [PATCH 03/11] fixing yaml --- codegen/xla_native_functions.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index e8de6f8aee5..f8e21fd9bfa 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -162,6 +162,8 @@ supported: - convolution_overrideable - copy - copy_ + - count_nonzero + - count_nonzero.dim_IntList - cross - cumprod - cumsum From 3501a1b82815fe696f034b50d68c575031f4f6f7 Mon Sep 17 00:00:00 2001 From: King Date: Fri, 18 Aug 2023 13:39:51 -0700 Subject: [PATCH 04/11] fixing linter issues --- test/cpp/test_aten_xla_tensor_6.cpp | 16 ++++++---------- torch_xla/csrc/ops/ops_lower_fn.cpp | 2 +- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/test/cpp/test_aten_xla_tensor_6.cpp b/test/cpp/test_aten_xla_tensor_6.cpp index 9a1f3f33f59..19e6a303254 100644 --- a/test/cpp/test_aten_xla_tensor_6.cpp +++ b/test/cpp/test_aten_xla_tensor_6.cpp @@ -1320,21 +1320,17 @@ TEST_F(AtenXlaTensorTest, TestCdistForward) { } TEST_F(AtenXlaTensorTest, TestGlu) { - std::vector > sizes { - {3, 8}, - {3, 5, 6}, - {3, 8, 5}, - {3, 8, 8, 16} - }; - std::vector dims {-1, -1, 1, 3}; - + std::vector> sizes{ + {3, 8}, {3, 5, 6}, {3, 8, 5}, {3, 8, 8, 16}}; + std::vector dims{-1, -1, 1, 3}; + auto size_it = sizes.begin(); auto dim_it = dims.begin(); - for ( ; size_it != sizes.end() && dim_it != dims.end() ; ++size_it, ++dim_it ) { + for (; size_it != sizes.end() && dim_it != dims.end(); ++size_it, ++dim_it) { torch::Tensor input = torch::rand(*size_it, torch::TensorOptions(torch::kFloat)); torch::Tensor output = torch::glu(input, *dim_it); - ForEachDevice([&](const torch::Device& device) { + ForEachDevice([&](const torch::Device &device) { torch::Tensor xla_input = CopyToDevice(input, device); torch::Tensor xla_output = torch::glu(xla_input, *dim_it); AllClose(output, xla_output); diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 32f8bb0a2c3..30b467a973b 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -670,4 +670,4 @@ torch_xla::XlaOpVector Trunc::Lower(LoweringContext* loctx) const { return ReturnOp(xla::Floor(BuildAbs(xla_input)) * BuildSgn(xla_input), loctx); } -} // namespace torch_xla +} // namespace torch_xla \ No newline at end of file From 1ab66a79fb3373f7648f807041b94135e920251b Mon Sep 17 00:00:00 2001 From: King Date: Fri, 18 Aug 2023 13:53:29 -0700 Subject: [PATCH 05/11] fixing linter issues --- test/cpp/test_aten_xla_tensor_6.cpp | 2 +- torch_xla/csrc/ops/ops_lower_fn.cpp | 12 +++++++----- torch_xla/csrc/ops/ops_xla_shape_fn.cpp | 5 +++-- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/test/cpp/test_aten_xla_tensor_6.cpp b/test/cpp/test_aten_xla_tensor_6.cpp index 19e6a303254..f8f11a2253c 100644 --- a/test/cpp/test_aten_xla_tensor_6.cpp +++ b/test/cpp/test_aten_xla_tensor_6.cpp @@ -1330,7 +1330,7 @@ TEST_F(AtenXlaTensorTest, TestGlu) { torch::Tensor input = torch::rand(*size_it, torch::TensorOptions(torch::kFloat)); torch::Tensor output = torch::glu(input, *dim_it); - ForEachDevice([&](const torch::Device &device) { + ForEachDevice([&](const torch::Device& device) { torch::Tensor xla_input = CopyToDevice(input, device); torch::Tensor xla_output = torch::glu(xla_input, *dim_it); AllClose(output, xla_output); diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 30b467a973b..20a4803a904 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -349,19 +349,21 @@ torch_xla::XlaOpVector GeTensor::Lower(LoweringContext* loctx) const { torch_xla::XlaOpVector Glu::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); - + // Calculate half input shape on target dim - since input must be sliced in 2 const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(xla_input); int64_t ldim = dim; if (ldim < 0) ldim += input_shape.rank(); absl::Span inp_dimensions = input_shape.dimensions(); - int64_t split_size = inp_dimensions[ldim]/2; + int64_t split_size = inp_dimensions[ldim] / 2; - // Split the input tensor into two parts, take sigmoid of RHS and multiple element-wise + // Split the input tensor into two parts, take sigmoid of RHS and multiple + // element-wise xla::XlaOp a = xla::SliceInDim(xla_input, 0, split_size, 1, ldim); - xla::XlaOp b = xla::SliceInDim(xla_input, split_size, split_size + split_size, 1, ldim); + xla::XlaOp b = + xla::SliceInDim(xla_input, split_size, split_size + split_size, 1, ldim); xla::XlaOp result = a * BuildSigmoid(b); - + return ReturnOp(result, loctx); } diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index 148f6e23553..0f761ac2c49 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -427,10 +427,11 @@ xla::Shape GluOutputShape(const torch::lazy::Value& input, int64_t dim) { if (dim < 0) dim += input_shape.rank(); absl::Span inp_dimensions = input_shape.dimensions(); - std::vector output_sizes(std::begin(inp_dimensions),std::end(inp_dimensions)); + std::vector output_sizes(std::begin(inp_dimensions), + std::end(inp_dimensions)); // Output shape is always half the input shape on the specified dimension - output_sizes[dim] = inp_dimensions[dim]/2; + output_sizes[dim] = inp_dimensions[dim] / 2; return xla::ShapeUtil::MakeShape(input_shape.element_type(), output_sizes); } From afce34da67bad2c1f5f75c4cd9fd4e8c3e7c7c3d Mon Sep 17 00:00:00 2001 From: King Date: Fri, 18 Aug 2023 13:57:49 -0700 Subject: [PATCH 06/11] fixing spacing --- torch_xla/csrc/ops/ops_lower_fn.cpp | 4 ++-- torch_xla/csrc/ops/ops_xla_shape_fn.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 20a4803a904..c27d8e478ea 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -356,11 +356,11 @@ torch_xla::XlaOpVector Glu::Lower(LoweringContext* loctx) const { if (ldim < 0) ldim += input_shape.rank(); absl::Span inp_dimensions = input_shape.dimensions(); int64_t split_size = inp_dimensions[ldim] / 2; - + // Split the input tensor into two parts, take sigmoid of RHS and multiple // element-wise xla::XlaOp a = xla::SliceInDim(xla_input, 0, split_size, 1, ldim); - xla::XlaOp b = + xla::XlaOp b = xla::SliceInDim(xla_input, split_size, split_size + split_size, 1, ldim); xla::XlaOp result = a * BuildSigmoid(b); diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index 0f761ac2c49..6f7e07a77d2 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -429,10 +429,10 @@ xla::Shape GluOutputShape(const torch::lazy::Value& input, int64_t dim) { absl::Span inp_dimensions = input_shape.dimensions(); std::vector output_sizes(std::begin(inp_dimensions), std::end(inp_dimensions)); - + // Output shape is always half the input shape on the specified dimension output_sizes[dim] = inp_dimensions[dim] / 2; - + return xla::ShapeUtil::MakeShape(input_shape.element_type(), output_sizes); } From 3e516e8dd5734613f98028f59bb0203487b24466 Mon Sep 17 00:00:00 2001 From: King Date: Fri, 18 Aug 2023 14:03:22 -0700 Subject: [PATCH 07/11] fixing spacing --- torch_xla/csrc/ops/ops_lower_fn.cpp | 4 ++-- torch_xla/csrc/ops/ops_xla_shape_fn.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index c27d8e478ea..ee0d8c3d39c 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -356,14 +356,14 @@ torch_xla::XlaOpVector Glu::Lower(LoweringContext* loctx) const { if (ldim < 0) ldim += input_shape.rank(); absl::Span inp_dimensions = input_shape.dimensions(); int64_t split_size = inp_dimensions[ldim] / 2; - + // Split the input tensor into two parts, take sigmoid of RHS and multiple // element-wise xla::XlaOp a = xla::SliceInDim(xla_input, 0, split_size, 1, ldim); xla::XlaOp b = xla::SliceInDim(xla_input, split_size, split_size + split_size, 1, ldim); xla::XlaOp result = a * BuildSigmoid(b); - + return ReturnOp(result, loctx); } diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index 6f7e07a77d2..296073c1115 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -429,7 +429,7 @@ xla::Shape GluOutputShape(const torch::lazy::Value& input, int64_t dim) { absl::Span inp_dimensions = input_shape.dimensions(); std::vector output_sizes(std::begin(inp_dimensions), std::end(inp_dimensions)); - + // Output shape is always half the input shape on the specified dimension output_sizes[dim] = inp_dimensions[dim] / 2; From e34435e52f0de836f39f9dff23e1c489298a7e76 Mon Sep 17 00:00:00 2001 From: King Date: Fri, 18 Aug 2023 14:10:05 -0700 Subject: [PATCH 08/11] fixing spacing --- torch_xla/csrc/ops/ops_lower_fn.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index ee0d8c3d39c..7e5a50400ba 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -356,7 +356,7 @@ torch_xla::XlaOpVector Glu::Lower(LoweringContext* loctx) const { if (ldim < 0) ldim += input_shape.rank(); absl::Span inp_dimensions = input_shape.dimensions(); int64_t split_size = inp_dimensions[ldim] / 2; - + // Split the input tensor into two parts, take sigmoid of RHS and multiple // element-wise xla::XlaOp a = xla::SliceInDim(xla_input, 0, split_size, 1, ldim); From 01cc1b4a5c62aa6f8d35234b42dba1f78215a1cf Mon Sep 17 00:00:00 2001 From: King Date: Fri, 18 Aug 2023 14:13:14 -0700 Subject: [PATCH 09/11] fixing spacing --- torch_xla/csrc/ops/ops_lower_fn.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 7e5a50400ba..ee0d8c3d39c 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -356,7 +356,7 @@ torch_xla::XlaOpVector Glu::Lower(LoweringContext* loctx) const { if (ldim < 0) ldim += input_shape.rank(); absl::Span inp_dimensions = input_shape.dimensions(); int64_t split_size = inp_dimensions[ldim] / 2; - + // Split the input tensor into two parts, take sigmoid of RHS and multiple // element-wise xla::XlaOp a = xla::SliceInDim(xla_input, 0, split_size, 1, ldim); From b2df22155fa3694c852034e26723f24f5353eac4 Mon Sep 17 00:00:00 2001 From: King Date: Fri, 18 Aug 2023 15:17:35 -0700 Subject: [PATCH 10/11] fixing shape helper --- torch_xla/csrc/ops/ops_lower_fn.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index ee0d8c3d39c..89d71e8e636 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -351,7 +351,7 @@ torch_xla::XlaOpVector Glu::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); // Calculate half input shape on target dim - since input must be sliced in 2 - const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(xla_input); + const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(xla_input); int64_t ldim = dim; if (ldim < 0) ldim += input_shape.rank(); absl::Span inp_dimensions = input_shape.dimensions(); From 3f497624ddc6d7c264fd19a2aef2de9fd6ce13b0 Mon Sep 17 00:00:00 2001 From: King Date: Fri, 18 Aug 2023 15:21:37 -0700 Subject: [PATCH 11/11] fixing spacing --- torch_xla/csrc/ops/ops_lower_fn.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 89d71e8e636..9b564f93232 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -357,7 +357,7 @@ torch_xla::XlaOpVector Glu::Lower(LoweringContext* loctx) const { absl::Span inp_dimensions = input_shape.dimensions(); int64_t split_size = inp_dimensions[ldim] / 2; - // Split the input tensor into two parts, take sigmoid of RHS and multiple + // Split the input tensor into two parts, take sigmoid of RHS and multiple // element-wise xla::XlaOp a = xla::SliceInDim(xla_input, 0, split_size, 1, ldim); xla::XlaOp b =