diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index baf623f05cb..f8e21fd9bfa 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -38,6 +38,7 @@ full_codegen: - frac - ge.Scalar - ge.Tensor + - glu - gt.Scalar - gt.Tensor - hardshrink diff --git a/test/cpp/test_aten_xla_tensor_6.cpp b/test/cpp/test_aten_xla_tensor_6.cpp index b77392dc4c4..f8f11a2253c 100644 --- a/test/cpp/test_aten_xla_tensor_6.cpp +++ b/test/cpp/test_aten_xla_tensor_6.cpp @@ -1319,5 +1319,26 @@ TEST_F(AtenXlaTensorTest, TestCdistForward) { ExpectCounterChanged("xla::_cdist_forward", cpp_test::GetIgnoredCounters()); } +TEST_F(AtenXlaTensorTest, TestGlu) { + std::vector> sizes{ + {3, 8}, {3, 5, 6}, {3, 8, 5}, {3, 8, 8, 16}}; + std::vector dims{-1, -1, 1, 3}; + + auto size_it = sizes.begin(); + auto dim_it = dims.begin(); + for (; size_it != sizes.end() && dim_it != dims.end(); ++size_it, ++dim_it) { + torch::Tensor input = + torch::rand(*size_it, torch::TensorOptions(torch::kFloat)); + torch::Tensor output = torch::glu(input, *dim_it); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_input = CopyToDevice(input, device); + torch::Tensor xla_output = torch::glu(xla_input, *dim_it); + AllClose(output, xla_output); + }); + } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::glu", cpp_test::GetIgnoredCounters()); +} + } // namespace cpp_test } // namespace torch_xla diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 5f588bdd912..9b564f93232 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -347,6 +347,26 @@ torch_xla::XlaOpVector GeTensor::Lower(LoweringContext* loctx) const { return ReturnOp(BuildComparisonOp(at::aten::ge, xla_input, xla_other), loctx); } +torch_xla::XlaOpVector Glu::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + + // Calculate half input shape on target dim - since input must be sliced in 2 + const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(xla_input); + int64_t ldim = dim; + if (ldim < 0) ldim += input_shape.rank(); + absl::Span inp_dimensions = input_shape.dimensions(); + int64_t split_size = inp_dimensions[ldim] / 2; + + // Split the input tensor into two parts, take sigmoid of RHS and multiple + // element-wise + xla::XlaOp a = xla::SliceInDim(xla_input, 0, split_size, 1, ldim); + xla::XlaOp b = + xla::SliceInDim(xla_input, split_size, split_size + split_size, 1, ldim); + xla::XlaOp result = a * BuildSigmoid(b); + + return ReturnOp(result, loctx); +} + torch_xla::XlaOpVector GtScalar::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); xla::XlaOp xla_other = loctx->GetOutputOp(operand(1)); @@ -652,4 +672,4 @@ torch_xla::XlaOpVector Trunc::Lower(LoweringContext* loctx) const { return ReturnOp(xla::Floor(BuildAbs(xla_input)) * BuildSgn(xla_input), loctx); } -} // namespace torch_xla +} // namespace torch_xla \ No newline at end of file diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index 0d153ab78b5..296073c1115 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -421,6 +421,21 @@ xla::Shape GeTensorOutputShape(const torch::lazy::Value& self, return GeScalarOutputShape(self, other); } +xla::Shape GluOutputShape(const torch::lazy::Value& input, int64_t dim) { + const xla::Shape& input_shape = GetXlaShape(input); + + if (dim < 0) dim += input_shape.rank(); + + absl::Span inp_dimensions = input_shape.dimensions(); + std::vector output_sizes(std::begin(inp_dimensions), + std::end(inp_dimensions)); + + // Output shape is always half the input shape on the specified dimension + output_sizes[dim] = inp_dimensions[dim] / 2; + + return xla::ShapeUtil::MakeShape(input_shape.element_type(), output_sizes); +} + xla::Shape GtScalarOutputShape(const torch::lazy::Value& self, const torch::lazy::Value& other) { auto lower_for_shape_fn = diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.h b/torch_xla/csrc/ops/ops_xla_shape_fn.h index 537574c3b37..76a0144a50e 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.h +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.h @@ -137,6 +137,8 @@ xla::Shape GeScalarOutputShape(const torch::lazy::Value& self, xla::Shape GeTensorOutputShape(const torch::lazy::Value& self, const torch::lazy::Value& other); +xla::Shape GluOutputShape(const torch::lazy::Value& input, int64_t dim); + xla::Shape GtScalarOutputShape(const torch::lazy::Value& self, const torch::lazy::Value& other);