Skip to content

Commit

Permalink
[TIR] CSE-TIR Pass - More deterministic behavior (apache#10663)
Browse files Browse the repository at this point in the history
* iterate through sorted keys

* masa comments -- simplify iteration

* test

* tests

* simplify vector construciton

* jostle ci
  • Loading branch information
AndrewZhaoLuo authored and pfk-beta committed Apr 11, 2022
1 parent b849843 commit 69fbecd
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
17 changes: 15 additions & 2 deletions src/tir/transforms/common_subexpr_elim_tools.cc
Original file line number Diff line number Diff line change
Expand Up @@ -743,13 +743,27 @@ bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b) {
std::vector<std::pair<PrimExpr, size_t>> SyntacticToSemanticComputations(
const ComputationTable& table) {
std::vector<std::pair<PrimExpr, size_t>> result;

// table.size() is an upper-bound of the number of elements in the resulting vector,
// as we might merge semantically equivalent computations.
// We do this reservation even if it might reserve slightly more space than is needed in the end
result.reserve(table.size());

// Traverse through map in a sorted order on keys to maintain deterministic behavior
// We do this by comparing the string repr of each PrimExpr to get a determinstic ordering
std::vector<std::pair<PrimExpr, size_t>> sorted_map_items(table.begin(), table.end());

sort(sorted_map_items.begin(), sorted_map_items.end(),
[](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
std::stringstream a_stream;
std::stringstream b_stream;
a_stream << a.first;
b_stream << b.first;
return a_stream.str().compare(b_stream.str()) < 0;
});

// For each element in the hashtable
for (auto elem : table) {
for (auto elem : sorted_map_items) {
// We try to see if a semantically equivalent term is already in the resulting vector
auto it_found = std::find_if(result.begin(), result.end(),
[elem](std::pair<PrimExpr, size_t> already_seen) {
Expand All @@ -763,7 +777,6 @@ std::vector<std::pair<PrimExpr, size_t>> SyntacticToSemanticComputations(
result.push_back(elem);
}
}

return result;
}

Expand Down
48 changes: 48 additions & 0 deletions tests/python/unittest/test_tir_transform_common_subexpr_elim.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import hashlib

import tvm
from tvm import te
from tvm.ir.base import save_json
from tvm.ir.module import IRModule


# A test program which gives the opportunity for the CSE pass to introduce two new variables, at two different levels
def test_cse():
Expand Down Expand Up @@ -133,6 +138,49 @@ def test_cse():
assert isinstance(body.body, tvm.tir.BufferStore)


def test_deterministic_cse():
import random

"""Test deterministic allocation of CSE vars
We expect something like
result = (x + 1) + (x + 2) + (x + 3) + (x + 1) + (x + 2) + (x + 3)
-->
cse_var_3 = (x + 1)
cse_var_2 = (x + 2)
cse_var_1 = (x + 3)
result = cse_var_3 + cse_var_2 + cse_var_1 + cse_var_3 + cse_var_2 + cse_var_1
"""
NUM_TERMS = 10
REPEATS = 10

x = te.var("x")
result = te.var("result")

offsets = sorted([i + 1 for i in range(NUM_TERMS)])
inc1 = [(x + offsets[i]) for i in range(NUM_TERMS)]
inc2 = [(x + offsets[i]) for i in range(NUM_TERMS)]

expression = x
for add in inc1 + inc2:
expression = expression + add
let_stmt = tvm.tir.LetStmt(result, expression, tvm.tir.Evaluate(result))
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([x], let_stmt))

initial_hash = None
for _ in range(REPEATS):
body = tvm.tir.transform.CommonSubexprElimTIR()(mod)["main"]

# Hash and ensure serialize json is the same every time
json_val = save_json(body)
json_hash = hashlib.sha256(json_val.encode()).hexdigest()

if initial_hash is None:
initial_hash = json_hash
assert json_hash == initial_hash


# First specific test for if nodes : Some duplicated computations appear only in one branch (here the Then branch), not in both branches.
# In this case, the CSE pass should introduce the redundant computation at the top if the Then branch, not before the whole If
# (otherwise that would lead to some computations being computed for nothing when it is the Else branch that is executed).
Expand Down

0 comments on commit 69fbecd

Please sign in to comment.