Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Jan 17, 2025
1 parent f874ef3 commit f6d081e
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 9 deletions.
5 changes: 3 additions & 2 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,7 @@ def transform_index_backwards(
iters = self.mapping.iters
mapping = self.mapping.dynamic_val_mappings[i]

# This logic relies on fact our mapping is identity.
# This logic assumes that the output mapping is identity.
subs = {
k: index[v] for k, v in zip(iters, self.mapping.output_mapping.keys())
}
Expand Down Expand Up @@ -1333,13 +1333,14 @@ def transform_index_backwards(
iters = self.mapping.iters
mapping = self.mapping.dynamic_val_mappings[i]

# This logic relies on fact in mapping is identity.
# This logic assumes that the input mapping is identity.
subs = {
k: index[v] for k, v in zip(iters, self.mapping.input_mapping.keys())
}
return {
k: IndexSequence.from_expr(mapping[k], subs)
for k in arg.type.symbolic_shape
if k in mapping
}

return index
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 The IREE Authors
# Copyright 2025 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
Expand Down
4 changes: 0 additions & 4 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,10 +1102,6 @@ def device_randn(*args, **kwargs):
return to_default_device(torch.randn(*args, **kwargs))


def device_randn_like(*args, **kwargs):
return to_default_device(torch.randn_like(*args, **kwargs))


def device_randint(*args, **kwargs):
return to_default_device(torch.randint(*args, **kwargs))

Expand Down
3 changes: 1 addition & 2 deletions tests/kernel/wave/attention/paged_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
device_arange,
device_randn,
device_randint,
device_randn_like,
device_zeros,
)
from iree.turbine.kernel.wave.constraints import MMAType
Expand Down Expand Up @@ -221,7 +220,7 @@ def testPagedFlashDecoding(
key_cache = device_randn(
num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
)
value_cache = device_randn_like(key_cache)
value_cache = torch.randn_like(key_cache)
# TODO: The block table entries should be able to be a random number
# in the range [0, num_blocks * block_size), but that fails for now.
# As a workaround, the maximum value is set to num_seqs - 1.
Expand Down

0 comments on commit f6d081e

Please sign in to comment.