diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 77fe79b327b6b..75c5c2921ff4f 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -103,7 +103,8 @@ def __getitem__(self, index): index = self._linear_index(index) if t.lanes > 1: base = index * t.lanes - index = _expr.Ramp(base, const(1, base.dtype), t.lanes) + stride = 1 if (not hasattr(base, "dtype")) else const(1, base.dtype) + index = _expr.Ramp(base, stride, t.lanes) return _expr.Load(self._content_type, self._buffer_var, index) def __setitem__(self, index, value): @@ -116,7 +117,8 @@ def __setitem__(self, index, value): t = DataType(self._content_type) if t.lanes > 1: base = index * t.lanes - index = _expr.Ramp(base, const(1, base.dtype), t.lanes) + stride = 1 if (not hasattr(base, "dtype")) else const(1, base.dtype) + index = _expr.Ramp(base, stride, t.lanes) self._builder.emit(_stmt.Store(self._buffer_var, value, index)) diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index b1a9eae7893a3..cb8968cfc8809 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -126,9 +126,10 @@ def check(m, lanes, target_bits, target_dtype): B = ib.buffer_ptr(Bb) with ib.for_range(0, m, name="i", dtype=m.dtype) as i: B[i] = A[i] + 1 + A[0] = B[1] stmt = ib.get() stmt = lower_stmt([Ab, Bb], stmt, target_bits) - assert stmt.loop_var.dtype == target_dtype + assert stmt.seq[0].loop_var.dtype == target_dtype # i32 -> i32 check(const(2 ** 10, dtype="int32"), 2, target_bits=32, target_dtype="int32")