Skip to content

Commit

Permalink
[BugFix][TVMScript] Fix the roundtripability of intrinsic pow (#13692)
Browse files Browse the repository at this point in the history
* Fix the roundtripability of pow intrinsic.

* fix the lint.

* Fix the lint.

* add tir.pow to make it consistent.

Co-authored-by: lightzhan-intellif <zhan.liang@intellif.com>
  • Loading branch information
lightzhan-intellif and lightzhan-intellif authored Jan 5, 2023
1 parent 07a5a9e commit 048028b
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 3 deletions.
4 changes: 2 additions & 2 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,7 +1532,7 @@ def wrapped(*args, **kwargs):
nearbyint = _op_wrapper(_tir_op.nearbyint)
nextafter = _op_wrapper(_tir_op.nextafter)
popcount = _op_wrapper(_tir_op.popcount)
power = _op_wrapper(_tir_op.power)
pow = _op_wrapper(_tir_op.pow) # pylint: disable=redefined-builtin
q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift)
q_multiply_shift_per_axis = _op_wrapper(_tir_op.q_multiply_shift_per_axis)
ret = _op_wrapper(_tir_op.ret)
Expand Down Expand Up @@ -1713,7 +1713,7 @@ def f():
"nearbyint",
"nextafter",
"popcount",
"power",
"pow",
"q_multiply_shift",
"q_multiply_shift_per_axis",
"ret",
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
from .op import cos, cosh, acos, acosh
from .op import tan, tanh, atan, atan2, atanh
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else
from .op import trunc, abs, round, nextafter, nearbyint, power, pow, popcount, fmod, if_then_else
from .op import likely, isnan, isnullptr, isfinite, isinf, copysign
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv
from .op import comm_reducer, min, max, sum
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2236,6 +2236,28 @@ def power(x, y, span=None):
return _ffi_api._OpPow(convert(x), convert(y), span) # type: ignore


def pow(x, y, span=None):
"""x power y
Parameters
----------
x : PrimExpr
Input argument.
y : PrimExpr
The exponent
span : Optional[Span]
The location of this operator in the source code.
Returns
-------
z : PrimExpr
The result.
"""
return _ffi_api._OpPow(convert(x), convert(y), span) # type: ignore


def popcount(x):
"""Count the number of set bits in input x.
Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3550,6 +3550,14 @@ def func(A: T.Buffer[128, "float32"], C: T.Buffer[128, "float32"]):
return mod["main"]


def intrinsic_pow():
@T.prim_func
def func():
T.pow(T.float32(1), T.float32(1))

return func


ir_generator = tvm.testing.parameter(
opt_gemm_normalize,
opt_gemm_lower,
Expand Down Expand Up @@ -3607,6 +3615,7 @@ def func(A: T.Buffer[128, "float32"], C: T.Buffer[128, "float32"]):
elif_chain_with_else,
*nested_boolean_expressions(),
multi_env_threads,
intrinsic_pow,
)


Expand Down

0 comments on commit 048028b

Please sign in to comment.