Skip to content

Commit

Permalink
masked_fill test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Sep 5, 2023
1 parent c7ad4dd commit e2a87a9
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions tests/core/conversion/converters/test_masked_fill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,62 @@ TEST(Converters, ATenMaskedFillMixedTypesFloatIntConvertsCorrectly) {
ASSERT_TRUE(jit_results[0].dtype() == trt_results[0].dtype());
}

TEST(Converters, ATenMaskedFillBroadcastMaskPad) {
const auto graph = R"IR(
graph(%x.1 : Tensor, %x.2 : Tensor):
%val : int = prim::Constant[value=4]()
%out : Tensor = aten::masked_fill(%x.1, %x.2, %val)
return (%out))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, &*g);

// Input is an integer tensor, filled with a float --> expecting integer tensor out
auto in1 = at::rand({2, 3, 5, 7}, {at::kCUDA}).to(torch::kInt32);
auto in2 = (2 * at::rand({3, 5, 7}, {at::kCUDA})).to(torch::kBool);

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));

// Ensure data types match in outputs
ASSERT_TRUE(jit_results[0].dtype() == trt_results[0].dtype());
}

TEST(Converters, ATenMaskedFillBroadcastSelfPad) {
const auto graph = R"IR(
graph(%x.1 : Tensor, %x.2 : Tensor):
%val : int = prim::Constant[value=4]()
%out : Tensor = aten::masked_fill(%x.1, %x.2, %val)
return (%out))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, &*g);

// Input is an integer tensor, filled with a float --> expecting integer tensor out
auto in1 = at::rand({3, 5, 7}, {at::kCUDA}).to(torch::kInt32);
auto in2 = (2 * at::rand({1, 3, 5, 7}, {at::kCUDA})).to(torch::kBool);

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));

// Ensure data types match in outputs
ASSERT_TRUE(jit_results[0].dtype() == trt_results[0].dtype());
}

TEST(Converters, ATenMaskedFillMixedTypesIntFloatConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor, %x.2 : Tensor):
Expand Down

0 comments on commit e2a87a9

Please sign in to comment.