Skip to content

Commit

Permalink
fix: Refactor implementation to remove nullptr
Browse files Browse the repository at this point in the history
- Edit in favor of `c10::optional` type usage
  • Loading branch information
gs-olive committed May 22, 2023
1 parent a86ac93 commit d3c0c7a
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions core/lowering/passes/remove_unnecessary_casts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,10 @@ const std::unordered_set<c10::Symbol> AtenIntReplacementNodeKinds = {
torch::jit::aten::floor_divide,
};

torch::jit::Value* Validate0DTensor(torch::jit::Value* value) {
c10::optional<torch::jit::Value*> Validate0DTensor(torch::jit::Value* value) {
// Validates that the input Value* is a 0D Tensor (or int/float)
// Return the stored int/float Value* if so, otherwise null
torch::jit::Value* enclosed_scalar_value = nullptr;
c10::optional<torch::jit::Value*> enclosed_scalar_value = {};

// Regular Int/Float case
if (value->type()->isSubtypeOf(c10::IntType::get()) || value->type()->isSubtypeOf(c10::FloatType::get())) {
Expand Down Expand Up @@ -257,7 +257,7 @@ torch::jit::Value* Validate0DTensor(torch::jit::Value* value) {
return enclosed_scalar_value;
}

torch::jit::Value* TracebackAndEliminate0DTensors(torch::jit::Node* node) {
c10::optional<torch::jit::Value*> TracebackAndEliminate0DTensors(torch::jit::Node* node) {
// Trace back through a node and all parents to eliminate 0D Tensors
// and update schemas to their scalar alternatives, returning final
// Value* to user
Expand All @@ -268,30 +268,30 @@ torch::jit::Value* TracebackAndEliminate0DTensors(torch::jit::Node* node) {
LOG_DEBUG(
"Encountered node " << node->kind().toQualString()
<< " which is unsupported in the aten::Int.Tensor replacement lowering pass.");
return nullptr;
return {};
}

// Validate the first and second function inputs are 0D tensors or scalars
torch::jit::Value* first_input_scalar_value = Validate0DTensor(node->inputs()[0]);
torch::jit::Value* second_input_scalar_value = Validate0DTensor(node->inputs()[1]);
c10::optional<torch::jit::Value*> first_input_scalar_value = Validate0DTensor(node->inputs()[0]);
c10::optional<torch::jit::Value*> second_input_scalar_value = Validate0DTensor(node->inputs()[1]);

// If the first input is not a scalar, recursively traceback on parent nodes
if (!first_input_scalar_value) {
if (!first_input_scalar_value.has_value()) {
LOG_DEBUG("In aten::Int.Tensor lowering, now tracing " << node->inputs()[0]->node()->kind().toQualString());
first_input_scalar_value = TracebackAndEliminate0DTensors(node->inputs()[0]->node());
}

// If the second input is not a scalar, recursively traceback on parent nodes
if (!second_input_scalar_value) {
if (!second_input_scalar_value.has_value()) {
LOG_DEBUG("In aten::Int.Tensor lowering, now tracing " << node->inputs()[0]->node()->kind().toQualString());
second_input_scalar_value = TracebackAndEliminate0DTensors(node->inputs()[1]->node());
}

if (!first_input_scalar_value || !second_input_scalar_value) {
if (!first_input_scalar_value.has_value() || !second_input_scalar_value.has_value()) {
LOG_DEBUG(
"In aten::Int.Tensor lowering, recursive trace through node input "
<< "parents failed to return a Scalar value for at least one parent node.");
return nullptr;
return {};
}

// Set default insert point at node
Expand All @@ -303,15 +303,16 @@ torch::jit::Value* TracebackAndEliminate0DTensors(torch::jit::Node* node) {
// must be inserted
case torch::jit::aten::floor_divide:
new_node = node->owningGraph()->create(
torch::jit::aten::floordiv, {first_input_scalar_value, second_input_scalar_value}, 1);
torch::jit::aten::floordiv, {first_input_scalar_value.value(), second_input_scalar_value.value()}, 1);
new_node->insertAfter(node);
new_node->output()->setType(c10::IntType::get());
return new_node->output();

// In the aten::mul case, the schema syntax is the same, so we can use the existing schema
// with new inputs
default:
new_node = node->owningGraph()->create(node->kind(), {first_input_scalar_value, second_input_scalar_value}, 1);
new_node = node->owningGraph()->create(
node->kind(), {first_input_scalar_value.value(), second_input_scalar_value.value()}, 1);
new_node->insertAfter(node);
new_node->output()->setType(c10::IntType::get());
return new_node->output();
Expand All @@ -336,8 +337,8 @@ void ReplaceAtenInt(std::shared_ptr<torch::jit::Graph>& g) {
"Tracing parent node " << it->input()->node()->kind().toQualString()
<< " to eliminate 0D Tensors for aten::Int.Tensor case.");
auto scalar_input_value = TracebackAndEliminate0DTensors(it->input()->node());
if (scalar_input_value) {
it->output()->replaceAllUsesWith(scalar_input_value);
if (scalar_input_value.has_value()) {
it->output()->replaceAllUsesWith(scalar_input_value.value());
LOG_DEBUG("Tracing parent nodes for aten::Int.Tensor case succeeded.");
} else {
LOG_DEBUG("Tracing parent nodes for aten::Int.Tensor case failed.");
Expand Down

0 comments on commit d3c0c7a

Please sign in to comment.