Skip to content

Commit

Permalink
docs: fix return type annotation of get_default_model_revision (#35982
Browse files Browse the repository at this point in the history
)
  • Loading branch information
MarcoGorelli authored Feb 13, 2025
1 parent 6a1ab63 commit 3c912c9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
6 changes: 4 additions & 2 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def get_framework(model, revision: Optional[str] = None):

def get_default_model_and_revision(
targeted_task: Dict, framework: Optional[str], task_options: Optional[Any]
) -> Union[str, Tuple[str, str]]:
) -> Tuple[str, str]:
"""
Select a default model to use for a given task. Defaults to pytorch if ambiguous.
Expand All @@ -401,7 +401,9 @@ def get_default_model_and_revision(
Returns
`str` The model string representing the default model for this pipeline
Tuple:
- `str` The model string representing the default model for this pipeline.
- `str` The revision of the model.
"""
if is_torch_available() and not is_tf_available():
framework = "pt"
Expand Down
6 changes: 4 additions & 2 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ def test_register_pipeline(self):
pipeline_class=PairClassificationPipeline,
pt_model=AutoModelForSequenceClassification if is_torch_available() else None,
tf_model=TFAutoModelForSequenceClassification if is_tf_available() else None,
default={"pt": "hf-internal-testing/tiny-random-distilbert"},
default={"pt": ("hf-internal-testing/tiny-random-distilbert", "2ef615d")},
type="text",
)
assert "custom-text-classification" in PIPELINE_REGISTRY.get_supported_tasks()
Expand All @@ -806,7 +806,9 @@ def test_register_pipeline(self):
self.assertEqual(task_def["tf"], (TFAutoModelForSequenceClassification,) if is_tf_available() else ())
self.assertEqual(task_def["type"], "text")
self.assertEqual(task_def["impl"], PairClassificationPipeline)
self.assertEqual(task_def["default"], {"model": {"pt": "hf-internal-testing/tiny-random-distilbert"}})
self.assertEqual(
task_def["default"], {"model": {"pt": ("hf-internal-testing/tiny-random-distilbert", "2ef615d")}}
)

# Clean registry for next tests.
del PIPELINE_REGISTRY.supported_tasks["custom-text-classification"]
Expand Down

0 comments on commit 3c912c9

Please sign in to comment.