Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[original author: mrnikwaws] Neuron glu operator support #5466

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions codegen/xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ full_codegen:
- frac
- ge.Scalar
- ge.Tensor
- glu
- gt.Scalar
- gt.Tensor
- hardshrink
Expand Down
21 changes: 21 additions & 0 deletions test/cpp/test_aten_xla_tensor_6.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1319,5 +1319,26 @@ TEST_F(AtenXlaTensorTest, TestCdistForward) {
ExpectCounterChanged("xla::_cdist_forward", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestGlu) {
std::vector<std::vector<int64_t>> sizes{
{3, 8}, {3, 5, 6}, {3, 8, 5}, {3, 8, 8, 16}};
std::vector<int64_t> dims{-1, -1, 1, 3};

auto size_it = sizes.begin();
auto dim_it = dims.begin();
for (; size_it != sizes.end() && dim_it != dims.end(); ++size_it, ++dim_it) {
torch::Tensor input =
torch::rand(*size_it, torch::TensorOptions(torch::kFloat));
torch::Tensor output = torch::glu(input, *dim_it);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_output = torch::glu(xla_input, *dim_it);
AllClose(output, xla_output);
});
}
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::glu", cpp_test::GetIgnoredCounters());
}

} // namespace cpp_test
} // namespace torch_xla
22 changes: 21 additions & 1 deletion torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,26 @@ torch_xla::XlaOpVector GeTensor::Lower(LoweringContext* loctx) const {
return ReturnOp(BuildComparisonOp(at::aten::ge, xla_input, xla_other), loctx);
}

torch_xla::XlaOpVector Glu::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));

// Calculate half input shape on target dim - since input must be sliced in 2
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(xla_input);
int64_t ldim = dim;
if (ldim < 0) ldim += input_shape.rank();
absl::Span<const int64_t> inp_dimensions = input_shape.dimensions();
int64_t split_size = inp_dimensions[ldim] / 2;

// Split the input tensor into two parts, take sigmoid of RHS and multiple
// element-wise
xla::XlaOp a = xla::SliceInDim(xla_input, 0, split_size, 1, ldim);
xla::XlaOp b =
xla::SliceInDim(xla_input, split_size, split_size + split_size, 1, ldim);
xla::XlaOp result = a * BuildSigmoid(b);

return ReturnOp(result, loctx);
}

torch_xla::XlaOpVector GtScalar::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
xla::XlaOp xla_other = loctx->GetOutputOp(operand(1));
Expand Down Expand Up @@ -652,4 +672,4 @@ torch_xla::XlaOpVector Trunc::Lower(LoweringContext* loctx) const {
return ReturnOp(xla::Floor(BuildAbs(xla_input)) * BuildSgn(xla_input), loctx);
}

} // namespace torch_xla
} // namespace torch_xla
15 changes: 15 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,21 @@ xla::Shape GeTensorOutputShape(const torch::lazy::Value& self,
return GeScalarOutputShape(self, other);
}

xla::Shape GluOutputShape(const torch::lazy::Value& input, int64_t dim) {
const xla::Shape& input_shape = GetXlaShape(input);

if (dim < 0) dim += input_shape.rank();

absl::Span<const int64_t> inp_dimensions = input_shape.dimensions();
std::vector<int64_t> output_sizes(std::begin(inp_dimensions),
std::end(inp_dimensions));

// Output shape is always half the input shape on the specified dimension
output_sizes[dim] = inp_dimensions[dim] / 2;

return xla::ShapeUtil::MakeShape(input_shape.element_type(), output_sizes);
}

xla::Shape GtScalarOutputShape(const torch::lazy::Value& self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a cpp test similar to

TEST_F(AtenXlaTensorTest, TestGelu) {
?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be a cpp test already added or another one?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should add another one similar to Gelu.

Copy link
Contributor

@mrnikwaws mrnikwaws Aug 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @JackCaoG can you help me understand what is missing?

The test in https://github.com/pytorch/xla/pull/5466/files#diff-0c11cc2fbeb0cfbc8fe223f9ac3df0904dcbb1277946a55a56cb102642f4faefR1322-R1342 was intended to be similar to the Gelu one, it tests a variety of input permutations across both input variables and uses the same checks.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can build it with BUILD_CPP_TESTS=1 and run CPP test with ./test/cpp/run_test.sh

const torch::lazy::Value& other) {
auto lower_for_shape_fn =
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ xla::Shape GeScalarOutputShape(const torch::lazy::Value& self,
xla::Shape GeTensorOutputShape(const torch::lazy::Value& self,
const torch::lazy::Value& other);

xla::Shape GluOutputShape(const torch::lazy::Value& input, int64_t dim);

xla::Shape GtScalarOutputShape(const torch::lazy::Value& self,
const torch::lazy::Value& other);

Expand Down