Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run unit tests across Python versions 3.10-3.12. #326

Merged
merged 13 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,17 @@ concurrency:

jobs:
test:
name: "Unit Tests and Type Checking"
name: "${{ matrix.os }} :: ${{ matrix.version }} :: Unit Tests and Type Checking "
strategy:
fail-fast: false
matrix:
version: [3.11]
version:
- "3.10"
- "3.11"
- "3.12"

# Support for Python 3.13 depends on https://github.com/pytorch/pytorch/issues/130249
# - "3.13"
ScottTodd marked this conversation as resolved.
Show resolved Hide resolved
os: [ubuntu-22.04]
runs-on: ${{matrix.os}}
env:
Expand Down
10 changes: 6 additions & 4 deletions iree/turbine/kernel/compiler/kernel_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,12 @@ def only_write_dependencies(node):
# Create new Memory type with the correct usage
memory_type = self.bindings[index].kernel_buffer_type
self.bindings[index].kernel_buffer_type = Memory[
*memory_type.symbolic_shape,
memory_type.address_space,
memory_type.dtype,
usage,
(
*memory_type.symbolic_shape,
memory_type.address_space,
memory_type.dtype,
usage,
)
]
return

Expand Down
2 changes: 1 addition & 1 deletion iree/turbine/kernel/lang/wave_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
ClassVar,
Iterable,
Optional,
Self,
Type,
TypeAlias,
TypeVar,
Expand All @@ -17,6 +16,7 @@

from sympy import Symbol
from sympy.core.expr import Expr
from typing_extensions import Self

from itertools import chain

Expand Down
20 changes: 10 additions & 10 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
Any,
Callable,
Optional,
Self,
Sequence,
Type,
TypeVar,
final,
)
from typing_extensions import Self
import torch.fx as fx

from ..lang.wave_types import Memory, Register, IndexMapping
Expand Down Expand Up @@ -792,7 +792,7 @@ def indexing_dims(self) -> list[IndexSymbol]:

@property
def type(self) -> "Memory":
return Memory[*self.shape, self.address_space, self.dtype]
return Memory[(*self.shape, self.address_space, self.dtype)]


@define_op("shared_memory_barrier")
Expand Down Expand Up @@ -855,7 +855,7 @@ def indexing_dims(self) -> list[IndexSymbol]:
return list(self.shape)

def infer_type(self):
self.type = Register[*self.shape, self.dtype]
self.type = Register[(*self.shape, self.dtype)]


@define_op("mma")
Expand Down Expand Up @@ -960,7 +960,7 @@ def indexing_dims(self) -> list[IndexSymbol]:

def infer_type(self):
dtype = self.memory_type.dtype
self.type = Register[*self.indexing_dims, dtype]
self.type = Register[(*self.indexing_dims, dtype)]

@property
def memory_type(self) -> "Memory":
Expand Down Expand Up @@ -1168,7 +1168,7 @@ def indexing_dims(self) -> list[IndexSymbol]:
def infer_type(self):
address_space = self.memory_type.address_space
dtype = self.memory_type.dtype
self.type = Memory[*self.indexing_dims, address_space, dtype]
self.type = Memory[(*self.indexing_dims, address_space, dtype)]

@property
def memory_type(self) -> "Memory":
Expand Down Expand Up @@ -1304,7 +1304,7 @@ def infer_type(self):
dst_shape = list(src_type.symbolic_shape)
dim_to_remove = dst_shape[-1] if not non_unit_dim else non_unit_dim[0]
dst_shape.remove(dim_to_remove)
dst_type = Register[*dst_shape, src_type.dtype]
dst_type = Register[(*dst_shape, src_type.dtype)]
self.type = dst_type


Expand Down Expand Up @@ -1354,7 +1354,7 @@ def indexing_dims(self) -> list[IndexSymbol]:

def infer_type(self):
src_dtype = get_custom(self.arg).type.dtype
self.type = Register[*self.target_shape, src_dtype]
self.type = Register[(*self.target_shape, src_dtype)]


@define_interface_op("max")
Expand Down Expand Up @@ -1406,7 +1406,7 @@ def infer_type(self):
else:
src_type = get_custom(self.arg).type
reduced_dims = [dims for dims in src_type.symbolic_shape if dims != self.dim]
dst_type = Register[*reduced_dims, src_type.dtype]
dst_type = Register[(*reduced_dims, src_type.dtype)]
self.type = dst_type

@property
Expand Down Expand Up @@ -1465,7 +1465,7 @@ def indexing_dims(self) -> list[IndexSymbol]:

def infer_type(self):
src_shape = get_custom(self.arg).type.symbolic_shape
self.type = Register[*src_shape, self.dtype]
self.type = Register[(*src_shape, self.dtype)]


@define_op("permute")
Expand All @@ -1488,7 +1488,7 @@ def infer_type(self):
assert set(src_type.symbolic_shape) == set(
self.target_shape
), f"Target shape {self.target_shape} must be a permutation of source shape {src_type.symbolic_shape}"
self.type = Register[*self.target_shape, src_type.dtype]
self.type = Register[(*self.target_shape, src_type.dtype)]

