diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index 726d5d1c988c..c2c158c77f78 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -179,7 +179,7 @@ def offset_of(self, indices): def __getitem__(self, indices): from ..arith import Analyzer # pylint: disable=import-outside-toplevel - from .expr import BufferLoad, Ramp # pylint: disable=import-outside-toplevel + from .expr import BufferLoad, Ramp, const # pylint: disable=import-outside-toplevel from .stmt import BufferRegion # pylint: disable=import-outside-toplevel if not isinstance(indices, (tuple, list)): @@ -195,7 +195,11 @@ def __getitem__(self, indices): stop = self.shape[i] if index.stop is None else index.stop region.append(Range.from_min_extent(start, analyzer.simplify(stop - start))) else: - region.append(Range.from_min_extent(index, 1)) + region.append( + Range.from_min_extent( + index, const(1, index.dtype) if isinstance(index, PrimExpr) else 1 + ) + ) return BufferRegion(self, region) else: expr_indices = [] diff --git a/tests/python/unittest/test_tvmscript_regression.py b/tests/python/unittest/test_tvmscript_regression.py index d063c0fcab7f..44d3036596ba 100644 --- a/tests/python/unittest/test_tvmscript_regression.py +++ b/tests/python/unittest/test_tvmscript_regression.py @@ -17,6 +17,7 @@ import numpy import tvm +import tvm.testing from tvm.script import tir as T @@ -73,9 +74,17 @@ def func_ref(): tvm.ir.assert_structural_equal(test_case, func_ref) +def test_tir_buffer_region_extent_correct_dtype(): + @T.prim_func + def func(A: T.Buffer[(T.int64(16), T.int64(1)), "float32"]): + for i in T.grid(T.int64(16)): + with T.block("block"): + vi = T.axis.remap("S", [i]) + T.reads(A[vi, T.int64(0) : T.int64(1)]) + T.evaluate(0) + + assert func.body.block.body.body.block.reads[0].region[0].extent.dtype == "int64" + + if __name__ == "__main__": - a = numpy.zeros((10, 10), dtype="int8") - test_multi_element_array_in_outmost_namespace() - test_different_dtype_assignment_to_var() - b = 1 - test_var_capturing_order() + tvm.testing.main()