From b5aaf875090388e2bbdbf2d8641ed7967365f435 Mon Sep 17 00:00:00 2001 From: CalOmnie Date: Thu, 23 Jan 2025 19:22:32 +0100 Subject: [PATCH] Fix `test_pipelines_video_classification` that was always failing (#35842) * Fix test_pipelines_video_classification that was always failing * Update video pipeline docstring to reflect actual return type --------- Co-authored-by: Louis Groux --- src/transformers/pipelines/video_classification.py | 6 +++--- tests/pipelines/test_pipelines_video_classification.py | 9 ++++----- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/transformers/pipelines/video_classification.py b/src/transformers/pipelines/video_classification.py index 057910098da2..616eb8def780 100644 --- a/src/transformers/pipelines/video_classification.py +++ b/src/transformers/pipelines/video_classification.py @@ -106,9 +106,9 @@ def __call__(self, inputs: Union[str, List[str]] = None, **kwargs): post-processing. Return: - A dictionary or a list of dictionaries containing result. If the input is a single video, will return a - dictionary, if the input is a list of several videos, will return a list of dictionaries corresponding to - the videos. + A list of dictionaries or a list of list of dictionaries containing result. If the input is a single video, + will return a list of `top_k` dictionaries, if the input is a list of several videos, will return a list of list of + `top_k` dictionaries corresponding to the videos. The dictionaries contain the following keys: diff --git a/tests/pipelines/test_pipelines_video_classification.py b/tests/pipelines/test_pipelines_video_classification.py index f1ed97ac13df..078e825ef6bc 100644 --- a/tests/pipelines/test_pipelines_video_classification.py +++ b/tests/pipelines/test_pipelines_video_classification.py @@ -91,14 +91,13 @@ def test_small_model_pt(self): ) video_file_path = hf_hub_download(repo_id="nateraw/video-demo", filename="archery.mp4", repo_type="dataset") - outputs = video_classifier(video_file_path, top_k=2) + output = video_classifier(video_file_path, top_k=2) self.assertEqual( - nested_simplify(outputs, decimals=4), + nested_simplify(output, decimals=4), [{"score": 0.5199, "label": "LABEL_0"}, {"score": 0.4801, "label": "LABEL_1"}], ) - for output in outputs: - for element in output: - compare_pipeline_output_to_hub_spec(element, VideoClassificationOutputElement) + for element in output: + compare_pipeline_output_to_hub_spec(element, VideoClassificationOutputElement) outputs = video_classifier( [