Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix test_pipelines_video_classification that was always failing #35842

Merged
6 changes: 3 additions & 3 deletions src/transformers/pipelines/video_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ def __call__(self, inputs: Union[str, List[str]] = None, **kwargs):
post-processing.

Return:
A dictionary or a list of dictionaries containing result. If the input is a single video, will return a
dictionary, if the input is a list of several videos, will return a list of dictionaries corresponding to
the videos.
A list of dictionaries or a list of list of dictionaries containing result. If the input is a single video,
will return a list of `top_k` dictionaries, if the input is a list of several videos, will return a list of list of
`top_k` dictionaries corresponding to the videos.

The dictionaries contain the following keys:

Expand Down
9 changes: 4 additions & 5 deletions tests/pipelines/test_pipelines_video_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,13 @@ def test_small_model_pt(self):
)

video_file_path = hf_hub_download(repo_id="nateraw/video-demo", filename="archery.mp4", repo_type="dataset")
outputs = video_classifier(video_file_path, top_k=2)
output = video_classifier(video_file_path, top_k=2)
self.assertEqual(
nested_simplify(outputs, decimals=4),
nested_simplify(output, decimals=4),
[{"score": 0.5199, "label": "LABEL_0"}, {"score": 0.4801, "label": "LABEL_1"}],
)
for output in outputs:
for element in output:
compare_pipeline_output_to_hub_spec(element, VideoClassificationOutputElement)
for element in output:
compare_pipeline_output_to_hub_spec(element, VideoClassificationOutputElement)

outputs = video_classifier(
[
Expand Down