Skip to content

Commit

Permalink
[GemmaModel] fix small typo (#31202)
Browse files Browse the repository at this point in the history
* fixes

* fix-copies
  • Loading branch information
ArthurZucker authored Jun 3, 2024
1 parent 39b2ff6 commit 1749841
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 14 deletions.
8 changes: 4 additions & 4 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)

if not output_attentions:
Expand Down Expand Up @@ -594,7 +594,7 @@ def forward(
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = attn_output.view(bsz, q_len, -1)

attn_output = self.o_proj(attn_output)

Expand Down Expand Up @@ -866,9 +866,9 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

return_legacy_cache = False
return_legacy_cache = False # noqa: F841
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
return_legacy_cache = True # noqa: F841
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

if cache_position is None:
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def forward(

attn_output = attn_output.transpose(1, 2).contiguous()

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = attn_output.reshape(bsz, q_len, -1)

if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
Expand Down Expand Up @@ -467,7 +467,7 @@ def forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)

if not output_attentions:
Expand Down Expand Up @@ -653,7 +653,7 @@ def forward(
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = attn_output.view(bsz, q_len, -1)

attn_output = self.o_proj(attn_output)

Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,6 @@ def forward(
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

cos, sin = self.rotary_emb(value_states, position_ids)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
Expand Down Expand Up @@ -656,7 +655,7 @@ def forward(
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = attn_output.view(bsz, q_len, -1)

attn_output = self.o_proj(attn_output)

Expand Down
15 changes: 10 additions & 5 deletions utils/diff_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,16 @@ def update_body(self, existing_body, new_statements):
Helper method to update the body by removing duplicates before adding new statements.
"""
deduplicated_new_body = []
existing_nodes = {
self.python_module.code_for_node(node).strip() for node in new_statements if isinstance(node, cst.CSTNode)
}
existing_nodes = set()
for node in new_statements:
code = self.python_module.code_for_node(node)
comment_less_code = re.sub(r"#.*", "", code).strip()
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
existing_nodes.add(comment_less_code)
for stmt in existing_body:
if self.python_module.code_for_node(stmt).strip() not in existing_nodes:
comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip()
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip()
if comment_less_code not in existing_nodes:
if m.matches(stmt, DOCSTRING_NODE) and self.has_docstring:
continue
deduplicated_new_body.append(stmt)
Expand Down Expand Up @@ -542,7 +547,7 @@ def convert_file(diff_file, cst_transformers=None):
parser = argparse.ArgumentParser()
parser.add_argument(
"--files_to_parse",
default=["/Users/arthurzucker/Work/transformers/examples/diff-conversion/diff_my_new_model.py"],
default=["all"],
nargs="+",
help="A list of `diff_xxxx` files that should be converted to single model file",
)
Expand Down

0 comments on commit 1749841

Please sign in to comment.