From 8a7e91b49f38e9d63a92900b6c75a9301fe50850 Mon Sep 17 00:00:00 2001 From: Michal Adamczyk Date: Thu, 12 Sep 2024 15:53:33 +0200 Subject: [PATCH] Remove hardcoded value from softmax in flat_pa (#280) This PR removes the hardcoded value used to normalize softmax in flat_pa . Current approach is to use the global maximum as it is very easy to compute, but it has the drawback that other samples in a batch might slightly affect numerical stability. This is a first step to eliminated some of the INF/NaN issues we see in certain configurations and by no means this is a complete solutions. This needs to be revised in the future. --- vllm/hpu/ops.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 3d76c36f2648b..939d195a12b08 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -40,7 +40,18 @@ def block2batch(tensor, block_mapping): def block_softmax(batch_size, attn, block_mapping): - attn.sub_(10.0) + # We're using global maximum to decrease the exponent as + # it's fast to compute and performs reasonably well. + # This is by no means a final solution and needs to + # be properly addressed in the future. + # + # Additionally there's a bug where 'max' is not parallelized + # across TPC cores, so we need to split the tensor manually + # instead of simply doing attn_max = attn.max() + + tail_dims = tuple(range(1, attn.dim())) + attn_max = attn.amax(tail_dims).amax() + attn.sub_(attn_max) attn = attn.exp_() sums = attn.sum(dim=-1).unsqueeze(-1) sums = block2batch(sums, block_mapping)