Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT FE] Support aten::atan2 for pytorch models #27026

Merged
merged 13 commits into from
Oct 23, 2024
101 changes: 101 additions & 0 deletions src/frontends/pytorch/src/op/atan2.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#define _USE_MATH_DEFINES

#include <math.h>

#include <memory>

#include "openvino/core/type/element_type.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/atan.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/equal.hpp"
#include "openvino/op/greater.hpp"
#include "openvino/op/greater_eq.hpp"
#include "openvino/op/less.hpp"
#include "openvino/op/logical_and.hpp"
#include "openvino/op/logical_or.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/subtract.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace pytorch {
namespace op {

using namespace ov::op;

OutputVector translate_atan2(const NodeContext& context) {
// atan2(input, other, *) → Tensor
num_inputs_check(context, 2, 2);
Output<Node> lhs;
Output<Node> rhs;

std::tie(lhs, rhs) = get_inputs_with_promoted_types(context, 0, 1);

auto div = context.mark_node(std::make_shared<v1::Divide>(lhs, rhs));

auto atan = context.mark_node(std::make_shared<v0::Atan>(div));

// create some constants to adjust result according to quadrant.
auto zero = context.mark_node(v0::Constant::create(ov::element::i32, Shape{}, {0}));
auto pi = context.mark_node(v0::Constant::create(ov::element::f64, Shape{}, {M_PI}));
auto neg_pi = context.mark_node(v0::Constant::create(ov::element::f64, Shape{}, {-M_PI}));
auto half_pi = context.mark_node(v0::Constant::create(ov::element::f64, Shape{}, {M_PI_2}));
auto neg_half_pi = context.mark_node(v0::Constant::create(ov::element::f64, Shape{}, {-M_PI_2}));

zero = context.mark_node(std::make_shared<v1::ConvertLike>(zero, rhs));
pi = context.mark_node(std::make_shared<v1::ConvertLike>(pi, rhs));
neg_pi = context.mark_node(std::make_shared<v1::ConvertLike>(neg_pi, rhs));
half_pi = context.mark_node(std::make_shared<v1::ConvertLike>(half_pi, rhs));
neg_half_pi = context.mark_node(std::make_shared<v1::ConvertLike>(neg_half_pi, rhs));

// x > 0, no adjustment needed
auto x_greater_than_zero = context.mark_node(std::make_shared<v1::Greater>(rhs, zero));

// x < 0 and y >= 0, need to plus pi
auto y_greater_equal_zero = context.mark_node(std::make_shared<v1::GreaterEqual>(lhs, zero));
auto x_less_than_zero = context.mark_node(std::make_shared<v1::Less>(rhs, zero));
auto add_pi_condition = context.mark_node(std::make_shared<v1::LogicalAnd>(x_less_than_zero, y_greater_equal_zero));

// x < 0 and y < 0, need to minus pi
auto y_less_than_zero = std::make_shared<v1::Less>(lhs, zero);
auto subtract_pi_condition =
context.mark_node(std::make_shared<v1::LogicalAnd>(x_less_than_zero, y_less_than_zero));

// x = 0 and y > 0, pi/2
auto x_equal_zero = std::make_shared<v1::Equal>(rhs, zero);
auto y_greater_than_zero = std::make_shared<v1::Greater>(lhs, zero);
auto half_pi_condition = context.mark_node(std::make_shared<v1::LogicalAnd>(x_equal_zero, y_greater_than_zero));

// x = 0 and y < 0, -pi/2
auto neg_half_pi_condition = context.mark_node(std::make_shared<v1::LogicalAnd>(x_equal_zero, y_less_than_zero));

auto special_case_condition =
context.mark_node(std::make_shared<v1::LogicalOr>(half_pi_condition, neg_half_pi_condition));

// do adjustment
auto atan_plus_pi = context.mark_node(std::make_shared<v1::Add>(atan, pi));
auto atan_minus_pi = context.mark_node(std::make_shared<v1::Subtract>(atan, pi));

// select result
auto ajusted_case = context.mark_node(std::make_shared<v1::Select>(add_pi_condition, atan_plus_pi, atan_minus_pi));
auto special_case = context.mark_node(std::make_shared<v1::Select>(half_pi_condition, half_pi, neg_half_pi));
auto adjusted_atan = context.mark_node(std::make_shared<v1::Select>(x_greater_than_zero, atan, ajusted_case));
auto result = context.mark_node(std::make_shared<v1::Select>(special_case_condition, special_case, adjusted_atan));

return {result};
}

} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ OP_CONVERTER(translate_argmax);
OP_CONVERTER(translate_argmin);
OP_CONVERTER(translate_as_strided);
OP_CONVERTER(translate_as_tensor);
OP_CONVERTER(translate_atan2);
OP_CONVERTER(translate_avg_poolnd);
OP_CONVERTER(translate_bool);
OP_CONVERTER(translate_batch_norm);
Expand Down Expand Up @@ -380,6 +381,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::atanh",
op::optional_out<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atanh>, 1>},
{"aten::atanh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Atanh>>},
{"aten::atan2", op::translate_atan2},
{"aten::avg_pool1d", op::quantizable_op<op::translate_avg_poolnd>},
{"aten::avg_pool2d", op::quantizable_op<op::translate_avg_poolnd>},
{"aten::avg_pool3d", op::quantizable_op<op::translate_avg_poolnd>},
Expand Down
80 changes: 80 additions & 0 deletions tests/layer_tests/pytorch_tests/test_atan2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest
import torch

from pytorch_layer_test_class import PytorchLayerTest

@pytest.mark.parametrize("input_shape_rhs", [
[2, 5, 3, 4],
[1, 5, 3, 4],
[1]
])
class TestAtan2(PytorchLayerTest):

def _prepare_input(self):
return (np.random.randn(2, 5, 3, 4).astype(np.float32), self.input_rhs)

def create_model(self):

class aten_atan2(torch.nn.Module):
def __init__(self):
super(aten_atan2, self).__init__()

def forward(self, lhs, rhs):
return torch.arctan2(lhs, rhs)

ref_net = None

return aten_atan2(), ref_net, "aten::atan2"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_torch_export
@pytest.mark.precommit_fx_backend
def test_atan2(self, ie_device, precision, ir_version, input_shape_rhs):
self.input_rhs = np.random.randn(*input_shape_rhs).astype(np.float32)
self._test(*self.create_model(), ie_device, precision, ir_version, use_convert_model=True)

class TestAtan2Types(PytorchLayerTest):

def _prepare_input(self):
return (torch.randn(self.lhs_shape).to(self.lhs_type).numpy(),
torch.randn(self.rhs_shape).to(self.rhs_type).numpy())

def create_model(self, lhs_type, rhs_type):

class aten_atan2(torch.nn.Module):
def __init__(self, lhs_type, rhs_type):
super(aten_atan2, self).__init__()
self.lhs_type = lhs_type
self.rhs_type = rhs_type

def forward(self, lhs, rhs):
return torch.arctan2(lhs.to(self.lhs_type), rhs.to(self.rhs_type))

ref_net = None

return aten_atan2(lhs_type, rhs_type), ref_net, "aten::atan2"

@pytest.mark.parametrize(("lhs_type", "rhs_type"),
[[torch.int, torch.float32],
[torch.int, torch.float64],
[torch.float32, torch.float64],
[torch.int64, torch.float32]
])
@pytest.mark.parametrize(("lhs_shape", "rhs_shape"), [([2, 3], [2, 3]),
([2, 3], [1, 3]),
([3, 2, 3], [2, 3]),
])
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_torch_export
def test_atan2_types(self, ie_device, precision, ir_version, lhs_type, lhs_shape, rhs_type, rhs_shape):
self.lhs_type = lhs_type
self.lhs_shape = lhs_shape
self.rhs_type = rhs_type
self.rhs_shape = rhs_shape
self._test(*self.create_model(lhs_type, rhs_type),
ie_device, precision, ir_version, freeze_model=False, trace_model=True)
Loading