Skip to content

Commit

Permalink
[WIP] Add support for Mistral-Nemo by supporting head_dim through con…
Browse files Browse the repository at this point in the history
…fig (#2254)

* Support passing head_dim through config

* Using `head_dim` as a fallback is necessary since it's a non standard
key in mistralConfig (as defined in transformers).

* Shorter diff.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
  • Loading branch information
2 people authored and ErikKaum committed Jul 26, 2024
1 parent 9ae43a1 commit 7c874e5
Showing 1 changed file with 3 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,14 @@ def __init__(self, prefix: str, config, weights, layer_id):
bias=False,
)

head_size = config.hidden_size // config.num_attention_heads
self.query_key_value = TensorParallelMultiAdapterLinear.load(
query_key_value,
layer_id,
["q_proj", "k_proj", "v_proj"],
sizes=[
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
self.head_size * config.num_attention_heads,
self.head_size * config.num_key_value_heads,
self.head_size * config.num_key_value_heads,
],
process_group=weights.process_group,
)
Expand Down

0 comments on commit 7c874e5

Please sign in to comment.