From d49dd8efddf44b87666f02b360ac79d521ccd39c Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 23 Jan 2025 00:27:33 +0800 Subject: [PATCH 1/5] clean up Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/whisper.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index c1f3bb0ca33c2..e683dec860caa 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -35,6 +35,13 @@ logger = init_logger(__name__) +def _cast_overflow_values(x: torch.Tensor) -> torch.Tensor: + if x.isinf().any() or x.isnan().any(): + clamp_value = torch.finfo(x.dtype).max - 1000 + x = torch.clamp(x, min=-clamp_value, max=clamp_value) + return x + + class WhisperAudioInputs(TypedDict): input_features: NestedTensors """Shape: `(batch_size, 128, M)`""" @@ -295,12 +302,7 @@ def forward( hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - - if hidden_states.isinf().any() or hidden_states.isnan().any(): - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, - min=-clamp_value, - max=clamp_value) + hidden_states = _cast_overflow_values(hidden_states) return hidden_states From f77cf40ca75fe166213c04855b81513ac79a7f2c Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 23 Jan 2025 14:59:31 +0800 Subject: [PATCH 2/5] revert whisper Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/whisper.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index e683dec860caa..b939f4d64b4b5 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -35,13 +35,6 @@ logger = init_logger(__name__) -def _cast_overflow_values(x: torch.Tensor) -> torch.Tensor: - if x.isinf().any() or x.isnan().any(): - clamp_value = torch.finfo(x.dtype).max - 1000 - x = torch.clamp(x, min=-clamp_value, max=clamp_value) - return x - - class WhisperAudioInputs(TypedDict): input_features: NestedTensors """Shape: `(batch_size, 128, M)`""" @@ -302,7 +295,12 @@ def forward( hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - hidden_states = _cast_overflow_values(hidden_states) + + if hidden_states.isinf().any() or hidden_states.isnan().any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, + min=-clamp_value, + max=clamp_value) return hidden_states @@ -734,4 +732,4 @@ def load_weights(self, weights: Iterable[Tuple[str, loaded_weights = [(name, loaded_weight) for name, loaded_weight in weights] mapper = WeightsMapper({".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}) - return loader.load_weights(loaded_weights, mapper=mapper) + return loader.load_weights(loaded_weights, mapper=mapper) \ No newline at end of file From b4f1c89d966d3a6763b11d9b46acda0d696f8ff5 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 23 Jan 2025 15:05:07 +0800 Subject: [PATCH 3/5] fix whisper self attn linear Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/layers/linear.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 52263e96fb9f9..deb14cfb7db52 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -319,8 +319,11 @@ def __init__(self, self.weight_loader_v2 if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) if bias: + # NOTE(Isotr0py): We intentionally use zeros to initialize the bias, + # so that it can be still compatible with the qkv_proj in model + # like whisper which has bias on q and v proj but not on k proj. self.bias = Parameter( - torch.empty(self.output_size_per_partition, + torch.zeros(self.output_size_per_partition, dtype=params_dtype)) set_weight_attrs(self.bias, { "output_dim": 0, From 7ca544361dd31c910734fe0c6ce80e921bbd99b3 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 23 Jan 2025 15:42:44 +0800 Subject: [PATCH 4/5] another solution Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/layers/linear.py | 5 +---- vllm/model_executor/models/whisper.py | 19 ++++++++++++++++++- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index deb14cfb7db52..52263e96fb9f9 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -319,11 +319,8 @@ def __init__(self, self.weight_loader_v2 if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) if bias: - # NOTE(Isotr0py): We intentionally use zeros to initialize the bias, - # so that it can be still compatible with the qkv_proj in model - # like whisper which has bias on q and v proj but not on k proj. self.bias = Parameter( - torch.zeros(self.output_size_per_partition, + torch.empty(self.output_size_per_partition, dtype=params_dtype)) set_weight_attrs(self.bias, { "output_dim": 0, diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index b939f4d64b4b5..4253a721db70a 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -728,8 +728,25 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: + # add fake zeros bias for k_proj to state_dict + weights = _create_fake_bias_for_k_proj(weights) loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."]) loaded_weights = [(name, loaded_weight) for name, loaded_weight in weights] mapper = WeightsMapper({".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}) - return loader.load_weights(loaded_weights, mapper=mapper) \ No newline at end of file + return loader.load_weights(loaded_weights, mapper=mapper) + + +def _create_fake_bias_for_k_proj( + weights: Iterable[Tuple[str, torch.Tensor]] +) -> Iterable[Tuple[str, torch.Tensor]]: + """ + Create full zeros bias for k_proj weight in self-attention layers. + So that the bias for k_proj in qkv_proj can be initialized with zeros. + """ + for name, weight in weights: + if ".self_attn.k_proj.weight" in name: + bias = torch.zeros(weight.size(0)) + bias_name = name.replace("weight", "bias") + yield from [(name, weight), (bias_name, bias)] + yield name, weight From bee02df41ca739142b1b4758cb3d52ef9e65372f Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 23 Jan 2025 15:45:11 +0800 Subject: [PATCH 5/5] clean up Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/whisper.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 4253a721db70a..b8512b735da94 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -728,13 +728,11 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - # add fake zeros bias for k_proj to state_dict - weights = _create_fake_bias_for_k_proj(weights) loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."]) - loaded_weights = [(name, loaded_weight) - for name, loaded_weight in weights] mapper = WeightsMapper({".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}) - return loader.load_weights(loaded_weights, mapper=mapper) + # add fake zeros bias for k_proj to state_dict + weights = _create_fake_bias_for_k_proj(weights) + return loader.load_weights(weights, mapper=mapper) def _create_fake_bias_for_k_proj(