From 4095db0cb720099fe77f6d05a10782fa85ab8d7d Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Mon, 9 Sep 2024 11:46:50 -0500 Subject: [PATCH] Add in some missing grok specific model structure and constants --- sharktank/sharktank/layers/paged_llama_attention_block.py | 4 ++++ sharktank/sharktank/models/grok/grok.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 1e5224e89..403ad1abc 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -37,6 +37,7 @@ def __init__( head_count_kv: int, rms_epsilon: float, use_hf: bool = False, + use_grok: bool = False, ): super().__init__(theta) self.add_module( @@ -53,6 +54,7 @@ def __init__( self.head_dim = head_dim self.head_count_kv = head_count_kv self.use_hf = use_hf + self.use_grok = use_grok def forward( self, @@ -141,6 +143,8 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: # Flash attention. attn_weights = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + if self.use_grok: + attn_weights = 30 * torch.tanh(attn_weights * (0.08838834764831845 / 30.0)) self.assert_not_nan(attn_weights) # Apply attention mask. diff --git a/sharktank/sharktank/models/grok/grok.py b/sharktank/sharktank/models/grok/grok.py index 7b971c6e2..def2f0fbc 100644 --- a/sharktank/sharktank/models/grok/grok.py +++ b/sharktank/sharktank/models/grok/grok.py @@ -147,6 +147,8 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): head_dim=hp.attn_head_dim, head_count_kv=hp.attention_head_count_kv, rms_epsilon=hp.attention_layer_norm_rms_epsilon, + use_hf=True, + use_grok=True, ) ) self.attn_blocks.append( @@ -250,6 +252,7 @@ def decode( ) h = self.token_embedding(tokens) + h *= 78.38367176906169 self.trace_tensor("mixtral.token_embedding", h) # Iterate over attention blocks. @@ -278,4 +281,5 @@ def decode( h = self.output_norm(h) logits = self.output_lm_head(h) + logits = logits * 0.5773502691896257 return logits