Skip to content

Commit

Permalink
fix:missing output_router_logits in SwitchTransformers (#30573)
Browse files Browse the repository at this point in the history
* fix:missing `output_router_logits` in SwitchTransformers

* fix whitespace in blank line
  • Loading branch information
lausannel authored May 2, 2024
1 parent 4ad5ada commit a65da83
Showing 1 changed file with 3 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1721,6 +1721,8 @@ def prepare_inputs_for_generation(

input_ids = input_ids[:, remove_prefix_length:]

output_router_logits = kwargs.get("output_router_logits", True)

return {
"decoder_input_ids": input_ids,
"past_key_values": past_key_values,
Expand All @@ -1730,6 +1732,7 @@ def prepare_inputs_for_generation(
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
"output_router_logits": output_router_logits,
}

def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
Expand Down

0 comments on commit a65da83

Please sign in to comment.