Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run unit tests across Python versions 3.10-3.12. #326

Merged
merged 13 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,17 @@ concurrency:

jobs:
test:
name: "Unit Tests and Type Checking"
name: "${{ matrix.os }} :: ${{ matrix.version }} :: Unit Tests and Type Checking"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marbre you may have opinions on the job naming. I'm not sure about this.

BTW, type checking does make sense to run on multiple python versions. The use of logging.getLevelNamesMapping() # Added in 3.11 was caught by running mypy on 3.10. I originally wanted to move type checking to pre-commit since it's similar to other lint checks, but I now think it makes sense to keep it here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following the naming in other repos we would have "Unit Tests and Type Checking :: ${{ matrix.os }} :: ${{ matrix.version }}" but to be honest I don't really care and if we start to dislike we can change it in the future. Thus I don't want to be picky here 😉

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'd like to put important information at the start of the name so it doesn't get cut off

  • image
  • image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argh GitHub. This was also being used for a required check matching Unit Tests and Type Checking (3.11, ubuntu-22.04), but that check is actually picking up TK CI / Unit Tests and Type Checking (3.11, ubuntu-22.04) also.

I might just copy this snippet from IREE to have a "ci-summary" job with a unique name we can set as required: https://github.com/iree-org/iree/blob/7177c29f9b2d9e255b63987f5dfff174ec2afc2f/.github/workflows/ci.yml#L201-L236

strategy:
fail-fast: false
matrix:
version: [3.11]
version:
- "3.10"
- "3.11"
- "3.12"

# Support for Python 3.13 depends on https://github.com/pytorch/pytorch/issues/130249
# - "3.13"
ScottTodd marked this conversation as resolved.
Show resolved Hide resolved
os: [ubuntu-22.04]
runs-on: ${{matrix.os}}
env:
Expand Down
10 changes: 6 additions & 4 deletions iree/turbine/kernel/compiler/kernel_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,12 @@ def only_write_dependencies(node):
# Create new Memory type with the correct usage
memory_type = self.bindings[index].kernel_buffer_type
self.bindings[index].kernel_buffer_type = Memory[
*memory_type.symbolic_shape,
memory_type.address_space,
memory_type.dtype,
usage,
(
*memory_type.symbolic_shape,
memory_type.address_space,
memory_type.dtype,
usage,
)
]
return

Expand Down
2 changes: 1 addition & 1 deletion iree/turbine/kernel/lang/wave_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
ClassVar,
Iterable,
Optional,
Self,
Type,
TypeAlias,
TypeVar,
Expand All @@ -17,6 +16,7 @@

from sympy import Symbol
from sympy.core.expr import Expr
from typing_extensions import Self

from itertools import chain

Expand Down
20 changes: 10 additions & 10 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
Any,
Callable,
Optional,
Self,
Sequence,
Type,
TypeVar,
final,
)
from typing_extensions import Self
import torch.fx as fx

from ..lang.wave_types import Memory, Register, IndexMapping
Expand Down Expand Up @@ -792,7 +792,7 @@ def indexing_dims(self) -> list[IndexSymbol]:

@property
def type(self) -> "Memory":
return Memory[*self.shape, self.address_space, self.dtype]
return Memory[(*self.shape, self.address_space, self.dtype)]


@define_op("shared_memory_barrier")
Expand Down Expand Up @@ -855,7 +855,7 @@ def indexing_dims(self) -> list[IndexSymbol]:
return list(self.shape)

def infer_type(self):
self.type = Register[*self.shape, self.dtype]
self.type = Register[(*self.shape, self.dtype)]


@define_op("mma")
Expand Down Expand Up @@ -960,7 +960,7 @@ def indexing_dims(self) -> list[IndexSymbol]:

def infer_type(self):
dtype = self.memory_type.dtype
self.type = Register[*self.indexing_dims, dtype]
self.type = Register[(*self.indexing_dims, dtype)]

@property
def memory_type(self) -> "Memory":
Expand Down Expand Up @@ -1168,7 +1168,7 @@ def indexing_dims(self) -> list[IndexSymbol]:
def infer_type(self):
address_space = self.memory_type.address_space
dtype = self.memory_type.dtype
self.type = Memory[*self.indexing_dims, address_space, dtype]
self.type = Memory[(*self.indexing_dims, address_space, dtype)]

