Skip to content

Commit

Permalink
Revert "Remove hardcoded value from softmax in flat_pa (HabanaAI#280)"
Browse files Browse the repository at this point in the history
This reverts commit 35a4a98.
  • Loading branch information
xuechendi committed Sep 12, 2024
1 parent 181babf commit 1f94b52
Showing 1 changed file with 1 addition and 12 deletions.
13 changes: 1 addition & 12 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,7 @@ def block2batch(tensor, block_mapping):


def block_softmax(batch_size, attn, block_mapping):
# We're using global maximum to decrease the exponent as
# it's fast to compute and performs reasonably well.
# This is by no means a final solution and needs to
# be properly addressed in the future.
#
# Additionally there's a bug where 'max' is not parallelized
# across TPC cores, so we need to split the tensor manually
# instead of simply doing attn_max = attn.max()

tail_dims = tuple(range(1, attn.dim()))
attn_max = attn.amax(tail_dims).amax()
attn.sub_(attn_max)
attn.sub_(10.0)
attn = attn.exp_()
sums = attn.sum(dim=-1).unsqueeze(-1)
sums = block2batch(sums, block_mapping)
Expand Down

0 comments on commit 1f94b52

Please sign in to comment.