Skip to content

Commit

Permalink
Fixed error added past_key_values in the forward method
Browse files Browse the repository at this point in the history
  • Loading branch information
nakranivaibhav committed Jan 31, 2024
1 parent 42f2a03 commit c90f9b3
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,12 +1436,19 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.transformer.embed_tokens

def set_input_embeddings(self, value):
self.transformer.embed_tokens = value

@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
Expand All @@ -1465,6 +1472,7 @@ def forward(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
Expand Down

0 comments on commit c90f9b3

Please sign in to comment.