diff --git a/.changeset/eleven-windows-relate.md b/.changeset/eleven-windows-relate.md new file mode 100644 index 0000000000000..fd40f3187774b --- /dev/null +++ b/.changeset/eleven-windows-relate.md @@ -0,0 +1,5 @@ +--- +"gradio": minor +--- + +fix:Allows `sources` to be a string for `gr.Image` diff --git a/gradio/components/audio.py b/gradio/components/audio.py index 61e98a7313028..7e46c2eec8980 100644 --- a/gradio/components/audio.py +++ b/gradio/components/audio.py @@ -108,19 +108,21 @@ def __init__( waveform_options: A dictionary of options for the waveform display. Options include: waveform_color (str), waveform_progress_color (str), show_controls (bool), skip_length (int). Default is None, which uses the default values for these options. """ valid_sources: list[Literal["upload", "microphone"]] = ["upload", "microphone"] - if sources is None: - sources = ["microphone"] if streaming else valid_sources + self.sources = ["microphone"] if streaming else valid_sources elif isinstance(sources, str) and sources in valid_sources: - sources = [sources] + self.sources = [sources] elif isinstance(sources, list): - pass + self.sources = sources else: raise ValueError( f"`sources` must be a list consisting of elements in {valid_sources}" ) - - self.sources = sources + for source in self.sources: + if source not in valid_sources: + raise ValueError( + f"`sources` must a list consisting of elements in {valid_sources}" + ) valid_types = ["numpy", "filepath"] if type not in valid_types: raise ValueError( diff --git a/gradio/components/image.py b/gradio/components/image.py index 2fc77a9bae13e..e0c22a8788ac4 100644 --- a/gradio/components/image.py +++ b/gradio/components/image.py @@ -116,8 +116,6 @@ def __init__( raise ValueError( f"`sources` must a list consisting of elements in {valid_sources}" ) - self.sources = sources - self.streaming = streaming self.show_download_button = show_download_button if streaming and self.sources != ["webcam"]: diff --git a/gradio/components/video.py b/gradio/components/video.py index f4e0b9009b757..e354cf0e62582 100644 --- a/gradio/components/video.py +++ b/gradio/components/video.py @@ -111,27 +111,29 @@ def __init__( min_length: The minimum length of video (in seconds) that the user can pass into the prediction function. If None, there is no minimum length. max_length: The maximum length of video (in seconds) that the user can pass into the prediction function. If None, there is no maximum length. """ - self.format = format - self.autoplay = autoplay - valid_sources: list[Literal["upload", "webcam"]] = ["webcam", "upload"] - if sources is None: - sources = valid_sources + self.sources = valid_sources elif isinstance(sources, str) and sources in valid_sources: - sources = [sources] + self.sources = [sources] elif isinstance(sources, list): - pass + self.sources = sources else: raise ValueError( f"`sources` must be a list consisting of elements in {valid_sources}" ) - self.sources = sources + for source in self.sources: + if source not in valid_sources: + raise ValueError( + f"`sources` must a list consisting of elements in {valid_sources}" + ) + self.format = format + self.autoplay = autoplay self.height = height self.width = width self.mirror_webcam = mirror_webcam self.include_audio = ( - include_audio if include_audio is not None else "upload" in sources + include_audio if include_audio is not None else "upload" in self.sources ) self.show_share_button = ( (utils.get_space() is not None) diff --git a/test/test_components.py b/test/test_components.py index a3fa9754c659f..ec71e968b0595 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -572,7 +572,7 @@ def test_component_functions(self, gradio_temp_dir): image_input = gr.Image(type="pil", label="Upload Your Image") assert image_input.get_config() == { "image_mode": "RGB", - "sources": None, + "sources": ["upload", "webcam", "clipboard"], "name": "image", "show_share_button": False, "show_download_button": True, @@ -604,6 +604,8 @@ def test_component_functions(self, gradio_temp_dir): with pytest.raises(ValueError): gr.Image(type="unknown") + string_source = gr.Image(sources="upload") + assert string_source.sources == ["upload"] # Output functionalities image_output = gr.Image(type="pil") processed_image = image_output.postprocess(