Skip to content

Commit

Permalink
Fix pipeline task dropping arguments bug (#828)
Browse files Browse the repository at this point in the history
* test

* fix

* shorter workflow

* fix task bug

* fix yml
  • Loading branch information
fxmarty authored Feb 28, 2023
1 parent 536ccac commit 9d76da2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
12 changes: 6 additions & 6 deletions optimum/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,15 +327,15 @@ def pipeline(

no_feature_extractor_tasks = set()
no_tokenizer_tasks = set()
for task, values in supported_tasks.items():
for _task, values in supported_tasks.items():
if values["type"] == "text":
no_feature_extractor_tasks.add(task)
no_feature_extractor_tasks.add(_task)
elif values["type"] in {"image", "video"}:
no_tokenizer_tasks.add(task)
no_tokenizer_tasks.add(_task)
elif values["type"] in {"audio"}:
no_tokenizer_tasks.add(task)
no_tokenizer_tasks.add(_task)
elif values["type"] not in ["multimodal", "audio", "video"]:
raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}")
raise ValueError(f"SUPPORTED_TASK {_task} contains invalid type {values['type']}")

# copied from transformers.pipelines.__init__.py l.609
if targeted_task in no_tokenizer_tasks:
Expand Down Expand Up @@ -372,7 +372,7 @@ def pipeline(
feature_extractor = get_preprocessor(model_id)

return transformers_pipeline(
targeted_task,
task,
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
Expand Down
5 changes: 1 addition & 4 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2917,10 +2917,7 @@ def test_pipeline_text_generation(self, test_name: str, model_arch: str, use_cac
# Translation
pipe = pipeline("translation_en_to_de", model=onnx_model, tokenizer=tokenizer)
text = "This is a test"
if model_arch in ["m2m_100", "mbart"]:
outputs = pipe(text, src_lang="en", tgt_lang="fr")
else:
outputs = pipe(text)
outputs = pipe(text)
self.assertEqual(pipe.device, onnx_model.device)
self.assertIsInstance(outputs[0]["translation_text"], str)

Expand Down

0 comments on commit 9d76da2

Please sign in to comment.