From ab2dd2ef050327d5701a34fef55c49a1740dba20 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 17 Mar 2022 11:13:20 -0700 Subject: [PATCH 1/6] iterate through sorted keys --- .../transforms/common_subexpr_elim_tools.cc | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc b/src/tir/transforms/common_subexpr_elim_tools.cc index 218667c331a5..fa49b9ef2754 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.cc +++ b/src/tir/transforms/common_subexpr_elim_tools.cc @@ -748,8 +748,26 @@ std::vector> SyntacticToSemanticComputations( // We do this reservation even if it might reserve slightly more space than is needed in the end result.reserve(table.size()); - // For each element in the hashtable + // Traverse through keys in a sorted order to maintain deterministic behavior + // We do this by comparing the string repr of each PrimExpr to get a determinstic ordering + std::vector table_keys; + table_keys.reserve(table.size()); for (auto elem : table) { + table_keys.push_back(elem.first); + } + sort(table_keys.begin(), table_keys.end(), [](PrimExpr a, PrimExpr b) { + std::stringstream a_stream; + std::stringstream b_stream; + a_stream << a; + b_stream << b; + return a_stream.str().compare(b_stream.str()) < 0; + }); + + // For each element in the hashtable + for (PrimExpr key : table_keys) { + size_t value = table.find(key)->second; + std::pair elem = {key, value}; + // We try to see if a semantically equivalent term is already in the resulting vector auto it_found = std::find_if(result.begin(), result.end(), [elem](std::pair already_seen) { @@ -763,7 +781,6 @@ std::vector> SyntacticToSemanticComputations( result.push_back(elem); } } - return result; } From 4381ac304302400166852259ea6969686950e6f9 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 17 Mar 2022 12:29:10 -0700 Subject: [PATCH 2/6] masa comments -- simplify iteration --- .../transforms/common_subexpr_elim_tools.cc | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc b/src/tir/transforms/common_subexpr_elim_tools.cc index fa49b9ef2754..a6104011fcb5 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.cc +++ b/src/tir/transforms/common_subexpr_elim_tools.cc @@ -743,31 +743,30 @@ bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b) { std::vector> SyntacticToSemanticComputations( const ComputationTable& table) { std::vector> result; + // table.size() is an upper-bound of the number of elements in the resulting vector, // as we might merge semantically equivalent computations. // We do this reservation even if it might reserve slightly more space than is needed in the end result.reserve(table.size()); - // Traverse through keys in a sorted order to maintain deterministic behavior + // Traverse through map in a sorted order on keys to maintain deterministic behavior // We do this by comparing the string repr of each PrimExpr to get a determinstic ordering - std::vector table_keys; - table_keys.reserve(table.size()); + std::vector> sorted_map_items; + sorted_map_items.reserve(table.size()); for (auto elem : table) { - table_keys.push_back(elem.first); + sorted_map_items.push_back(elem); } - sort(table_keys.begin(), table_keys.end(), [](PrimExpr a, PrimExpr b) { - std::stringstream a_stream; - std::stringstream b_stream; - a_stream << a; - b_stream << b; - return a_stream.str().compare(b_stream.str()) < 0; - }); + sort(sorted_map_items.begin(), sorted_map_items.end(), + [](std::pair a, std::pair b) { + std::stringstream a_stream; + std::stringstream b_stream; + a_stream << a.first; + b_stream << b.first; + return a_stream.str().compare(b_stream.str()) < 0; + }); // For each element in the hashtable - for (PrimExpr key : table_keys) { - size_t value = table.find(key)->second; - std::pair elem = {key, value}; - + for (auto elem : sorted_map_items) { // We try to see if a semantically equivalent term is already in the resulting vector auto it_found = std::find_if(result.begin(), result.end(), [elem](std::pair already_seen) { From 404c4bbb011d103dbfc2be67a9e1a6db369ccd71 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 17 Mar 2022 13:48:58 -0700 Subject: [PATCH 3/6] test --- .../test_tir_transform_common_subexpr_elim.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py index 17c0cbdd99c6..2744e2c3ab51 100644 --- a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py @@ -14,8 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import hashlib + import tvm from tvm import te +from tvm.ir.base import save_json +from tvm.ir.module import IRModule + # A test program which gives the opportunity for the CSE pass to introduce two new variables, at two different levels def test_cse(): @@ -133,6 +138,50 @@ def test_cse(): assert isinstance(body.body, tvm.tir.BufferStore) +def test_deterministic_cse(): + import random + + """Test deterministic allocation of CSE vars + + We expect something like + + result = (x + 1) + (x + 2) + (x + 3) + (x + 1) + (x + 2) + (x + 3) + --> + cse_var_3 = (x + 1) + cse_var_2 = (x + 2) + cse_var_1 = (x + 3) + result = cse_var_3 + cse_var_2 + cse_var_1 + cse_var_3 + cse_var_2 + cse_var_1 + """ + NUM_TERMS = 10 + REPEATS = 10 + + x = te.var("x") + result = te.var("result") + + rand_ints = sorted([random.randint(1, 10) for i in range(NUM_TERMS)]) + inc1 = [(x + rand_ints[i]) for i in range(NUM_TERMS)] + inc2 = [(x + rand_ints[i]) for i in range(NUM_TERMS)] + + expression = x + for add in inc1 + inc2: + expression = expression + add + let_stmt = tvm.tir.LetStmt(result, expression, tvm.tir.Evaluate(result)) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([x], let_stmt)) + + initial_hash = None + for _ in range(REPEATS): + body = tvm.tir.transform.CommonSubexprElimTIR()(mod)["main"] + print(body) + + # Hash and ensure serialize json is the same every time + json_val = save_json(body) + json_hash = hashlib.sha256(json_val.encode()).hexdigest() + + if initial_hash is None: + initial_hash = json_hash + assert json_hash == initial_hash + + # First specific test for if nodes : Some duplicated computations appear only in one branch (here the Then branch), not in both branches. # In this case, the CSE pass should introduce the redundant computation at the top if the Then branch, not before the whole If # (otherwise that would lead to some computations being computed for nothing when it is the Else branch that is executed). From d104b1cbf31d7e853a92ca88695fd6708f47d15c Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 17 Mar 2022 13:49:46 -0700 Subject: [PATCH 4/6] tests --- .../unittest/test_tir_transform_common_subexpr_elim.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py index 2744e2c3ab51..c12e27a46e3f 100644 --- a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py @@ -158,9 +158,9 @@ def test_deterministic_cse(): x = te.var("x") result = te.var("result") - rand_ints = sorted([random.randint(1, 10) for i in range(NUM_TERMS)]) - inc1 = [(x + rand_ints[i]) for i in range(NUM_TERMS)] - inc2 = [(x + rand_ints[i]) for i in range(NUM_TERMS)] + offsets = sorted([i + 1 for i in range(NUM_TERMS)]) + inc1 = [(x + offsets[i]) for i in range(NUM_TERMS)] + inc2 = [(x + offsets[i]) for i in range(NUM_TERMS)] expression = x for add in inc1 + inc2: @@ -171,7 +171,6 @@ def test_deterministic_cse(): initial_hash = None for _ in range(REPEATS): body = tvm.tir.transform.CommonSubexprElimTIR()(mod)["main"] - print(body) # Hash and ensure serialize json is the same every time json_val = save_json(body) From d0a0b7f2832c33afd0fa8ee1f1f2d03f25270b1f Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 17 Mar 2022 13:51:02 -0700 Subject: [PATCH 5/6] simplify vector construciton --- src/tir/transforms/common_subexpr_elim_tools.cc | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc b/src/tir/transforms/common_subexpr_elim_tools.cc index a6104011fcb5..d39d211ba182 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.cc +++ b/src/tir/transforms/common_subexpr_elim_tools.cc @@ -751,11 +751,8 @@ std::vector> SyntacticToSemanticComputations( // Traverse through map in a sorted order on keys to maintain deterministic behavior // We do this by comparing the string repr of each PrimExpr to get a determinstic ordering - std::vector> sorted_map_items; - sorted_map_items.reserve(table.size()); - for (auto elem : table) { - sorted_map_items.push_back(elem); - } + std::vector> sorted_map_items(table.begin(), table.end()); + sort(sorted_map_items.begin(), sorted_map_items.end(), [](std::pair a, std::pair b) { std::stringstream a_stream; From ac6a4ea30d90b18a189b0974bec753e708614350 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Thu, 17 Mar 2022 15:08:33 -0700 Subject: [PATCH 6/6] jostle ci