Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
caokai committed Jun 12, 2024
1 parent ed7dd0f commit 5054a19
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions vllm/model_executor/models/zhinao.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
Expand Down Expand Up @@ -94,6 +94,7 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
Expand Down Expand Up @@ -139,7 +140,7 @@ def __init__(
self.total_num_heads * self.head_dim,
hidden_size,
bias=bias,
quant_config=quant_config,
quant_config=quant_config
)

self.rotary_emb = get_rope(
Expand All @@ -153,7 +154,8 @@ def __init__(
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window)
cache_config=cache_config,
quant_config=quant_config)

def forward(
self,
Expand All @@ -165,8 +167,7 @@ def forward(
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
self.kv_scale)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output

Expand All @@ -176,6 +177,7 @@ class ZhinaoDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
Expand All @@ -200,6 +202,7 @@ def __init__(
quant_config=quant_config,
bias=attention_bias,
sliding_window=sliding_window,
cache_config=cache_config,
)
self.mlp = ZhinaoMLP(
hidden_size=self.hidden_size,
Expand Down Expand Up @@ -246,6 +249,7 @@ class ZhinaoModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
Expand All @@ -262,7 +266,7 @@ def __init__(
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
ZhinaoDecoderLayer(config, quant_config)
ZhinaoDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -322,12 +326,16 @@ class ZhinaoForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.model = ZhinaoModel(config, quant_config, lora_config=lora_config)
self.model = ZhinaoModel(config,
cache_config,
quant_config,
lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
Expand Down

0 comments on commit 5054a19

Please sign in to comment.