@property
def memory_type(self) -> "Memory":
Expand Down Expand Up @@ -1304,7 +1304,7 @@ def infer_type(self):
dst_shape = list(src_type.symbolic_shape)
dim_to_remove = dst_shape[-1] if not non_unit_dim else non_unit_dim[0]
dst_shape.remove(dim_to_remove)
dst_type = Register[*dst_shape, src_type.dtype]
dst_type = Register[(*dst_shape, src_type.dtype)]
self.type = dst_type


Expand Down Expand Up @@ -1354,7 +1354,7 @@ def indexing_dims(self) -> list[IndexSymbol]:

def infer_type(self):
src_dtype = get_custom(self.arg).type.dtype
self.type = Register[*self.target_shape, src_dtype]
self.type = Register[(*self.target_shape, src_dtype)]


@define_interface_op("max")
Expand Down Expand Up @@ -1406,7 +1406,7 @@ def infer_type(self):
else:
src_type = get_custom(self.arg).type
reduced_dims = [dims for dims in src_type.symbolic_shape if dims != self.dim]
dst_type = Register[*reduced_dims, src_type.dtype]
dst_type = Register[(*reduced_dims, src_type.dtype)]
self.type = dst_type

@property
Expand Down Expand Up @@ -1465,7 +1465,7 @@ def indexing_dims(self) -> list[IndexSymbol]:

def infer_type(self):
src_shape = get_custom(self.arg).type.symbolic_shape
self.type = Register[*src_shape, self.dtype]
self.type = Register[(*src_shape, self.dtype)]


@define_op("permute")
Expand All @@ -1488,7 +1488,7 @@ def infer_type(self):
assert set(src_type.symbolic_shape) == set(
self.target_shape
), f"Target shape {self.target_shape} must be a permutation of source shape {src_type.symbolic_shape}"
self.type = Register[*self.target_shape, src_type.dtype]
self.type = Register[(*self.target_shape, src_type.dtype)]

