diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 23e1d14114cb..2b42571876d9 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1083,6 +1083,7 @@ "add_end_docstrings", "add_start_docstrings", "is_apex_available", + "is_av_available", "is_bitsandbytes_available", "is_datasets_available", "is_decord_available", @@ -5951,6 +5952,7 @@ add_end_docstrings, add_start_docstrings, is_apex_available, + is_av_available, is_bitsandbytes_available, is_datasets_available, is_decord_available, diff --git a/src/transformers/pipelines/video_classification.py b/src/transformers/pipelines/video_classification.py index f8596ce14c71..5702f23c5f60 100644 --- a/src/transformers/pipelines/video_classification.py +++ b/src/transformers/pipelines/video_classification.py @@ -3,13 +3,19 @@ import requests -from ..utils import add_end_docstrings, is_decord_available, is_torch_available, logging, requires_backends +from ..utils import ( + add_end_docstrings, + is_av_available, + is_torch_available, + logging, + requires_backends, +) from .base import Pipeline, build_pipeline_init_args -if is_decord_available(): +if is_av_available(): + import av import numpy as np - from decord import VideoReader if is_torch_available(): @@ -33,7 +39,7 @@ class VideoClassificationPipeline(Pipeline): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - requires_backends(self, "decord") + requires_backends(self, "av") self.check_model_type(MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES) def _sanitize_parameters(self, top_k=None, num_frames=None, frame_sampling_rate=None): @@ -90,14 +96,13 @@ def preprocess(self, video, num_frames=None, frame_sampling_rate=1): if video.startswith("http://") or video.startswith("https://"): video = BytesIO(requests.get(video).content) - videoreader = VideoReader(video) - videoreader.seek(0) + container = av.open(video) start_idx = 0 end_idx = num_frames * frame_sampling_rate - 1 indices = np.linspace(start_idx, end_idx, num=num_frames, dtype=np.int64) - video = videoreader.get_batch(indices).asnumpy() + video = read_video_pyav(container, indices) video = list(video) model_inputs = self.image_processor(video, return_tensors=self.framework) @@ -120,3 +125,16 @@ def postprocess(self, model_outputs, top_k=5): scores = scores.tolist() ids = ids.tolist() return [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)] + + +def read_video_pyav(container, indices): + frames = [] + container.seek(0) + start_index = indices[0] + end_index = indices[-1] + for i, frame in enumerate(container.decode(video=0)): + if i > end_index: + break + if i >= start_index and i in indices: + frames.append(frame) + return np.stack([x.to_ndarray(format="rgb24") for x in frames]) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 8b7814163739..44e0e1ebb2f6 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -57,6 +57,7 @@ is_aqlm_available, is_auto_awq_available, is_auto_gptq_available, + is_av_available, is_bitsandbytes_available, is_bs4_available, is_cv2_available, @@ -1010,6 +1011,13 @@ def require_aqlm(test_case): return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case) +def require_av(test_case): + """ + Decorator marking a test that requires av + """ + return unittest.skipUnless(is_av_available(), "test requires av")(test_case) + + def require_bitsandbytes(test_case): """ Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library or its hard dependency torch is not installed. diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index b8da221a8c91..a3f596d0e955 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -109,6 +109,7 @@ is_aqlm_available, is_auto_awq_available, is_auto_gptq_available, + is_av_available, is_bitsandbytes_available, is_bs4_available, is_coloredlogs_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 3835831e88a4..9382ca9528da 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -94,6 +94,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True) _apex_available = _is_package_available("apex") _aqlm_available = _is_package_available("aqlm") +_av_available = importlib.util.find_spec("av") is not None _bitsandbytes_available = _is_package_available("bitsandbytes") _galore_torch_available = _is_package_available("galore_torch") # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. @@ -656,6 +657,10 @@ def is_aqlm_available(): return _aqlm_available +def is_av_available(): + return _av_available + + def is_ninja_available(): r""" Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the @@ -1012,6 +1017,16 @@ def is_mlx_available(): return _mlx_available +# docstyle-ignore +AV_IMPORT_ERROR = """ +{0} requires the PyAv library but it was not found in your environment. You can install it with: +``` +pip install av +``` +Please note that you may need to restart your runtime after installation. +""" + + # docstyle-ignore CV2_IMPORT_ERROR = """ {0} requires the OpenCV library but it was not found in your environment. You can install it with: @@ -1336,6 +1351,7 @@ def is_mlx_available(): BACKENDS_MAPPING = OrderedDict( [ + ("av", (is_av_available, AV_IMPORT_ERROR)), ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), ("cv2", (is_cv2_available, CV2_IMPORT_ERROR)), ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)), diff --git a/tests/pipelines/test_pipelines_video_classification.py b/tests/pipelines/test_pipelines_video_classification.py index 33e06e30f5ae..d23916bad84f 100644 --- a/tests/pipelines/test_pipelines_video_classification.py +++ b/tests/pipelines/test_pipelines_video_classification.py @@ -21,7 +21,7 @@ from transformers.testing_utils import ( is_pipeline_test, nested_simplify, - require_decord, + require_av, require_tf, require_torch, require_torch_or_tf, @@ -34,7 +34,7 @@ @is_pipeline_test @require_torch_or_tf @require_vision -@require_decord +@require_av class VideoClassificationPipelineTests(unittest.TestCase): model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING