Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix/feat: Add lowering pass to resolve most aten::Int.Tensor uses #1937

Merged
merged 2 commits into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
passes::SiluToSigmoidMultipication(g);
passes::RemoveSingleUse0DTensors(g);
passes::RemoveUnnecessaryCasts(g);
passes::ReplaceAtenInt(g);
passes::UnpackAndCastMaskedFill(g, lower_info.getGPUDeviceString());
passes::UnpackAndCastNumToTensor(g, lower_info.getGPUDeviceString());
passes::UnpackAndCastFull(g, lower_info.getGPUDeviceString());
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
void RemoveNOPs(std::shared_ptr<torch::jit::Graph> graph);
void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g);
void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph);
void ReplaceAtenInt(std::shared_ptr<torch::jit::Graph>& g);
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
Expand Down
145 changes: 145 additions & 0 deletions core/lowering/passes/remove_unnecessary_casts.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "torch/csrc/jit/ir/constants.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/passes/subgraph_rewrite.h"

#include "core/util/prelude.h"
Expand Down Expand Up @@ -211,6 +212,150 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
LOG_GRAPH("Post removing single use 0-dim Tensor operations: " << *g);
}

// Schemas for Aten::Int which can be replaced by scalar equivalents
const std::unordered_set<c10::Symbol> AtenIntReplacementNodeKinds = {
torch::jit::aten::mul,
torch::jit::aten::floor_divide,
};

c10::optional<torch::jit::Value*> Validate0DTensor(torch::jit::Value* value) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored to use c10::optional wrapper instead of nullptr + replaced pointer checks with .has_value()

// Validates that the input Value* is a 0D Tensor (or int/float)
// Return the stored int/float Value* if so, otherwise null
c10::optional<torch::jit::Value*> enclosed_scalar_value = {};

// Regular Int/Float case
if (value->type()->isSubtypeOf(c10::IntType::get()) || value->type()->isSubtypeOf(c10::FloatType::get())) {
enclosed_scalar_value = value;
return enclosed_scalar_value;
}

// Constant Tensor case
if (value->node()->kind() == torch::jit::prim::Constant && value->type()->isSubtypeOf(c10::TensorType::get())) {
// Retrieve the Tensor stored in constant
at::Tensor t = *torch::jit::constant_as<at::Tensor>(value);
// Validate the shape of the Tensor is 0D (single-element) and integral
if (t.sizes() == std::vector<int64_t>({}) && t.item().isIntegral(false)) {
// Extract the stored value, add it to the graph as a constant
torch::jit::WithInsertPoint guard(value->node());
auto new_const_val = value->owningGraph()->insertConstant(t.item(), c10::nullopt, value->node()->scope());
new_const_val->copyMetadata(value);
new_const_val->setType(c10::IntType::get());
enclosed_scalar_value = new_const_val;
return enclosed_scalar_value;
} else {
LOG_DEBUG("In aten::Int.Tensor removal, encountered a const which was either not 0D or not integral");
}
}

// NumToTensor case
if (value->node()->kind() == torch::jit::prim::NumToTensor && value->type()->isSubtypeOf(c10::TensorType::get())) {
// Input to NumToTensor is relevant scalar
enclosed_scalar_value = value->node()->input();
return enclosed_scalar_value;
}

return enclosed_scalar_value;
}

c10::optional<torch::jit::Value*> TracebackAndEliminate0DTensors(torch::jit::Node* node) {
// Trace back through a node and all parents to eliminate 0D Tensors
// and update schemas to their scalar alternatives, returning final
// Value* to user

// Requires valid schema with at least two inputs
if (AtenIntReplacementNodeKinds.find(node->kind()) == AtenIntReplacementNodeKinds.end() ||
node->inputs().size() < 2) {
LOG_DEBUG(
"Encountered node " << node->kind().toQualString()
<< " which is unsupported in the aten::Int.Tensor replacement lowering pass.");
return {};
}

// Validate the first and second function inputs are 0D tensors or scalars
c10::optional<torch::jit::Value*> first_input_scalar_value = Validate0DTensor(node->inputs()[0]);
c10::optional<torch::jit::Value*> second_input_scalar_value = Validate0DTensor(node->inputs()[1]);

// If the first input is not a scalar, recursively traceback on parent nodes
if (!first_input_scalar_value.has_value()) {
LOG_DEBUG("In aten::Int.Tensor lowering, now tracing " << node->inputs()[0]->node()->kind().toQualString());
first_input_scalar_value = TracebackAndEliminate0DTensors(node->inputs()[0]->node());
}

// If the second input is not a scalar, recursively traceback on parent nodes
if (!second_input_scalar_value.has_value()) {
LOG_DEBUG("In aten::Int.Tensor lowering, now tracing " << node->inputs()[0]->node()->kind().toQualString());
second_input_scalar_value = TracebackAndEliminate0DTensors(node->inputs()[1]->node());
}

if (!first_input_scalar_value.has_value() || !second_input_scalar_value.has_value()) {
LOG_DEBUG(
"In aten::Int.Tensor lowering, recursive trace through node input "
<< "parents failed to return a Scalar value for at least one parent node.");
return {};
}

// Set default insert point at node
torch::jit::WithInsertPoint guard(node);
torch::jit::Node* new_node;

switch (node->kind()) {
// In the aten::floor_divide case, the schema syntax changes, so a new node
// must be inserted
case torch::jit::aten::floor_divide:
new_node = node->owningGraph()->create(
torch::jit::aten::floordiv, {first_input_scalar_value.value(), second_input_scalar_value.value()}, 1);
new_node->insertAfter(node);
new_node->output()->setType(c10::IntType::get());
return new_node->output();

// In the aten::mul case, the schema syntax is the same, so we can use the existing schema
// with new inputs
default:
new_node = node->owningGraph()->create(
node->kind(), {first_input_scalar_value.value(), second_input_scalar_value.value()}, 1);
new_node->insertAfter(node);
new_node->output()->setType(c10::IntType::get());
return new_node->output();
}
}

void ReplaceAtenInt(std::shared_ptr<torch::jit::Graph>& g) {
// Find all nodes with the aten::Int.Tensor schema and replace those
// by tracing through the node and resolving the use of 0D tensors
// to their corresponding scalar alternatives

// Iterate over all nodes in the graph
for (auto it = g->block()->nodes().begin(), end = g->block()->nodes().end(); it != end; ++it) {
// Validate schema requirements for aten::Int.Tensor
if (it->kind() == torch::jit::aten::Int && it->inputs().size() == 1 &&
it->input()->type()->isSubtypeOf(c10::TensorType::get())) {
LOG_DEBUG("Found an aten::Int.Tensor case, attempting to resolve input scalars.");

// If the node parent schema is of a supported type, trace back through the graph
if (AtenIntReplacementNodeKinds.find(it->input()->node()->kind()) != AtenIntReplacementNodeKinds.end()) {
LOG_DEBUG(
"Tracing parent node " << it->input()->node()->kind().toQualString()
<< " to eliminate 0D Tensors for aten::Int.Tensor case.");
auto scalar_input_value = TracebackAndEliminate0DTensors(it->input()->node());
if (scalar_input_value.has_value()) {
it->output()->replaceAllUsesWith(scalar_input_value.value());
LOG_DEBUG("Tracing parent nodes for aten::Int.Tensor case succeeded.");
} else {
LOG_DEBUG("Tracing parent nodes for aten::Int.Tensor case failed.");
}
} else {
LOG_DEBUG(
"Parent node schema " << it->input()->node()->kind().toQualString()
<< " is currently unsupported for aten::Int.Tensor case.");
}
}
}

// Clean up remnant operators in graph
torch::jit::EliminateDeadCode(g);
LOG_GRAPH("Post removing aten.Int.Tensor operations: " << *g);
}

} // namespace passes
} // namespace lowering
} // namespace core
Expand Down
152 changes: 152 additions & 0 deletions tests/core/lowering/test_remove_unnecessary_casts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,3 +437,155 @@ TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivFloatValuesAgree) {
ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
}

TEST(LoweringPasses, RemoveAtenIntTensorValuesAgree) {
std::string source_graph_no_inputs = R"IR(
graph():
%0: int = prim::Constant[value=2]()
%11: int = prim::Constant[value=7]()
%3: Tensor = prim::NumToTensor(%0)
%1: Tensor = prim::NumToTensor(%11)
%4: Tensor = aten::floor_divide(%1, %3)
%7: Tensor = aten::mul(%3, %4)
%8: Tensor = aten::mul(%7, %1)
%50: int = aten::Int(%8)
%5: Tensor = prim::NumToTensor(%50)
return (%5))IR";
std::string target_graph_no_inputs = R"IR(
graph():
%0: int = prim::Constant[value=2]()
%1: int = prim::Constant[value=7]()
%4: int = aten::floordiv(%1, %0)
%7: int = aten::mul(%0, %4)
%40: int = aten::mul(%7, %1)
%4: Tensor = prim::NumToTensor(%40)
return (%4))IR";