def transform_index(
self, index: dict[IndexSymbol, IndexSequence]
Expand Down
6 changes: 3 additions & 3 deletions iree/turbine/kernel/wave/scheduling/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,15 @@ def schedule_reduction(
# is not dynamic.
max_induction_variable = int(max_induction_variable)
if max_induction_variable <= scheduler.num_stages - 1:
logger.warn(
logger.warning(
"Not enough iterations to pipeline the loop. Skipping pipelining."
)
return {}
else:
# Otherwise, we need to rely on assumptions provided by the author.
assumptions = get_assumptions(constraints)
if not assumptions:
logger.warn(
logger.warning(
"No assumptions provided to determine if the loop can be pipelined. Skipping pipelining."
)
return {}
Expand All @@ -122,7 +122,7 @@ def schedule_reduction(
constraints, max_induction_variable > scheduler.num_stages - 1
)
if not result:
logger.warn(
logger.warning(
"Not enough iterations to pipeline the loop. Skipping pipelining."
)
return {}
Expand Down
18 changes: 11 additions & 7 deletions iree/turbine/support/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import logging
import re
import os
import sys
import torch
import numpy as np

Expand Down Expand Up @@ -54,7 +55,7 @@ class DebugFlags:
def set(self, part: str):
m = re.match(SETTING_PART_PATTERN, part)
if not m:
logger.warn("Syntax error in %s flag: '%s'", FLAGS_ENV_NAME, part)
logger.warning("Syntax error in %s flag: '%s'", FLAGS_ENV_NAME, part)
return
name = m.group(2)
value = m.group(4)
Expand All @@ -64,19 +65,22 @@ def set(self, part: str):
logical_sense = m.group(1) != "-"

if name == "log_level":
log_level_mapping = logging.getLevelNamesMapping()
try:
self.log_level = log_level_mapping[value.upper()]
except KeyError:
logger.warn("Log level '%s' unknown (ignored)", value)
if sys.version_info >= (3, 11):
log_level_mapping = logging.getLevelNamesMapping() # Added in 3.11
try:
self.log_level = log_level_mapping[value.upper()]
except KeyError:
logger.warning("Log level '%s' unknown (ignored)", value)
else:
logger.warning("'log_level' flag requires Python >= 3.11")
elif name == "asserts":
self.asserts = logical_sense
global NDEBUG
NDEBUG = not logical_sense
elif name == "runtime_trace_dir":
self.runtime_trace_dir = value
else:
logger.warn("Unrecognized %s flag: '%s'", FLAGS_ENV_NAME, name)
logger.warning("Unrecognized %s flag: '%s'", FLAGS_ENV_NAME, name)

@staticmethod
def parse(settings: str) -> "DebugFlags":
Expand Down
8 changes: 4 additions & 4 deletions iree/turbine/tools/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def callback(self, op: Operation) -> None:
offset = [0 for _ in range(len(load_indices))]
for i in range(*result_shape):
ind = [int(x) + y for x, y in zip(load_indices, offset)]
value[i] = memref[*ind]
value[i] = memref[(*ind,)]
offset[-1] += 1
case vector_d.ExtractStridedSliceOp:
vector = self.symbol_table[op.vector]
Expand All @@ -168,7 +168,7 @@ def callback(self, op: Operation) -> None:
offset = [0 for _ in range(len(store_indices))]
for i in range(*result_shape):
memref[
*[int(x) + y for x, y in zip(store_indices, offset)]
(*[int(x) + y for x, y in zip(store_indices, offset)],)
] = vector[i]
offset[-1] += 1
case vector_d.MaskedStoreOp:
Expand All @@ -185,7 +185,7 @@ def callback(self, op: Operation) -> None:
for i in range(*result_shape):
if mask[i]:
ind = [int(x) + y for x, y in zip(store_indices, offset)]
memref[*ind] = vector[i]
memref[(*ind,)] = vector[i]

offset[-1] += 1
case vector_d.ConstantMaskOp:
Expand Down Expand Up @@ -313,7 +313,7 @@ def interpret_ndrange(
):
for wg in np.ndindex(*workgroup_count):
for t in np.ndindex(*workgroup_size):
Interpreter([*wg], [*t]).interpret(asm)
Interpreter([(*wg,)], [(*t,)]).interpret(asm)


if __name__ == "__main__":
Expand Down
3 changes: 0 additions & 3 deletions lit_tests/kernel/wave/barriers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# RUN: python %s | FileCheck %s

import logging
from typing import Callable
import unittest
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
Expand Down Expand Up @@ -262,4 +260,3 @@ def test_gemm():

if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
ScottTodd marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 0 additions & 2 deletions lit_tests/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# RUN: python %s | FileCheck %s

import logging
import unittest
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
Expand Down Expand Up @@ -1109,4 +1108,3 @@ def test_chained_gemm_32x32x8():

if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
2 changes: 0 additions & 2 deletions lit_tests/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# RUN: python %s | FileCheck %s

import logging
import unittest
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
Expand Down Expand Up @@ -350,4 +349,3 @@ def test_gemm():

if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
3 changes: 0 additions & 3 deletions lit_tests/kernel/wave/minimize_global_loads.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# RUN: python %s | FileCheck %s

import logging
from typing import Callable
import unittest
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
Expand Down Expand Up @@ -267,4 +265,3 @@ def test_gemm():

if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
2 changes: 0 additions & 2 deletions lit_tests/kernel/wave/promotion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# RUN: python %s | FileCheck %s

import logging
import unittest
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
Expand Down Expand Up @@ -216,4 +215,3 @@ def test_gemm():

if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
2 changes: 0 additions & 2 deletions lit_tests/kernel/wave/scheduling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# RUN: python %s | FileCheck %s

import logging
import unittest
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
Expand Down Expand Up @@ -246,4 +245,3 @@ def test_gemm_pipelined():

if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ lit==18.1.7
mypy==1.8.0
ml_dtypes==0.5.0
setuptools
typing_extensions
wheel

# It is expected that you have installed a PyTorch version/variant specific
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def initialize_options(self):
"torch>=2.3.0",
f"Jinja2{get_version_spec('Jinja2')}",
f"ml_dtypes{get_version_spec('ml_dtypes')}",
f"typing_extensions{get_version_spec('typing_extensions')}",
],
extras_require={
"testing": [
Expand Down
Loading