Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper support #5964

Closed
wants to merge 16 commits into from
Prev Previous commit
Next Next commit
added whisper encoder decoder
  • Loading branch information
huseinzol05 committed Jun 27, 2024
commit 860b70aae37c3203cc4035b2ab78aaf0c4838a38
250 changes: 209 additions & 41 deletions vllm/model_executor/models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,17 @@ class WhisperPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
super().__init__(num_positions, embedding_dim)

def forward(self, input_ids, past_key_values_length=0, position_ids=None):
if position_ids is None:
return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]]
else:
return self.weight[position_ids]
def forward(self, input_ids, position_ids=None):
return self.weight[position_ids]

class WhisperAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
bias: bool = True,
config: Optional[WhisperConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
Expand All @@ -78,22 +75,26 @@ def __init__(
self.k_proj = RowParallelLinear(
input_size = embed_dim,
output_size = embed_dim,
bias = False
bias = False,
quant_config=quant_config,
)
self.k_proj = RowParallelLinear(
input_size = embed_dim,
output_size = embed_dim,
bias = bias
bias = bias,
quant_config=quant_config
)
self.q_proj = RowParallelLinear(
input_size = embed_dim,
output_size = embed_dim,
bias = bias
bias = bias,
quant_config=quant_config
)
self.out_proj = RowParallelLinear(
input_size = embed_dim,
output_size = embed_dim,
bias = bias
bias = bias,
quant_config=quant_config
)
self.attn = Attention(
self.num_heads,
Expand All @@ -110,51 +111,44 @@ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value = None,
kv_cache: torch.Tensor = None,
attn_metadata: AttentionMetadata = None,
):
is_cross_attention = key_value_states is not None
is_cross_attention = past_key_value is not None
bsz, tgt_len, _ = hidden_states.size()
q, _ = self.q_proj(hidden_states) * self.scaling
k, _ = self.k_proj(key_value_states)
v, _ = self.v_proj(hidden_states)
if is_cross_attention:
# reuse k,v, cross_attentions


if kv_cache is None:
q = self._shape(q, tgt_len, bsz)
k = self._shape(k, -1, bsz)
v = self._shape(v, -1, bsz)

if is_cross_attention:
k = past_key_value[0]
v = past_key_value[1]
else:
k, _ = self.k_proj(key_value_states)
v, _ = self.v_proj(hidden_states)
k = self._shape(k, -1, bsz)
v = self._shape(v, -1, bsz)

attn_output = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=self.is_causal and tgt_len > 1,
)
attn_output = attn_output.reshape(bsz, q_len, -1)
output = self.out_proj(attn_output)
elif past_key_value is not None:

else:
k, _ = self.k_proj(hidden_states)
v, _ = self.v_proj(hidden_states)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)

output, _ = self.o_proj(attn_output)
return output

# class WhisperAttention(nn.Module):
# def __init__(
# self,
# embed_dim: int,
# num_heads: int,
# dropout: float = 0.0,
# is_decoder: bool = False,
# bias: bool = True,
# is_causal: bool = False,
# config: Optional[WhisperConfig] = None,
# cache_config: Optional[CacheConfig] = None,
# ):

class WhisperEncoderLayer(nn.Module):
def __init__(
self,
Expand All @@ -173,8 +167,18 @@ def __init__(
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.activation_fn = FastGELU()
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.fc1 = RowParallelLinear(
input_size = self.embed_dim,
output_size = config.encoder_ffn_dim,
bias = True,
quant_config=quant_config,
)
self.fc2 = RowParallelLinear(
input_size = config.encoder_ffn_dim,
output_size = self.embed_dim,
bias = True,
quant_config=quant_config,
)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)

def forward(
Expand All @@ -188,13 +192,177 @@ def forward(
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.fc2(hidden_states)
hidden_states = residual + hidden_states

return hidden_states

class WhisperDecoderLayer(nn.Module):
def __init__(
self,
config: WhisperConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = WhisperAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
is_decoder=True,
is_causal=True,
config=config,
quant_config=quant_config,
cache_config=cache_config,
)
self.activation_fn = FastGELU()

self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = WhisperAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
is_decoder=True,
config=config,
quant_config=quant_config,
cache_config=cache_config,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = RowParallelLinear(
input_size = self.embed_dim,
output_size = config.decoder_ffn_dim,
bias = True,
quant_config=quant_config,
)
self.fc2 = RowParallelLinear(
input_size = config.decoder_ffn_dim,
output_size = self.embed_dim,
bias = True,
quant_config=quant_config,
)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)

def forward(
self,
hidden_states: torch.Tensor,
past_key_value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:

residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)

hidden_states = self.self_attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata
)
hidden_states = residual + hidden_states

hidden_states = self.self_attn(
hidden_states=hidden_states,
past_key_value=past_key_value,
attn_metadata=attn_metadata
)
hidden_states = residual + hidden_states

residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.fc2(hidden_states)
hidden_states = residual + hidden_states

return outputs

class WhisperEncoder(nn.Module):
def __init__(
self,
config: WhisperConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
embed_dim = config.d_model
self.num_mel_bins = config.num_mel_bins
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)

self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)

self.layers = nn.ModuleList([WhisperEncoderLayer(config, quant_config=quant_config, cache_config=cache_config)
for layer_idx in range(config.decoder_layers)])

self.layer_norm = nn.LayerNorm(config.d_model)

def forward(
self,
input_features,
):
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
inputs_embeds = inputs_embeds.permute(0, 2, 1)
embed_pos = self.embed_positions.weight

hidden_states = inputs_embeds + embed_pos
for idx, encoder_layer in enumerate(self.layers):
hidden_states = encoder_layer(hidden_states)

hidden_states = self.layer_norm(hidden_states)
return hidden_states

class WhisperDecoder(nn.Module):
def __init__(
self,
config: WhisperConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_target_positions
self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0

self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.d_model)
self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model)
self.layers = nn.ModuleList([WhisperDecoderLayer(config, quant_config=quant_config, cache_config=cache_config)
for layer_idx in range(config.decoder_layers)])
self.layer_norm = nn.LayerNorm(config.d_model)

def forward(
self,
input_ids,
positions: torch.Tensor,
past_key_values,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
inputs_embeds: Optional[torch.Tensor] = None,
):
inputs_embeds = self.embed_tokens(input_ids)
positions = self.embed_positions(input_ids, positions)
hidden_states = inputs_embeds + positions
for idx, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
kv_cache=kv_caches[idx],
output_attentions=output_attentions,
attn_metadata=attn_metadata
)

hidden_states = self.layer_norm(hidden_states)
return hidden_states