auto g_in = std::make_shared<torch::jit::Graph>();
auto g_out = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(source_graph_no_inputs, g_in.get());
torch::jit::parseIR(target_graph_no_inputs, g_out.get());

auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_in, {});
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_out, {});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));

// Ensure the lowering pass transforms the first graph into the second
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph_no_inputs, sg.get());

torch_tensorrt::core::lowering::passes::ReplaceAtenInt(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph_no_inputs, tg.get());

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, RemoveAtenIntSizeTensorValuesAgree) {
std::string source_graph_no_inputs = R"IR(
graph(%x.0: Tensor):
%10: int = prim::Constant[value=0]()
%100: int = aten::size(%x.0, %10)
%0: Tensor = prim::NumToTensor(%100)
%11: int = prim::Constant[value=9]()
%1: Tensor = prim::NumToTensor(%11)
%4: Tensor = aten::floor_divide(%1, %0)
%7: Tensor = aten::mul(%0, %4)
%8: Tensor = aten::mul(%7, %1)
%50: int = aten::Int(%8)
%5: Tensor = prim::NumToTensor(%50)
return (%5))IR";
std::string target_graph_no_inputs = R"IR(
graph(%x.0: Tensor):
%10: int = prim::Constant[value=0]()
%0: int = aten::size(%x.0, %10)
%1: int = prim::Constant[value=9]()
%4: int = aten::floordiv(%1, %0)
%7: int = aten::mul(%0, %4)
%40: int = aten::mul(%7, %1)
%4: Tensor = prim::NumToTensor(%40)
return (%4))IR";

auto g_in = std::make_shared<torch::jit::Graph>();
auto g_out = std::make_shared<torch::jit::Graph>();

auto in_0 = at::rand({2, 3, 5, 5}, {at::kCUDA});

torch::jit::parseIR(source_graph_no_inputs, g_in.get());
torch::jit::parseIR(target_graph_no_inputs, g_out.get());

auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_in, {in_0});
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_out, {in_0});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));

// Ensure the lowering pass transforms the first graph into the second
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph_no_inputs, sg.get());

torch_tensorrt::core::lowering::passes::ReplaceAtenInt(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph_no_inputs, tg.get());

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

TEST(LoweringPasses, RemoveAtenIntConstTensorValuesAgree) {
// Ensure the lowering pass transforms the first graph into the second
std::string source_graph = R"IR(
graph(%0: int):
%1: Tensor = prim::Constant[value=[8]]()
%3: Tensor = prim::NumToTensor(%0)
%4: Tensor = aten::floor_divide(%3, %1)
%5: int = aten::Int(%4)
return (%5))IR";

std::string target_graph = R"IR(
graph(%0 : int):
%1 : Tensor = prim::Constant[value=[8]]()
%2 : int = prim::Constant[value=8]()
%3 : int = aten::floordiv(%0, %2)
return (%3))IR";

auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, &*sg);

// Manually enter 0d tensor const for source
auto first_op_sg = *(sg->block()->nodes().begin());
torch::jit::Value* r_sg = sg->insertConstant(c10::scalar_to_tensor(8), c10::nullopt, first_op_sg->scope());
r_sg->copyMetadata(first_op_sg->output());
r_sg->setType(c10::TensorType::get());
first_op_sg->output()->replaceAllUsesWith(r_sg);
first_op_sg->destroy();

torch_tensorrt::core::lowering::passes::ReplaceAtenInt(sg);
torch::jit::ConstantPooling(sg);
sg = torch::jit::Canonicalize(sg, false);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, &*tg);

// Manually enter 0d tensor const for target
auto first_op_tg = *(tg->block()->nodes().begin());
torch::jit::Value* r_tg = tg->insertConstant(c10::scalar_to_tensor(8), c10::nullopt, first_op_tg->scope());
r_tg->copyMetadata(first_op_tg->output());
r_tg->setType(c10::TensorType::get());
first_op_tg->output()->replaceAllUsesWith(r_tg);
first_op_tg->destroy();

torch::jit::ConstantPooling(tg);
tg = torch::jit::Canonicalize(tg, false);

// Validate identical graphs after pooling constants and canonicalizing
ASSERT_TRUE((tg->toString() == sg->toString()));
}