Skip to content

Commit

Permalink
stricten checks and add fuzzing tests
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
  • Loading branch information
Hardcode84 committed Dec 4, 2024
1 parent 565f81f commit ce8b6de
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 10 deletions.
22 changes: 17 additions & 5 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,13 +1071,19 @@ def check_mul(mul):
ret = None
for arg in mul.args:
if arg.is_number:
if arg < 0:
return None

if ret is not None:
return None

ret = arg
continue

if not isinstance(arg, (sympy.floor, sympy.Mod)):
if not (isinstance(arg, (sympy.floor, sympy.Mod)) or arg.is_integer):
return None

if not arg.is_nonnegative:
return None

return ret
Expand All @@ -1092,7 +1098,7 @@ def transform_mod(expr):
return None

p, q = expr.args
if not q.is_number:
if not q.is_number or q < 0:
return None

if not isinstance(p, sympy.Add):
Expand Down Expand Up @@ -1131,10 +1137,16 @@ def check_mul_rational(mul):
if ret is not None:
return None

if arg.p < 0 or arg.q < 0:
return None

ret = arg
continue

if not isinstance(arg, (sympy.floor, sympy.Mod)):
if not (isinstance(arg, (sympy.floor, sympy.Mod)) or arg.is_integer):
return None

if not arg.is_nonnegative:
return None

return ret
Expand Down Expand Up @@ -1172,10 +1184,10 @@ def transform_floor(expr):
return None

r = check_mul_rational(arg)
if r is None:
if r is None or r.p != 1:
return None

if r < c:
if r <= c:
return None

terms.append(arg)
Expand Down
4 changes: 3 additions & 1 deletion iree/turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ def is_reduction(node: fx.Node):
reduction_nodes = trace.walk(is_reduction)
for node in reduction_nodes:
custom = get_custom(node)
self.induction_vars[custom] = tkl.IndexSymbol("$ARG" + str(custom.axis))
self.induction_vars[custom] = tkl.IndexSymbol(
"$ARG" + str(custom.axis), integer=True, nonnegative=True
)
for tiling_constraint in self.tiling_constraints:
if tiling_constraint.dim == custom.axis:
tiling_constraint.induction_var = self.induction_vars[custom]
Expand Down
112 changes: 108 additions & 4 deletions tests/kernel/wave/wave_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import pytest
from iree.turbine.kernel.lang import sym
from iree.turbine.kernel.wave.utils import delinearize_index, _simplify_sympy_expr
import sympy
Expand All @@ -20,8 +21,111 @@ def test_delinearize_index():


def test_custom_sympy_simplifications():
mod_expr = sympy.sympify("(floor(a) * 4 + 3) % 16")
assert str(_simplify_sympy_expr(mod_expr)) == "4*(Mod(floor(a), 4)) + 3"
a = sympy.Symbol("a", integer=True, nonnegative=True)
mod_expr = (sympy.floor(a) * 4 + 3) % 16
assert str(_simplify_sympy_expr(mod_expr)) == "4*(Mod(a, 4)) + 3"

floor_expr = sympy.sympify("floor(floor(a)/3 + 1/6)")
assert str(_simplify_sympy_expr(floor_expr)) == "floor(floor(a)/3)"
floor_expr = sympy.floor(sympy.floor(a) / 3 + sympy.sympify(1) / 6)
assert str(_simplify_sympy_expr(floor_expr)) == "floor(a/3)"


@pytest.mark.skip("Too slow")
def test_fuzz_custom_sympy_simplifications_mod():
x = sympy.Symbol("x", integer=True, nonnegative=True)
a = sympy.Symbol("a")
b = sympy.Symbol("b")
c = sympy.Symbol("c")

import random

expr = (sympy.floor(x) * a + b) % c
total = 0
outer_num_iters = 1000
for i in range(outer_num_iters):

a_val = random.randint(2, 50)
b_val = random.randint(1, a_val - 1)
c_val = a_val * random.randint(1, 10)

vals = [a_val, b_val, c_val]
expr = expr.subs({a: vals[0], b: vals[1], c: vals[2]})
expr = sympy.simplify(expr)

expr2 = _simplify_sympy_expr(expr)

if i % 50 == 0 and i > 0:
print(f"{100*i/outer_num_iters}%")

if expr == expr2:
print("skip", vals)
continue

vals2 = vals + [0, 1]
for j in range(100):
val = vals2[j] if j < len(vals2) else random.randint(0, c_val * 2)
if expr.subs({x: val}) != expr2.subs({x: val}):
print(f"Failed: {vals}, {val}")

assert expr.subs({x: val}) == expr2.subs({x: val})
total += 1

print(f"Sucess: {total} checks")


@pytest.mark.skip("Too slow")
def test_fuzz_custom_sympy_simplifications_floor():
x = sympy.Symbol("x", integer=True, nonnegative=True)
a = sympy.Symbol("a")
b = sympy.Symbol("b")
c = sympy.Symbol("c")
d = sympy.Symbol("d")

import random

orig_expr = sympy.floor(sympy.floor(x) * a / b + c / d)

def check_specific(*vals):
expr1 = orig_expr.subs({a: vals[0], b: vals[1], c: vals[2], d: vals[3]})
expr1 = sympy.simplify(expr1)

expr2 = _simplify_sympy_expr(expr1)
assert expr1.subs({x: vals[4]}) == expr2.subs({x: vals[4]})

check_specific(10, 11, 6, 10, 6)
check_specific(8, 5, 1, 5, 8)

total = 0
outer_num_iters = 500
for i in range(outer_num_iters):
while True:
a_val = 1 # random.randint(1, 10)
b_val = random.randint(1, 10)
if b_val == a_val:
b_val += 1

c_val = random.randint(1, 10)
d_val = random.randint(1, 10)
if d_val == c_val:
d_val += 1

vals = [a_val, b_val, c_val, d_val]
expr = orig_expr.subs({a: vals[0], b: vals[1], c: vals[2], d: vals[3]})
expr = sympy.simplify(expr)

expr2 = _simplify_sympy_expr(expr)
if expr != expr2:
break

if i % 50 == 0 and i > 0:
print(f"{100*i/outer_num_iters}%")

vals2 = vals + [-1, 0, 1]
for j in range(100):
val = vals2[j] if j < len(vals2) else random.randint(0, c_val * 2)
if expr.subs({x: val}) != expr2.subs({x: val}):
print(f"Failed: {vals}, {val}")

assert expr.subs({x: val}) == expr2.subs({x: val})
total += 1

print(f"Sucess: {total} checks")

0 comments on commit ce8b6de

Please sign in to comment.