diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 3d76c36f2648b..939d195a12b08 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -40,7 +40,18 @@ def block2batch(tensor, block_mapping): def block_softmax(batch_size, attn, block_mapping): - attn.sub_(10.0) + # We're using global maximum to decrease the exponent as + # it's fast to compute and performs reasonably well. + # This is by no means a final solution and needs to + # be properly addressed in the future. + # + # Additionally there's a bug where 'max' is not parallelized + # across TPC cores, so we need to split the tensor manually + # instead of simply doing attn_max = attn.max() + + tail_dims = tuple(range(1, attn.dim())) + attn_max = attn.amax(tail_dims).amax() + attn.sub_(attn_max) attn = attn.exp_() sums = attn.sum(dim=-1).unsqueeze(-1) sums = block2batch(sums, block_mapping)