Skip to content

Commit

Permalink
Fixing aten::slice invalid schema and implementing aten::list evaluat…
Browse files Browse the repository at this point in the history
…or (#1695)
  • Loading branch information
apbose authored Jun 2, 2023
1 parent 6a26856 commit 4494699
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 7 deletions.
26 changes: 20 additions & 6 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,20 @@ auto aten_registrations TORCHTRT_UNUSED =
{c10::Symbol::fromQualString("aten::slice"),
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
c10::List<c10::IValue> list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();

int64_t start = 0;
int64_t end = 9223372036854775807;
auto startIVal = args.at(n->input(1)).IValue();
auto endIVal = args.at(n->input(2)).IValue();

if (!startIVal->isNone()) {
start = args.at(n->input(1)).unwrapToInt();
}
int64_t end = args.at(n->input(2)).unwrapToInt();
if (!endIVal->isNone()) {
end = args.at(n->input(2)).unwrapToInt();
}
if (start > end) {
LOG_DEBUG("The end should be greater than start");
}
int64_t step = args.at(n->input(3)).unwrapToInt();

const int64_t list_size = list.size();
Expand All @@ -253,8 +260,9 @@ auto aten_registrations TORCHTRT_UNUSED =

return sliced_list;
},
EvalOptions().validSchemas(
{"aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])"})})
EvalOptions().validSchemas({"aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> (t[])"})})
// EvalOptions().validSchemas(
// {"aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])"})})
.evaluator(
{c10::Symbol::fromQualString("aten::len"),
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
Expand Down Expand Up @@ -896,8 +904,14 @@ auto aten_registrations TORCHTRT_UNUSED =
auto step = args.at(n->input(2)).unwrapToInt();
return start + idx * step;
},
EvalOptions().validSchemas({"aten::__derive_index(int idx, int start, int step) -> int"})});

EvalOptions().validSchemas({"aten::__derive_index(int idx, int start, int step) -> int"})})
.evaluator(
{c10::Symbol::fromQualString("aten::list"),
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
c10::List<c10::IValue> list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
return list.copy();
},
EvalOptions().validSchemas({"aten::list.t(t[] l) -> (t[])"})});
} // namespace
} // namespace evaluators
} // namespace conversion
Expand Down
2 changes: 1 addition & 1 deletion tests/core/conversion/evaluators/evaluator_test.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ def evaluator_test(name, visibility = None):
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@libtorch//:libtorch"],
}),
timeout = "short",
timeout = "long",
)
35 changes: 35 additions & 0 deletions tests/core/conversion/evaluators/test_aten_evaluators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -931,3 +931,38 @@ TEST(Evaluators, IsNotTrueEvaluatesCorrectly) {

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, IsAtenSliceEvaluateCorrectly) {
const auto graph = R"IR(
graph():
%1 : int[] = prim::Constant[value= [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]()
%2 : int = prim::Constant[value = 0]()
%3 : int = prim::Constant[value = 7]()
%4 : int = prim::Constant[value = 2]()
%5 : int[] = aten::slice(%1, %2, %3, %4)
return (%5))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, IsAtenListEvaluateCorrectly) {
const auto graph = R"IR(
graph():
%1 : int[] = prim::Constant[value= [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]()
%2 : int[] = aten::list(%1)
return (%2))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

0 comments on commit 4494699

Please sign in to comment.