Skip to content

Commit

Permalink
[original author: mrnikwaws] Neuron operator support (#5471)
Browse files Browse the repository at this point in the history
* adding glu operator support

* adding glu operator

* fixing yaml

* fixing linter issues

* fixing linter issues

* fixing spacing

* fixing spacing

* fixing spacing

* fixing spacing

* fixing shape helper

* fixing spacing
  • Loading branch information
aws-kingrj authored and will-cromar committed Sep 14, 2023
1 parent 42cadd2 commit 6378fca
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 1 deletion.
1 change: 1 addition & 0 deletions codegen/xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ full_codegen:
- frac
- ge.Scalar
- ge.Tensor
- glu
- gt.Scalar
- gt.Tensor
- hardshrink
Expand Down
21 changes: 21 additions & 0 deletions test/cpp/test_aten_xla_tensor_6.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1319,5 +1319,26 @@ TEST_F(AtenXlaTensorTest, TestCdistForward) {
ExpectCounterChanged("xla::_cdist_forward", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestGlu) {
std::vector<std::vector<int64_t>> sizes{
{3, 8}, {3, 5, 6}, {3, 8, 5}, {3, 8, 8, 16}};
std::vector<int64_t> dims{-1, -1, 1, 3};

auto size_it = sizes.begin();
auto dim_it = dims.begin();
for (; size_it != sizes.end() && dim_it != dims.end(); ++size_it, ++dim_it) {
torch::Tensor input =
torch::rand(*size_it, torch::TensorOptions(torch::kFloat));
torch::Tensor output = torch::glu(input, *dim_it);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_output = torch::glu(xla_input, *dim_it);
AllClose(output, xla_output);
});
}
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::glu", cpp_test::GetIgnoredCounters());
}

} // namespace cpp_test
} // namespace torch_xla
22 changes: 21 additions & 1 deletion torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,26 @@ torch_xla::XlaOpVector GeTensor::Lower(LoweringContext* loctx) const {
return ReturnOp(BuildComparisonOp(at::aten::ge, xla_input, xla_other), loctx);
}

torch_xla::XlaOpVector Glu::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));

// Calculate half input shape on target dim - since input must be sliced in 2
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(xla_input);
int64_t ldim = dim;
if (ldim < 0) ldim += input_shape.rank();
absl::Span<const int64_t> inp_dimensions = input_shape.dimensions();
int64_t split_size = inp_dimensions[ldim] / 2;

// Split the input tensor into two parts, take sigmoid of RHS and multiple
// element-wise
xla::XlaOp a = xla::SliceInDim(xla_input, 0, split_size, 1, ldim);
xla::XlaOp b =
xla::SliceInDim(xla_input, split_size, split_size + split_size, 1, ldim);
xla::XlaOp result = a * BuildSigmoid(b);

return ReturnOp(result, loctx);
}

torch_xla::XlaOpVector GtScalar::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
xla::XlaOp xla_other = loctx->GetOutputOp(operand(1));
Expand Down Expand Up @@ -652,4 +672,4 @@ torch_xla::XlaOpVector Trunc::Lower(LoweringContext* loctx) const {
return ReturnOp(xla::Floor(BuildAbs(xla_input)) * BuildSgn(xla_input), loctx);
}

} // namespace torch_xla
} // namespace torch_xla
15 changes: 15 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,21 @@ xla::Shape GeTensorOutputShape(const torch::lazy::Value& self,
return GeScalarOutputShape(self, other);
}

xla::Shape GluOutputShape(const torch::lazy::Value& input, int64_t dim) {
const xla::Shape& input_shape = GetXlaShape(input);

if (dim < 0) dim += input_shape.rank();

absl::Span<const int64_t> inp_dimensions = input_shape.dimensions();
std::vector<int64_t> output_sizes(std::begin(inp_dimensions),
std::end(inp_dimensions));

// Output shape is always half the input shape on the specified dimension
output_sizes[dim] = inp_dimensions[dim] / 2;

return xla::ShapeUtil::MakeShape(input_shape.element_type(), output_sizes);
}

xla::Shape GtScalarOutputShape(const torch::lazy::Value& self,
const torch::lazy::Value& other) {
auto lower_for_shape_fn =
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ xla::Shape GeScalarOutputShape(const torch::lazy::Value& self,
xla::Shape GeTensorOutputShape(const torch::lazy::Value& self,
const torch::lazy::Value& other);

xla::Shape GluOutputShape(const torch::lazy::Value& input, int64_t dim);

xla::Shape GtScalarOutputShape(const torch::lazy::Value& self,
const torch::lazy::Value& other);

Expand Down

0 comments on commit 6378fca

Please sign in to comment.