Skip to content

Commit

Permalink
Add attention with bias tests
Browse files Browse the repository at this point in the history
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Nov 25, 2024
1 parent 965247e commit 9a03b35
Show file tree
Hide file tree
Showing 7 changed files with 760 additions and 131 deletions.
42 changes: 28 additions & 14 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,20 +538,6 @@ def expanded_dims(self, value: dict[IndexSymbol, int]):
raise ValueError("Expanded dims must be a dict")
self.fx_node.expanded_dims = value

@property
def anchor(self) -> fx.Node:
"""
The anchor is a node that provides information to the node
such as vector_shapes, indexing information etc.
"""
if hasattr(self.fx_node, "anchor"):
return self.fx_node.anchor
return None

@anchor.setter
def anchor(self, value: fx.Node):
self.fx_node.anchor = value

@property
def vector_shapes(self) -> dict[IndexSymbol, int]:
if hasattr(self.fx_node, "vector_shapes"):
Expand Down Expand Up @@ -590,6 +576,14 @@ def align_index(self, constraints: list["Constraint"]) -> None:
"""
pass

def transform_index(
self, index: dict[IndexSymbol, IndexSequence]
) -> dict[IndexSymbol, IndexSequence]:
"""
Transform the index of the node based on the provided mapping.
"""
return index


@define_py_op(operator.add)
@define_py_op(operator.sub)
Expand Down Expand Up @@ -1426,6 +1420,26 @@ def infer_type(self):
), f"Target shape {self.target_shape} must be a permutation of source shape {src_type.symbolic_shape}"
self.type = Register[*self.target_shape, src_type.dtype]

def transform_index(
self, index: dict[IndexSymbol, IndexSequence]
) -> dict[IndexSymbol, IndexSequence]:
"""
The permute operation swaps the strides of the permuted indices.
So say we have a permute operation that swaps [B, M, N] to
[M, N, B], then we swap the strides of the dimensions.
"""
custom_src = get_custom(self.arg)
src_shape = custom_src.type.symbolic_shape
src_to_target = {
src: self.target_shape[src_shape.index(src)] for src in src_shape
}
permuted_index = {
k: IndexSequence(v.start, v.size, index[src_to_target[k]].stride)
for k, v in index.items()
if k in src_shape
}
return permuted_index


def _to_sequence(input: Any | Sequence[Any]) -> Sequence[Any]:
return input if isinstance(input, Sequence) else (input,)
Expand Down
Loading

0 comments on commit 9a03b35

Please sign in to comment.