diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 375d94043e6c..774f9bf1a2fc 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1721,6 +1721,8 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, remove_prefix_length:] + output_router_logits = kwargs.get("output_router_logits", True) + return { "decoder_input_ids": input_ids, "past_key_values": past_key_values, @@ -1730,6 +1732,7 @@ def prepare_inputs_for_generation( "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, + "output_router_logits": output_router_logits, } def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):