Skip to content

Commit

Permalink
Merge pull request #1259 from pytorch/assorted_small_fixes
Browse files Browse the repository at this point in the history
Assorted small fixes
  • Loading branch information
narendasan authored Aug 12, 2022
2 parents 679ea21 + 460fc9b commit 6f61c6f
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 34 deletions.
Empty file modified core/conversion/converters/impl/element_wise.cpp
100755 → 100644
Empty file.
9 changes: 6 additions & 3 deletions core/conversion/converters/impl/unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ namespace impl {
namespace {

auto abs_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
{"aten::abs (Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
{"aten::abs(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensorOrFreeze(ctx);
bool unary_supported_input = in->getType() == nvinfer1::DataType::kFLOAT ||
in->getType() == nvinfer1::DataType::kHALF || in->getType() == nvinfer1::DataType::kINT8;
if (unary_supported_input) {
Expand All @@ -23,6 +23,9 @@ auto abs_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
return true;
} else {
LOG_GRAPH(
"Tensor is of unsupported type "
<< in->getType() << " for IUnaryLayer::kABS. Using backup implementation via IElementWise (max(x, -x)");
// For types not supported by kABS, use an elementwise implementation abs(x) = max(x, -1 * x)
at::Tensor neg_one = torch::full({1}, -1).to(util::TRTDataTypeToScalarType(in->getType()));
auto neg_one_const = tensor_to_const(ctx, neg_one);
Expand Down Expand Up @@ -50,7 +53,7 @@ auto abs_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern
auto unary##_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( \
{"aten::" #unary "(Tensor self) -> Tensor", \
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { \
auto in = args[0].ITensor(); \
auto in = args[0].ITensorOrFreeze(ctx); \
auto unary = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::trt_type); \
\
TORCHTRT_CHECK(unary, "Unable to create " #unary " layer from node: " << *n); \
Expand Down
18 changes: 14 additions & 4 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ auto aten_registrations TORCHTRT_UNUSED =
auto self = args.at(n->input(0)).IValue();
auto obj = args.at(n->input(1)).IValue();

return self->isSameIdentity(*obj);
return self->is(*obj);
},
EvalOptions().validSchemas({
"aten::__is__(t1 self, t2 obj) -> bool",
Expand All @@ -527,7 +527,7 @@ auto aten_registrations TORCHTRT_UNUSED =
auto self = args.at(n->input(0)).IValue();
auto obj = args.at(n->input(1)).IValue();

return !self->isSameIdentity(*obj);
return !self->is(*obj);
},
EvalOptions().validSchemas({
"aten::__isnot__(t1 self, t2 obj) -> bool",
Expand Down Expand Up @@ -806,9 +806,19 @@ auto aten_registrations TORCHTRT_UNUSED =
return 0;
}
},
EvalOptions().validSchemas({"aten::__range_length(int lo, int hi, int step) -> int"})});
EvalOptions().validSchemas({"aten::__range_length(int lo, int hi, int step) -> int"})})
.evaluator(
{c10::Symbol::fromQualString("aten::__derive_index"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto idx = args.at(n->input(0)).unwrapToInt();
auto start = args.at(n->input(1)).unwrapToInt();
auto step = args.at(n->input(2)).unwrapToInt();
return start + idx * step;
},
EvalOptions().validSchemas({"aten::__derive_index(int idx, int start, int step) -> int"})});

} // namespace
} // namespace evaluators
} // namespace conversion
} // namespace core
} // namespace torch_tensorrt
} // namespace torch_tensorrt
37 changes: 10 additions & 27 deletions core/conversion/evaluators/prim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,36 +270,19 @@ auto prim_registrations =
.evaluator(
{torch::jit::prim::TupleConstruct,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto num_inputs = n->inputs().size();
c10::IValue tuple = c10::ivalue::Tuple::create();
switch (num_inputs) {
case 0:
tuple = c10::ivalue::Tuple::create();
break;
case 1:
tuple = c10::ivalue::Tuple::create(std::move((*args.at(n->input(0)).IValue())));
break;
case 2: {
tuple = c10::ivalue::Tuple::create(
std::move(*(args.at(n->input(0)).IValue())), std::move(*(args.at(n->input(1)).IValue())));
break;
}
case 3: {
tuple = c10::ivalue::Tuple::create(
std::move(*(args.at(n->input(0)).IValue())),
std::move(*(args.at(n->input(1)).IValue())),
std::move(*(args.at(n->input(2)).IValue())));
break;
}
default: {
std::vector<c10::IValue> elems;
for (size_t i = 0; i < num_inputs; i++) {
elems.push_back(*(args.at(n->input(i)).IValue()));
}
tuple = c10::ivalue::Tuple::create(std::move(elems));
break;
std::vector<c10::IValue> elems;
for (auto in : n->inputs()) {
if (args.at(in).isITensor()) {
auto tensor_holder = TensorContainer();
tensor_holder.hold_tensor(args.at(in).ITensor());
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
elems.push_back(std::move(ival));
} else {
elems.push_back(*(args.at(in).IValue()));
}
}
tuple = c10::ivalue::Tuple::create(std::move(elems));
return c10::optional<torch::jit::IValue>(std::move(tuple));
}})
.evaluator(
Expand Down
69 changes: 69 additions & 0 deletions tests/core/conversion/evaluators/test_aten_evaluators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -797,3 +797,72 @@ TEST(Evaluators, PowFloatIntEvaluatesCorrectly) {

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, DeriveIndexEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%1 : int = prim::Constant[value=9]()
%2 : int = prim::Constant[value=4]()
%3 : int = prim::Constant[value=2]()
%4 : int = aten::__derive_index(%1, %2, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, IsTrueEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%1 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=1]()
%4 : bool = aten::__is__(%1, %2)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, IsFalseEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%1 : int = prim::Constant[value=9]()
%2 : None = prim::Constant()
%4 : bool = aten::__is__(%1, %2)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, IsNotTrueEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%1 : int = prim::Constant[value=1]()
%2 : None = prim::Constant()
%4 : bool = aten::__isnot__(%1, %2)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

0 comments on commit 6f61c6f

Please sign in to comment.