From 1f94b524cc74ac9d7803cb71356474c5814e7a83 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Thu, 12 Sep 2024 20:30:26 +0000 Subject: [PATCH] Revert "Remove hardcoded value from softmax in flat_pa (#280)" This reverts commit 35a4a984a79dc421320a2e520005e48ed884571d. --- vllm/hpu/ops.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 939d195a12b08..3d76c36f2648b 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -40,18 +40,7 @@ def block2batch(tensor, block_mapping): def block_softmax(batch_size, attn, block_mapping): - # We're using global maximum to decrease the exponent as - # it's fast to compute and performs reasonably well. - # This is by no means a final solution and needs to - # be properly addressed in the future. - # - # Additionally there's a bug where 'max' is not parallelized - # across TPC cores, so we need to split the tensor manually - # instead of simply doing attn_max = attn.max() - - tail_dims = tuple(range(1, attn.dim())) - attn_max = attn.amax(tail_dims).amax() - attn.sub_(attn_max) + attn.sub_(10.0) attn = attn.exp_() sums = attn.sum(dim=-1).unsqueeze(-1) sums = block2batch(sums, block_mapping)