Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Wave] Add self_index, predicate, and selectOp to implement causal attention #452

Merged
merged 23 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
6762c5d
hack iota / self_index
ftynse Jan 28, 2025
1a34aca
self_index fixes and stress test
nicolasvasilache Jan 30, 2025
86c4106
Initial support for T5 RPE in vanilla attention
nicolasvasilache Jan 30, 2025
5d05791
Add support for tkw.minimum
nicolasvasilache Jan 31, 2025
0b9dc22
Finish T5 RPE on vanilla attention
nicolasvasilache Jan 31, 2025
5659218
Rework T5 RPE test
nicolasvasilache Feb 1, 2025
09c5e3d
Rename files
nicolasvasilache Feb 3, 2025
46a7bba
add signed integer comparisons, select and pow ops
ftynse Feb 3, 2025
eb53619
Add simple triangular test that will be useful for causal attention
nicolasvasilache Feb 3, 2025
9ff210c
Add causal attention
nicolasvasilache Feb 3, 2025
9051cdd
Debug causal attention
nicolasvasilache Feb 3, 2025
bbd11c9
[Wave] Remove/clean all non causal related changes
raikonenfnu Feb 4, 2025
3cf3eb2
[Wave] Refactor sgt/slt/sge/sle op into pyoperator
raikonenfnu Feb 4, 2025
e37b36a
Refactor/cleanliness of codegen.py
raikonenfnu Feb 4, 2025
2b61f84
Add Causal template mask + small cleanup
raikonenfnu Feb 4, 2025
02724d8
pre-commit fix
raikonenfnu Feb 4, 2025
c1cc2a3
Add LIT test for self_index and causal
raikonenfnu Feb 4, 2025
b7b9a5f
Add GPR_NUM partition support for SelfIndex to enable more MMA intrinsic
raikonenfnu Feb 4, 2025
94a2e45
[Wave] add type checks and more type support for cmps
raikonenfnu Feb 4, 2025
f2804ae
[Wave] make element_per_thread VS index.size one or another and cleanups
raikonenfnu Feb 4, 2025
e080622
Clean up couple other NITs
raikonenfnu Feb 4, 2025
5531c94
Add support to overload multiple ops in same handle_fn
raikonenfnu Feb 4, 2025
2fc4f1c
Add verbose test
raikonenfnu Feb 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions iree/turbine/aot/support/ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,10 @@ def _is_float_type(type):
return isinstance(type, (BF16Type, F16Type, F32Type, F64Type, Float8E4M3FNUZType))


def _is_index_type(type):
return isinstance(type, (IndexType))


def _is_integer_like_type(type):
return isinstance(type, (IntegerType, IndexType))

Expand Down
1 change: 1 addition & 0 deletions iree/turbine/kernel/_support/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def bitwidth(self):

