-
Notifications
You must be signed in to change notification settings - Fork 360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Added support for aten::unflatten converter #2097
Conversation
Signed-off-by: Anurag Dixit <anurag.dixit@getcruise.com>
Signed-off-by: Anurag Dixit <anurag.dixit@getcruise.com>
@peri044 : Following commands have no impact on the linting of my code changes:
I don't see buildifier triggering auto lint with commits. Can you please share the instructions for the same? |
Signed-off-by: Anurag Dixit <anurag.dixit@getcruise.com>
Signed-off-by: Anurag Dixit <anurag.dixit@getcruise.com>
tests/cpp/test_dynamic_size.cpp
Outdated
|
||
TEST(Converters, ATenUnflattenDynShapeShapeCorrectly) { | ||
const auto graph = R"IR( | ||
graph(%x.1 : Tensor): | ||
%2 : int = prim::Constant[value=1]() | ||
%3 : int = prim::Constant[value=512]() | ||
%4 : int = prim::Constant[value=1]() | ||
%5 : int = prim::Constant[value=1]() | ||
%6 : int[] = prim::ListConstruct(%3, %4, %5) | ||
%7 : Tensor = aten::unflatten(%x.1, %2, %6) | ||
return (%7))IR"; | ||
|
||
auto g = std::make_shared<torch::jit::Graph>(); | ||
|
||
torch::jit::parseIR(graph, g.get()); | ||
|
||
auto in = at::randint(0, 10, {1, 512}, {at::kCUDA}); | ||
|
||
auto jit_in = at::clone(in); | ||
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); | ||
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); | ||
|
||
auto trt_in = at::clone(in); | ||
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); | ||
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true); | ||
|
||
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); | ||
} | ||
|
||
TEST(Converters, ATenUnflattenDynShapeNegativeDimsShapeCorrectly) { | ||
const auto graph = R"IR( | ||
graph(%x.1 : Tensor): | ||
%2 : int = prim::Constant[value=-2]() | ||
%3 : int = prim::Constant[value=512]() | ||
%4 : int = prim::Constant[value=1]() | ||
%5 : int = prim::Constant[value=1]() | ||
%6 : int[] = prim::ListConstruct(%3, %4, %5) | ||
%7 : Tensor = aten::unflatten(%x.1, %2, %6) | ||
return (%7))IR"; | ||
|
||
auto g = std::make_shared<torch::jit::Graph>(); | ||
|
||
torch::jit::parseIR(graph, g.get()); | ||
|
||
auto in = at::randint(0, 10, {1, 512, 2}, {at::kCUDA}); | ||
|
||
auto jit_in = at::clone(in); | ||
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); | ||
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); | ||
|
||
auto trt_in = at::clone(in); | ||
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); | ||
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true); | ||
|
||
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move these to test_shuffle.cpp since these are not ITensor based (aten::size + dyn_shapes) ? The aten::size + aten::unflatten.cpp tests can be left in this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed redundant test cases. Mentioned test case variants are already present in test_shuffle.
We use |
Signed-off-by: Anurag Dixit <anurag.dixit@getcruise.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
Added a converter support for aten::unflatten.
Created a separate PR for #1808
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: