diff --git a/src/transformers/pipelines/visual_question_answering.py b/src/transformers/pipelines/visual_question_answering.py index f456835d7090..9106b19d3367 100644 --- a/src/transformers/pipelines/visual_question_answering.py +++ b/src/transformers/pipelines/visual_question_answering.py @@ -123,9 +123,9 @@ def preprocess(self, inputs, padding=False, truncation=False, timeout=None): model_inputs.update(image_features) return model_inputs - def _forward(self, model_inputs): + def _forward(self, model_inputs, **generate_kwargs): if self.model.can_generate(): - model_outputs = self.model.generate(**model_inputs) + model_outputs = self.model.generate(**model_inputs, **generate_kwargs) else: model_outputs = self.model(**model_inputs) return model_outputs