Skip to content

Commit

Permalink
Bug-fix] Fix tir allocation with multiple lanes (apache#6941)
Browse files Browse the repository at this point in the history
* Bug-fix] Fix tir allocation with multiple lanes

This PR stemmed from apache#6907
and it is fixing a small error in the getter and setter of a buffer for
the case where `t.lanes > 1`. I also added a test to stress the issue.

* Address dtyped vs non-dtyped constant cases
  • Loading branch information
Giuseppe Rossini authored and Trevor Morris committed Dec 4, 2020
1 parent 4e238e8 commit 5ec48f5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
6 changes: 4 additions & 2 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def __getitem__(self, index):
index = self._linear_index(index)
if t.lanes > 1:
base = index * t.lanes
index = _expr.Ramp(base, const(1, base.dtype), t.lanes)
stride = 1 if (not hasattr(base, "dtype")) else const(1, base.dtype)
index = _expr.Ramp(base, stride, t.lanes)
return _expr.Load(self._content_type, self._buffer_var, index)

def __setitem__(self, index, value):
Expand All @@ -116,7 +117,8 @@ def __setitem__(self, index, value):
t = DataType(self._content_type)
if t.lanes > 1:
base = index * t.lanes
index = _expr.Ramp(base, const(1, base.dtype), t.lanes)
stride = 1 if (not hasattr(base, "dtype")) else const(1, base.dtype)
index = _expr.Ramp(base, stride, t.lanes)
self._builder.emit(_stmt.Store(self._buffer_var, value, index))


Expand Down
3 changes: 2 additions & 1 deletion tests/python/unittest/test_tir_transform_narrow_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,10 @@ def check(m, lanes, target_bits, target_dtype):
B = ib.buffer_ptr(Bb)
with ib.for_range(0, m, name="i", dtype=m.dtype) as i:
B[i] = A[i] + 1
A[0] = B[1]
stmt = ib.get()
stmt = lower_stmt([Ab, Bb], stmt, target_bits)
assert stmt.loop_var.dtype == target_dtype
assert stmt.seq[0].loop_var.dtype == target_dtype

# i32 -> i32
check(const(2 ** 10, dtype="int32"), 2, target_bits=32, target_dtype="int32")
Expand Down

0 comments on commit 5ec48f5

Please sign in to comment.