diff --git a/iree/turbine/kernel/lang/wave_types.py b/iree/turbine/kernel/lang/wave_types.py index de6139df5..f87a95702 100644 --- a/iree/turbine/kernel/lang/wave_types.py +++ b/iree/turbine/kernel/lang/wave_types.py @@ -4,7 +4,6 @@ ClassVar, Iterable, Optional, - Self, Type, TypeAlias, TypeVar, @@ -17,6 +16,7 @@ from sympy import Symbol from sympy.core.expr import Expr +from typing_extensions import Self from itertools import chain diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 912fadd4f..30f7241ce 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -9,12 +9,12 @@ Any, Callable, Optional, - Self, Sequence, Type, TypeVar, final, ) +from typing_extensions import Self import torch.fx as fx from ..lang.wave_types import Memory, Register, IndexMapping diff --git a/requirements.txt b/requirements.txt index 3ad12af32..9ba15741c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ lit==18.1.7 mypy==1.8.0 ml_dtypes==0.5.0 setuptools +typing_extensions wheel # It is expected that you have installed a PyTorch version/variant specific diff --git a/setup.py b/setup.py index 6861d2306..5a94d24eb 100644 --- a/setup.py +++ b/setup.py @@ -110,6 +110,7 @@ def initialize_options(self): "torch>=2.3.0", f"Jinja2{get_version_spec('Jinja2')}", f"ml_dtypes{get_version_spec('ml_dtypes')}", + f"typing_extensions{get_version_spec('typing_extensions')}", ], extras_require={ "testing": [