Skip to content

Commit

Permalink
Avoid Python 3.11+ starred expressions in indexes.
Browse files Browse the repository at this point in the history
  • Loading branch information
ScottTodd committed Dec 10, 2024
1 parent b2dd271 commit 74194fd
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
10 changes: 6 additions & 4 deletions iree/turbine/kernel/compiler/kernel_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,12 @@ def only_write_dependencies(node):
# Create new Memory type with the correct usage
memory_type = self.bindings[index].kernel_buffer_type
self.bindings[index].kernel_buffer_type = Memory[
*memory_type.symbolic_shape,
memory_type.address_space,
memory_type.dtype,
usage,
(
*memory_type.symbolic_shape,
memory_type.address_space,
memory_type.dtype,
usage,
)
]
return

Expand Down
18 changes: 9 additions & 9 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ def indexing_dims(self) -> list[IndexSymbol]:

@property
def type(self) -> "Memory":
return Memory[*self.shape, self.address_space, self.dtype]
return Memory[(*self.shape, self.address_space, self.dtype)]


@define_op("shared_memory_barrier")
Expand Down Expand Up @@ -855,7 +855,7 @@ def indexing_dims(self) -> list[IndexSymbol]:
return list(self.shape)

def infer_type(self):
self.type = Register[*self.shape, self.dtype]
self.type = Register[(*self.shape, self.dtype)]


@define_op("mma")
Expand Down Expand Up @@ -960,7 +960,7 @@ def indexing_dims(self) -> list[IndexSymbol]:

def infer_type(self):
dtype = self.memory_type.dtype
self.type = Register[*self.indexing_dims, dtype]
self.type = Register[(*self.indexing_dims, dtype)]

@property
def memory_type(self) -> "Memory":
Expand Down Expand Up @@ -1168,7 +1168,7 @@ def indexing_dims(self) -> list[IndexSymbol]:
def infer_type(self):
address_space = self.memory_type.address_space
dtype = self.memory_type.dtype
self.type = Memory[*self.indexing_dims, address_space, dtype]
self.type = Memory[(*self.indexing_dims, address_space, dtype)]

@property
def memory_type(self) -> "Memory":
Expand Down Expand Up @@ -1304,7 +1304,7 @@ def infer_type(self):
dst_shape = list(src_type.symbolic_shape)
dim_to_remove = dst_shape[-1] if not non_unit_dim else non_unit_dim[0]
dst_shape.remove(dim_to_remove)
dst_type = Register[*dst_shape, src_type.dtype]
dst_type = Register[(*dst_shape, src_type.dtype)]
self.type = dst_type


Expand Down Expand Up @@ -1354,7 +1354,7 @@ def indexing_dims(self) -> list[IndexSymbol]:

def infer_type(self):
src_dtype = get_custom(self.arg).type.dtype
self.type = Register[*self.target_shape, src_dtype]
self.type = Register[(*self.target_shape, src_dtype)]


@define_interface_op("max")
Expand Down Expand Up @@ -1406,7 +1406,7 @@ def infer_type(self):
else:
src_type = get_custom(self.arg).type
reduced_dims = [dims for dims in src_type.symbolic_shape if dims != self.dim]
dst_type = Register[*reduced_dims, src_type.dtype]
dst_type = Register[(*reduced_dims, src_type.dtype)]
self.type = dst_type

@property
Expand Down Expand Up @@ -1465,7 +1465,7 @@ def indexing_dims(self) -> list[IndexSymbol]:

def infer_type(self):
src_shape = get_custom(self.arg).type.symbolic_shape
self.type = Register[*src_shape, self.dtype]
self.type = Register[(*src_shape, self.dtype)]


@define_op("permute")
Expand All @@ -1488,7 +1488,7 @@ def infer_type(self):
assert set(src_type.symbolic_shape) == set(
self.target_shape
), f"Target shape {self.target_shape} must be a permutation of source shape {src_type.symbolic_shape}"
self.type = Register[*self.target_shape, src_type.dtype]
self.type = Register[(*self.target_shape, src_type.dtype)]

def transform_index(
self, index: dict[IndexSymbol, IndexSequence]
Expand Down

0 comments on commit 74194fd

Please sign in to comment.