def transform_index(
self, index: dict[IndexSymbol, IndexSequence]
Expand Down
6 changes: 3 additions & 3 deletions iree/turbine/kernel/wave/scheduling/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,15 @@ def schedule_reduction(
# is not dynamic.
max_induction_variable = int(max_induction_variable)
if max_induction_variable <= scheduler.num_stages - 1:
logger.warn(
logger.warning(
"Not enough iterations to pipeline the loop. Skipping pipelining."
)
return {}
else:
# Otherwise, we need to rely on assumptions provided by the author.
assumptions = get_assumptions(constraints)
if not assumptions:
logger.warn(
logger.warning(
"No assumptions provided to determine if the loop can be pipelined. Skipping pipelining."
)
return {}
Expand All @@ -122,7 +122,7 @@ def schedule_reduction(
constraints, max_induction_variable > scheduler.num_stages - 1
)
if not result:
logger.warn(
logger.warning(
"Not enough iterations to pipeline the loop. Skipping pipelining."
)
return {}
Expand Down
18 changes: 11 additions & 7 deletions iree/turbine/support/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import logging
import re
import os
import sys
import torch
import numpy as np

Expand Down Expand Up @@ -54,7 +55,7 @@ class DebugFlags:
def set(self, part: str):
m = re.match(SETTING_PART_PATTERN, part)
if not m:
logger.warn("Syntax error in %s flag: '%s'", FLAGS_ENV_NAME, part)
logger.warning("Syntax error in %s flag: '%s'", FLAGS_ENV_NAME, part)
return
name = m.group(2)
value = m.group(4)
Expand All @@ -64,19 +65,22 @@ def set(self, part: str):
logical_sense = m.group(1) != "-"

if name == "log_level":
log_level_mapping = logging.getLevelNamesMapping()
try:
self.log_level = log_level_mapping[value.upper()]
except KeyError:
logger.warn("Log level '%s' unknown (ignored)", value)
if sys.version_info >= (3, 11):
log_level_mapping = logging.getLevelNamesMapping() # Added in 3.11
try:
self.log_level = log_level_mapping[value.upper()]
except KeyError:
logger.warning("Log level '%s' unknown (ignored)", value)
else:
logger.warning("'log_level' flag requires Python >= 3.11")
elif name == "asserts":
self.asserts = logical_sense
global NDEBUG
NDEBUG = not logical_sense
elif name == "runtime_trace_dir":
self.runtime_trace_dir = value
else:
logger.warn("Unrecognized %s flag: '%s'", FLAGS_ENV_NAME, name)
logger.warning("Unrecognized %s flag: '%s'", FLAGS_ENV_NAME, name)

@staticmethod
def parse(settings: str) -> "DebugFlags":
Expand Down
8 changes: 4 additions & 4 deletions iree/turbine/tools/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def callback(self, op: Operation) -> None:
offset = [0 for _ in range(len(load_indices))]
for i in range(*result_shape):
ind = [int(x) + y for x, y in zip(load_indices, offset)]
value[i] = memref[*ind]
value[i] = memref[(*ind,)]
offset[-1] += 1
case vector_d.ExtractStridedSliceOp:
vector = self.symbol_table[op.vector]
Expand All @@ -168,7 +168,7 @@ def callback(self, op: Operation) -> None:
offset = [0 for _ in range(len(store_indices))]
for i in range(*result_shape):
memref[
*[int(x) + y for x, y in zip(store_indices, offset)]
(*[int(x) + y for x, y in zip(store_indices, offset)],)
] = vector[i]
offset[-1] += 1
case vector_d.MaskedStoreOp:
Expand All @@ -185,7 +185,7 @@ def callback(self, op: Operation) -> None:
for i in range(*result_shape):
if mask[i]:
ind = [int(x) + y for x, y in zip(store_indices, offset)]
memref[*ind] = vector[i]
memref[(*ind,)] = vector[i]

offset[-1] += 1
case vector_d.ConstantMaskOp:
Expand Down Expand Up @@ -313,7 +313,7 @@ def interpret_ndrange(
):
for wg in np.ndindex(*workgroup_count):
for t in np.ndindex(*workgroup_size):
Interpreter([*wg], [*t]).interpret(asm)
Interpreter([(*wg,)], [(*t,)]).interpret(asm)


if __name__ == "__main__":
Expand Down
3 changes: 0 additions & 3 deletions lit_tests/kernel/wave/barriers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# RUN: python %s | FileCheck %s

import logging
from typing import Callable
import unittest
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
Expand Down Expand Up @@ -262,4 +260,3 @@ def test_gemm():

if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
ScottTodd marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 0 additions & 2 deletions lit_tests/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# RUN: python %s | FileCheck %s

import logging
import unittest
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
Expand Down Expand Up @@ -1109,4 +1108,3 @@ def test_chained_gemm_32x32x8():

if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
2 changes: 0 additions & 2 deletions lit_tests/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# RUN: python %s | FileCheck %s

import logging
import unittest
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
Expand Down Expand Up @@ -350,4 +349,3 @@ def test_gemm():

if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
3 changes: 0 additions & 3 deletions lit_tests/kernel/wave/minimize_global_loads.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# RUN: python %s | FileCheck %s

import logging
from typing import Callable
import unittest
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
Expand Down Expand Up @@ -267,4 +265,3 @@ def test_gemm():

if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
2 changes: 0 additions & 2 deletions lit_tests/kernel/wave/promotion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# RUN: python %s | FileCheck %s

import logging
import unittest
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
Expand Down Expand Up @@ -216,4 +215,3 @@ def test_gemm():

if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
2 changes: 0 additions & 2 deletions lit_tests/kernel/wave/scheduling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# RUN: python %s | FileCheck %s

import logging
import unittest
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
Expand Down Expand Up @@ -246,4 +245,3 @@ def test_gemm_pipelined():

if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ lit==18.1.7
mypy==1.8.0
ml_dtypes==0.5.0
setuptools
typing_extensions
wheel

# It is expected that you have installed a PyTorch version/variant specific
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def initialize_options(self):
"torch>=2.3.0",
f"Jinja2{get_version_spec('Jinja2')}",
f"ml_dtypes{get_version_spec('ml_dtypes')}",
f"typing_extensions{get_version_spec('typing_extensions')}",
],
extras_require={
"testing": [
Expand Down
Loading