diff --git a/iree/turbine/kernel/compiler/kernel_codegen.py b/iree/turbine/kernel/compiler/kernel_codegen.py index 67ca217f2..5574dcba4 100644 --- a/iree/turbine/kernel/compiler/kernel_codegen.py +++ b/iree/turbine/kernel/compiler/kernel_codegen.py @@ -279,10 +279,12 @@ def only_write_dependencies(node): # Create new Memory type with the correct usage memory_type = self.bindings[index].kernel_buffer_type self.bindings[index].kernel_buffer_type = Memory[ - *memory_type.symbolic_shape, - memory_type.address_space, - memory_type.dtype, - usage, + ( + *memory_type.symbolic_shape, + memory_type.address_space, + memory_type.dtype, + usage, + ) ] return diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 30f7241ce..30f8bb50c 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -792,7 +792,7 @@ def indexing_dims(self) -> list[IndexSymbol]: @property def type(self) -> "Memory": - return Memory[*self.shape, self.address_space, self.dtype] + return Memory[(*self.shape, self.address_space, self.dtype)] @define_op("shared_memory_barrier") @@ -855,7 +855,7 @@ def indexing_dims(self) -> list[IndexSymbol]: return list(self.shape) def infer_type(self): - self.type = Register[*self.shape, self.dtype] + self.type = Register[(*self.shape, self.dtype)] @define_op("mma") @@ -960,7 +960,7 @@ def indexing_dims(self) -> list[IndexSymbol]: def infer_type(self): dtype = self.memory_type.dtype - self.type = Register[*self.indexing_dims, dtype] + self.type = Register[(*self.indexing_dims, dtype)] @property def memory_type(self) -> "Memory": @@ -1168,7 +1168,7 @@ def indexing_dims(self) -> list[IndexSymbol]: def infer_type(self): address_space = self.memory_type.address_space dtype = self.memory_type.dtype - self.type = Memory[*self.indexing_dims, address_space, dtype] + self.type = Memory[(*self.indexing_dims, address_space, dtype)] @property def memory_type(self) -> "Memory": @@ -1304,7 +1304,7 @@ def infer_type(self): dst_shape = list(src_type.symbolic_shape) dim_to_remove = dst_shape[-1] if not non_unit_dim else non_unit_dim[0] dst_shape.remove(dim_to_remove) - dst_type = Register[*dst_shape, src_type.dtype] + dst_type = Register[(*dst_shape, src_type.dtype)] self.type = dst_type @@ -1354,7 +1354,7 @@ def indexing_dims(self) -> list[IndexSymbol]: def infer_type(self): src_dtype = get_custom(self.arg).type.dtype - self.type = Register[*self.target_shape, src_dtype] + self.type = Register[(*self.target_shape, src_dtype)] @define_interface_op("max") @@ -1406,7 +1406,7 @@ def infer_type(self): else: src_type = get_custom(self.arg).type reduced_dims = [dims for dims in src_type.symbolic_shape if dims != self.dim] - dst_type = Register[*reduced_dims, src_type.dtype] + dst_type = Register[(*reduced_dims, src_type.dtype)] self.type = dst_type @property @@ -1465,7 +1465,7 @@ def indexing_dims(self) -> list[IndexSymbol]: def infer_type(self): src_shape = get_custom(self.arg).type.symbolic_shape - self.type = Register[*src_shape, self.dtype] + self.type = Register[(*src_shape, self.dtype)] @define_op("permute") @@ -1488,7 +1488,7 @@ def infer_type(self): assert set(src_type.symbolic_shape) == set( self.target_shape ), f"Target shape {self.target_shape} must be a permutation of source shape {src_type.symbolic_shape}" - self.type = Register[*self.target_shape, src_type.dtype] + self.type = Register[(*self.target_shape, src_type.dtype)] def transform_index( self, index: dict[IndexSymbol, IndexSequence]