Skip to content

Commit

Permalink
[Codegen] Emit tir::Let as var assignment explicitly (#17278)
Browse files Browse the repository at this point in the history
Prior to this PR, the PrimExpr `tir::Let` is treated as inlining during
codegen, which makes any common subexpression elimination (CSE) efforts
using `tir::Let` at TIR level effectless.

This PR updates codegen so that the `tir::Let` will have an explicit
var assignment and thus can effectively reflect the CSE efforts.
  • Loading branch information
MasterJH5574 authored Aug 21, 2024
1 parent dc24781 commit b76ebad
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 7 deletions.
6 changes: 3 additions & 3 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2544,7 +2544,7 @@ def _cumsum_mask(cumsum_sorted, top_p, top_k, i, j):

@T.prim_func(private=True)
def _get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
batch, vocab_size = T.int64(), T.int64()
batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
cumsum_sorted = T.match_buffer(A, (batch, vocab_size), prob_dtype)
top_p = T.match_buffer(B, (batch, 1), prob_dtype)
top_k = T.match_buffer(C, (batch, 1), index_dtype)
Expand All @@ -2564,8 +2564,8 @@ def _get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
def _get_index_from_sorted(
A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F: T.handle
):
batch, vocab_size = T.int64(), T.int64()
out_batch = T.int64()
batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
out_batch = T.int64(is_size_var=True)
cumsum_sorted = T.match_buffer(A, (batch, vocab_size), prob_dtype)
indices = T.match_buffer(B, (batch, vocab_size), index_dtype)
renorm_prob = T.match_buffer(C, (batch, 1), prob_dtype)
Expand Down
21 changes: 20 additions & 1 deletion src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -887,8 +887,27 @@ void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*)
let_binding_[op->var] = op;
}
std::string value = PrintExpr(op->value);
var_idmap_[op->var.get()] = value;
if (print_ssa_form_) {
ICHECK(!var_idmap_.count(op->var.get()));
var_idmap_[op->var.get()] = value;
} else {
PrintIndent();
if (op->var.dtype() == DataType::Handle() && handle_data_type_.count(op->var.get())) {
PrintType(handle_data_type_.at(op->var.get()), this->stream);
this->stream << "* " << AllocVarID(op->var.get()) << " = (";
PrintType(handle_data_type_.at(op->var.get()), this->stream);
this->stream << "*)" << value << ";\n";
} else {
PrintType(op->var.dtype(), this->stream);
this->stream << ' ' << AllocVarID(op->var.get()) << " = " << value << ";\n";
}
}
os << PrintExpr(op->body);
// Pop the defined var from var_idmap when exiting its scope.
// We do this because it is hard to completely avoid a same LetNode appearing
// at different places.
bool removed = var_idmap_.erase(op->var.get());
ICHECK(removed);
}

void CodeGenC::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*)
Expand Down
6 changes: 3 additions & 3 deletions tests/python/relax/test_frontend_nn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,11 +947,11 @@ def foo(
class Expected:
@T.prim_func(private=True)
def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F: T.handle):
batch, vocab_size = T.int64(), T.int64()
batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
indices = T.match_buffer(B, (batch, vocab_size), "int64")
renorm_prob = T.match_buffer(C, (batch, 1))
out_batch = T.int64()
out_batch = T.int64(is_size_var=True)
usample = T.match_buffer(D, (out_batch, 1))
sample_indices = T.match_buffer(E, (out_batch, 1), "int64")
output_index = T.match_buffer(F, (out_batch, 1), "int64")
Expand All @@ -970,7 +970,7 @@ def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E:

@T.prim_func(private=True)
def get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
batch, vocab_size = T.int64(), T.int64()
batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
top_p = T.match_buffer(B, (batch, 1))
top_k = T.match_buffer(C, (batch, 1), "int64")
Expand Down

0 comments on commit b76ebad

Please sign in to comment.