bf16 = DataType("bf16")
bool = DataType("bool", "i1")
i1 = bool
i4 = DataType("i4")
i8 = DataType("i8")
i16 = DataType("i16")
Expand Down
147 changes: 135 additions & 12 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ..lang.wave_types import Memory, Register, IndexMapping
from ..lang.global_symbols import *
from .._support.indexing import IndexExpr, IndexSymbol, IndexSequence
from .._support.dtype import DataType
from .._support.dtype import DataType, i1
from .._support.regions import RegionGraph
from .base import OpDispatcher
import numpy as np
Expand All @@ -45,6 +45,14 @@ def allocate(
...


def self_index(
idx: IndexExpr,
dtype: DataType,
elements_per_thread: Optional[IndexExpr | int] = None,
) -> "Register":
...


def extract(
register: "Register",
offsets: tuple[IndexExpr],
Expand Down Expand Up @@ -166,6 +174,22 @@ def shuffle(src: "Register", offset: int, width: int) -> "Register":
...


def gt(lhs: "Register", rhs: "Register") -> "Register":
...


def ge(lhs: "Register", rhs: "Register") -> "Register":
...


def lt(lhs: "Register", rhs: "Register") -> "Register":
...


def le(lhs: "Register", rhs: "Register") -> "Register":
...


def cast(src: "Register", dtype: DataType) -> "Register":
...

Expand All @@ -178,6 +202,10 @@ def reshape(inputs: Sequence["Register"]) -> "Register":
...


def select(cond: "Register", if_true: "Register", if_false: "Register") -> "Register":
...


def define_op(op_name: str) -> Callable[[T], T]:
def decorator(cls: T) -> T:
cls.tkw_op_name = op_name
Expand Down Expand Up @@ -680,14 +708,8 @@ def transform_index(
return index


@define_py_op(operator.add)
@define_py_op(operator.sub)
@define_py_op(operator.mul)
@define_py_op(operator.truediv)
@define_interface_op("maximum")
@define_interface_op("minimum")
@dataclass
class BinaryPyOp(CustomOp, ABC):
class BinaryOpBase(CustomOp, ABC):
"""
Represents an elementwise binary python operator.

Expand Down Expand Up @@ -715,21 +737,51 @@ def indexing_dims(self) -> list[IndexSymbol]:
def py_operator(self) -> str:
return self.tkw_op_name

def infer_type(self):
def infer_shape(self) -> Any:
lhs_type = get_custom(self.lhs).type
rhs_type = get_custom(self.rhs).type
has_same_type = has_same_custom_type(lhs_type, rhs_type)
if has_same_type:
self.type = lhs_type
return
return lhs_type.symbolic_shape

lhs_dim_set = set(lhs_type.symbolic_shape)
rhs_dim_set = set(rhs_type.symbolic_shape)
if lhs_dim_set.isdisjoint(rhs_dim_set):
raise ValueError(
"BinaryPyOp requires lhs and rhs shape to be at least broadcastable."
f" got {lhs_type.symbolic_shape} vs {rhs_type.symbolic_shape}"
)

# TODO: this logic looks suspicious. Specifically, there's no check that
# rhs_dim_set subsumes lhs_dim_set, they may partially overlap.
broadcasted_type = lhs_type if lhs_dim_set > rhs_dim_set else rhs_type
self.type = broadcasted_type
return broadcasted_type.symbolic_shape


@define_py_op(operator.add)
@define_py_op(operator.sub)
@define_py_op(operator.mul)
@define_py_op(operator.truediv)
@define_interface_op("maximum")
@define_interface_op("minimum")
@dataclass
class BinaryPyOp(BinaryOpBase, ABC):
def infer_type(self):
self.type = Register[(*self.infer_shape(), get_custom(self.lhs).type.dtype)]


@define_py_op(operator.gt)
@define_py_op(operator.ge)
@define_py_op(operator.lt)
@define_py_op(operator.le)
@define_interface_op("gt")
@define_interface_op("ge")
@define_interface_op("lt")
@define_interface_op("le")
@dataclass
class ComparisonPyOp(BinaryOpBase, ABC):
def infer_type(self):
self.type = Register[(*self.infer_shape(), i1)]


@define_interface_op("log2")
Expand Down Expand Up @@ -759,6 +811,42 @@ def infer_type(self):
self.type = src_type


@define_op("select")
@dataclass
class SelectOp(CustomOp):
cond: fx.Node
if_true: fx.Node
if_false: fx.Node

@property
def indexing_dims(self) -> list[IndexSymbol]:
combined_dims = []
combined_dims += get_custom(self.cond).indexing_dims
combined_dims += get_custom(self.if_true).indexing_dims
combined_dims += get_custom(self.if_false).indexing_dims
return list(dict.fromkeys(combined_dims))

def infer_type(self):
cond_type = get_custom(self.cond).type
if_true_type = get_custom(self.if_true).type
if_false_type = get_custom(self.if_false).type

if cond_type.dtype != i1:
raise ValueError("SelectOp expects condition type to be i1.")

if if_true_type.dtype != if_false_type.dtype:
raise ValueError("SelectOp expects lhs and rhs dtype to match.")

# TODO: support broadcasting behavior.
if (
cond_type.symbolic_shape != if_true_type.symbolic_shape
or cond_type.symbolic_shape != if_false_type.symbolic_shape
):
raise ValueError("SelectOp doesn't support broadcasting. (yet?)")

self.type = if_true_type


@final
@dataclass
class Unknown(CustomOp):
Expand Down Expand Up @@ -940,6 +1028,22 @@ def type(self) -> "Memory":
return Memory[(*self.shape, self.address_space, self.dtype)]


@define_op("self_index")
@dataclass
class SelfIndex(CustomOp):
dim: IndexExpr
dtype: DataType
elements_per_thread: Optional[IndexExpr | int] = None

@property
def indexing_dims(self) -> list[IndexSymbol]:
return [self.dim]

@property
def type(self) -> "Register":
return Register[(self.dim, self.dtype)]


@define_op("shared_memory_barrier")
@dataclass
class SharedMemoryBarrier(CustomOp):
Expand Down Expand Up @@ -1657,6 +1761,25 @@ class Broadcast(CustomOp, ABC):
arg: fx.Node
target_shape: Sequence[IndexSymbol] = None

def __post_init__(self):
# Required for setting up hash.
super().__post_init__()
# Verify for valid src type.
if isinstance(self.arg, fx.Node):
src = self.arg
elif isinstance(self.arg, fx.Proxy):
src = self.arg.node
else:
raise ValueError(f"Unexpected broadcast src type of {type(self.arg)}")

# Verifies target broadcast shape is valid.
src_type = get_custom(src).type
src_shape = set(getattr(src_type, "symbolic_shape", []))
dst_shape = set(self.target_shape)
assert src_shape.issubset(
dst_shape
), "Fail to initialize broadcast because of invalid target_shape."

@property
def indexing_dims(self) -> list[IndexSymbol]:
return self.target_shape
Expand Down
Loading
Loading