diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 306ec91c2..3cdd86415 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -22,11 +22,12 @@ concurrency: jobs: test: - name: "Unit Tests and Type Checking" + name: "${{ matrix.os }} :: ${{ matrix.version }} :: Unit Tests and Type Checking" strategy: fail-fast: false matrix: - version: [3.11] + # Support for Python 3.13 depends on https://github.com/pytorch/pytorch/issues/130249 + version: ["3.10", "3.11", "3.12"] os: [ubuntu-22.04] runs-on: ${{matrix.os}} env: @@ -71,3 +72,23 @@ jobs: if: ${{ !cancelled() }} run: | mypy + + # Depends on all other jobs to provide an aggregate job status. + ci_summary: + if: always() + runs-on: ubuntu-20.04 + needs: + - test + steps: + - name: Getting failed jobs + run: | + echo '${{ toJson(needs) }}' + FAILED_JOBS="$(echo '${{ toJson(needs) }}' \ + | jq --raw-output \ + 'map_values(select(.result!="success" and .result!="skipped")) | keys | join(",")' \ + )" + echo "failed-jobs=${FAILED_JOBS}" >> $GITHUB_OUTPUT + if [[ "${FAILED_JOBS}" != "" ]]; then + echo "The following jobs failed: ${FAILED_JOBS}" + exit 1 + fi diff --git a/iree/turbine/kernel/compiler/kernel_codegen.py b/iree/turbine/kernel/compiler/kernel_codegen.py index 67ca217f2..5574dcba4 100644 --- a/iree/turbine/kernel/compiler/kernel_codegen.py +++ b/iree/turbine/kernel/compiler/kernel_codegen.py @@ -279,10 +279,12 @@ def only_write_dependencies(node): # Create new Memory type with the correct usage memory_type = self.bindings[index].kernel_buffer_type self.bindings[index].kernel_buffer_type = Memory[ - *memory_type.symbolic_shape, - memory_type.address_space, - memory_type.dtype, - usage, + ( + *memory_type.symbolic_shape, + memory_type.address_space, + memory_type.dtype, + usage, + ) ] return diff --git a/iree/turbine/kernel/lang/wave_types.py b/iree/turbine/kernel/lang/wave_types.py index de6139df5..f87a95702 100644 --- a/iree/turbine/kernel/lang/wave_types.py +++ b/iree/turbine/kernel/lang/wave_types.py @@ -4,7 +4,6 @@ ClassVar, Iterable, Optional, - Self, Type, TypeAlias, TypeVar, @@ -17,6 +16,7 @@ from sympy import Symbol from sympy.core.expr import Expr +from typing_extensions import Self from itertools import chain diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 912fadd4f..30f8bb50c 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -9,12 +9,12 @@ Any, Callable, Optional, - Self, Sequence, Type, TypeVar, final, ) +from typing_extensions import Self import torch.fx as fx from ..lang.wave_types import Memory, Register, IndexMapping @@ -792,7 +792,7 @@ def indexing_dims(self) -> list[IndexSymbol]: @property def type(self) -> "Memory": - return Memory[*self.shape, self.address_space, self.dtype] + return Memory[(*self.shape, self.address_space, self.dtype)] @define_op("shared_memory_barrier") @@ -855,7 +855,7 @@ def indexing_dims(self) -> list[IndexSymbol]: return list(self.shape) def infer_type(self): - self.type = Register[*self.shape, self.dtype] + self.type = Register[(*self.shape, self.dtype)] @define_op("mma") @@ -960,7 +960,7 @@ def indexing_dims(self) -> list[IndexSymbol]: def infer_type(self): dtype = self.memory_type.dtype - self.type = Register[*self.indexing_dims, dtype] + self.type = Register[(*self.indexing_dims, dtype)] @property def memory_type(self) -> "Memory": @@ -1168,7 +1168,7 @@ def indexing_dims(self) -> list[IndexSymbol]: def infer_type(self): address_space = self.memory_type.address_space dtype = self.memory_type.dtype - self.type = Memory[*self.indexing_dims, address_space, dtype] + self.type = Memory[(*self.indexing_dims, address_space, dtype)] @property def memory_type(self) -> "Memory": @@ -1304,7 +1304,7 @@ def infer_type(self): dst_shape = list(src_type.symbolic_shape) dim_to_remove = dst_shape[-1] if not non_unit_dim else non_unit_dim[0] dst_shape.remove(dim_to_remove) - dst_type = Register[*dst_shape, src_type.dtype] + dst_type = Register[(*dst_shape, src_type.dtype)] self.type = dst_type @@ -1354,7 +1354,7 @@ def indexing_dims(self) -> list[IndexSymbol]: def infer_type(self): src_dtype = get_custom(self.arg).type.dtype - self.type = Register[*self.target_shape, src_dtype] + self.type = Register[(*self.target_shape, src_dtype)] @define_interface_op("max") @@ -1406,7 +1406,7 @@ def infer_type(self): else: src_type = get_custom(self.arg).type reduced_dims = [dims for dims in src_type.symbolic_shape if dims != self.dim] - dst_type = Register[*reduced_dims, src_type.dtype] + dst_type = Register[(*reduced_dims, src_type.dtype)] self.type = dst_type @property @@ -1465,7 +1465,7 @@ def indexing_dims(self) -> list[IndexSymbol]: def infer_type(self): src_shape = get_custom(self.arg).type.symbolic_shape - self.type = Register[*src_shape, self.dtype] + self.type = Register[(*src_shape, self.dtype)] @define_op("permute") @@ -1488,7 +1488,7 @@ def infer_type(self): assert set(src_type.symbolic_shape) == set( self.target_shape ), f"Target shape {self.target_shape} must be a permutation of source shape {src_type.symbolic_shape}" - self.type = Register[*self.target_shape, src_type.dtype] + self.type = Register[(*self.target_shape, src_type.dtype)] def transform_index( self, index: dict[IndexSymbol, IndexSequence] diff --git a/iree/turbine/kernel/wave/scheduling/schedule.py b/iree/turbine/kernel/wave/scheduling/schedule.py index b8a34191a..93cfe52ea 100644 --- a/iree/turbine/kernel/wave/scheduling/schedule.py +++ b/iree/turbine/kernel/wave/scheduling/schedule.py @@ -105,7 +105,7 @@ def schedule_reduction( # is not dynamic. max_induction_variable = int(max_induction_variable) if max_induction_variable <= scheduler.num_stages - 1: - logger.warn( + logger.warning( "Not enough iterations to pipeline the loop. Skipping pipelining." ) return {} @@ -113,7 +113,7 @@ def schedule_reduction( # Otherwise, we need to rely on assumptions provided by the author. assumptions = get_assumptions(constraints) if not assumptions: - logger.warn( + logger.warning( "No assumptions provided to determine if the loop can be pipelined. Skipping pipelining." ) return {} @@ -122,7 +122,7 @@ def schedule_reduction( constraints, max_induction_variable > scheduler.num_stages - 1 ) if not result: - logger.warn( + logger.warning( "Not enough iterations to pipeline the loop. Skipping pipelining." ) return {} diff --git a/iree/turbine/support/debugging.py b/iree/turbine/support/debugging.py index daa6ead03..5078ddc07 100644 --- a/iree/turbine/support/debugging.py +++ b/iree/turbine/support/debugging.py @@ -11,6 +11,7 @@ import logging import re import os +import sys import torch import numpy as np @@ -54,7 +55,7 @@ class DebugFlags: def set(self, part: str): m = re.match(SETTING_PART_PATTERN, part) if not m: - logger.warn("Syntax error in %s flag: '%s'", FLAGS_ENV_NAME, part) + logger.warning("Syntax error in %s flag: '%s'", FLAGS_ENV_NAME, part) return name = m.group(2) value = m.group(4) @@ -64,11 +65,14 @@ def set(self, part: str): logical_sense = m.group(1) != "-" if name == "log_level": - log_level_mapping = logging.getLevelNamesMapping() - try: - self.log_level = log_level_mapping[value.upper()] - except KeyError: - logger.warn("Log level '%s' unknown (ignored)", value) + if sys.version_info >= (3, 11): + log_level_mapping = logging.getLevelNamesMapping() # Added in 3.11 + try: + self.log_level = log_level_mapping[value.upper()] + except KeyError: + logger.warning("Log level '%s' unknown (ignored)", value) + else: + logger.warning("'log_level' flag requires Python >= 3.11") elif name == "asserts": self.asserts = logical_sense global NDEBUG @@ -76,7 +80,7 @@ def set(self, part: str): elif name == "runtime_trace_dir": self.runtime_trace_dir = value else: - logger.warn("Unrecognized %s flag: '%s'", FLAGS_ENV_NAME, name) + logger.warning("Unrecognized %s flag: '%s'", FLAGS_ENV_NAME, name) @staticmethod def parse(settings: str) -> "DebugFlags": diff --git a/iree/turbine/tools/interpreter.py b/iree/turbine/tools/interpreter.py index 5022933d0..7f901eb40 100644 --- a/iree/turbine/tools/interpreter.py +++ b/iree/turbine/tools/interpreter.py @@ -151,7 +151,7 @@ def callback(self, op: Operation) -> None: offset = [0 for _ in range(len(load_indices))] for i in range(*result_shape): ind = [int(x) + y for x, y in zip(load_indices, offset)] - value[i] = memref[*ind] + value[i] = memref[(*ind,)] offset[-1] += 1 case vector_d.ExtractStridedSliceOp: vector = self.symbol_table[op.vector] @@ -168,7 +168,7 @@ def callback(self, op: Operation) -> None: offset = [0 for _ in range(len(store_indices))] for i in range(*result_shape): memref[ - *[int(x) + y for x, y in zip(store_indices, offset)] + (*[int(x) + y for x, y in zip(store_indices, offset)],) ] = vector[i] offset[-1] += 1 case vector_d.MaskedStoreOp: @@ -185,7 +185,7 @@ def callback(self, op: Operation) -> None: for i in range(*result_shape): if mask[i]: ind = [int(x) + y for x, y in zip(store_indices, offset)] - memref[*ind] = vector[i] + memref[(*ind,)] = vector[i] offset[-1] += 1 case vector_d.ConstantMaskOp: @@ -313,7 +313,7 @@ def interpret_ndrange( ): for wg in np.ndindex(*workgroup_count): for t in np.ndindex(*workgroup_size): - Interpreter([*wg], [*t]).interpret(asm) + Interpreter([(*wg,)], [(*t,)]).interpret(asm) if __name__ == "__main__": diff --git a/lit_tests/kernel/wave/barriers.py b/lit_tests/kernel/wave/barriers.py index f50c6facf..ac8cd6f52 100644 --- a/lit_tests/kernel/wave/barriers.py +++ b/lit_tests/kernel/wave/barriers.py @@ -1,8 +1,6 @@ # RUN: python %s | FileCheck %s import logging -from typing import Callable -import unittest import iree.turbine.kernel as tk import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw @@ -262,4 +260,3 @@ def test_gemm(): if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index 58e6bab96..129085dfa 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -1,7 +1,6 @@ # RUN: python %s | FileCheck %s import logging -import unittest import iree.turbine.kernel as tk import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw @@ -1109,4 +1108,3 @@ def test_chained_gemm_32x32x8(): if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/lit_tests/kernel/wave/index_sequence_analysis.py b/lit_tests/kernel/wave/index_sequence_analysis.py index 8ca0b9857..814c0089e 100644 --- a/lit_tests/kernel/wave/index_sequence_analysis.py +++ b/lit_tests/kernel/wave/index_sequence_analysis.py @@ -1,7 +1,6 @@ # RUN: python %s | FileCheck %s import logging -import unittest import iree.turbine.kernel as tk import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw @@ -350,4 +349,3 @@ def test_gemm(): if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/lit_tests/kernel/wave/minimize_global_loads.py b/lit_tests/kernel/wave/minimize_global_loads.py index c5ec6860b..5a5402fdf 100644 --- a/lit_tests/kernel/wave/minimize_global_loads.py +++ b/lit_tests/kernel/wave/minimize_global_loads.py @@ -1,8 +1,6 @@ # RUN: python %s | FileCheck %s import logging -from typing import Callable -import unittest import iree.turbine.kernel as tk import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw @@ -267,4 +265,3 @@ def test_gemm(): if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/lit_tests/kernel/wave/promotion.py b/lit_tests/kernel/wave/promotion.py index 45fcb50a0..7784885c8 100644 --- a/lit_tests/kernel/wave/promotion.py +++ b/lit_tests/kernel/wave/promotion.py @@ -1,7 +1,6 @@ # RUN: python %s | FileCheck %s import logging -import unittest import iree.turbine.kernel as tk import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw @@ -216,4 +215,3 @@ def test_gemm(): if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/lit_tests/kernel/wave/scheduling.py b/lit_tests/kernel/wave/scheduling.py index 83f6053ab..3f6004fd3 100644 --- a/lit_tests/kernel/wave/scheduling.py +++ b/lit_tests/kernel/wave/scheduling.py @@ -1,7 +1,6 @@ # RUN: python %s | FileCheck %s import logging -import unittest import iree.turbine.kernel as tk import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw @@ -246,4 +245,3 @@ def test_gemm_pipelined(): if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/requirements.txt b/requirements.txt index 3ad12af32..9ba15741c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ lit==18.1.7 mypy==1.8.0 ml_dtypes==0.5.0 setuptools +typing_extensions wheel # It is expected that you have installed a PyTorch version/variant specific diff --git a/setup.py b/setup.py index 6861d2306..5a94d24eb 100644 --- a/setup.py +++ b/setup.py @@ -110,6 +110,7 @@ def initialize_options(self): "torch>=2.3.0", f"Jinja2{get_version_spec('Jinja2')}", f"ml_dtypes{get_version_spec('ml_dtypes')}", + f"typing_extensions{get_version_spec('typing_extensions')}", ], extras_require={ "testing": [