Skip to content

Commit

Permalink
Allows sources to be a string for gr.Image (#6378)
Browse files Browse the repository at this point in the history
* fixes

* components

* add changeset

* sources

* video

* add changeset

* fixes

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
abidlabs and gradio-pr-bot authored Nov 11, 2023
1 parent c55f927 commit d31d8c6
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 18 deletions.
5 changes: 5 additions & 0 deletions .changeset/eleven-windows-relate.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
---

fix:Allows `sources` to be a string for `gr.Image`
14 changes: 8 additions & 6 deletions gradio/components/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,21 @@ def __init__(
waveform_options: A dictionary of options for the waveform display. Options include: waveform_color (str), waveform_progress_color (str), show_controls (bool), skip_length (int). Default is None, which uses the default values for these options.
"""
valid_sources: list[Literal["upload", "microphone"]] = ["upload", "microphone"]

if sources is None:
sources = ["microphone"] if streaming else valid_sources
self.sources = ["microphone"] if streaming else valid_sources
elif isinstance(sources, str) and sources in valid_sources:
sources = [sources]
self.sources = [sources]
elif isinstance(sources, list):
pass
self.sources = sources
else:
raise ValueError(
f"`sources` must be a list consisting of elements in {valid_sources}"
)

self.sources = sources
for source in self.sources:
if source not in valid_sources:
raise ValueError(
f"`sources` must a list consisting of elements in {valid_sources}"
)
valid_types = ["numpy", "filepath"]
if type not in valid_types:
raise ValueError(
Expand Down
2 changes: 0 additions & 2 deletions gradio/components/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ def __init__(
raise ValueError(
f"`sources` must a list consisting of elements in {valid_sources}"
)
self.sources = sources

self.streaming = streaming
self.show_download_button = show_download_button
if streaming and self.sources != ["webcam"]:
Expand Down
20 changes: 11 additions & 9 deletions gradio/components/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,27 +111,29 @@ def __init__(
min_length: The minimum length of video (in seconds) that the user can pass into the prediction function. If None, there is no minimum length.
max_length: The maximum length of video (in seconds) that the user can pass into the prediction function. If None, there is no maximum length.
"""
self.format = format
self.autoplay = autoplay

valid_sources: list[Literal["upload", "webcam"]] = ["webcam", "upload"]

if sources is None:
sources = valid_sources
self.sources = valid_sources
elif isinstance(sources, str) and sources in valid_sources:
sources = [sources]
self.sources = [sources]
elif isinstance(sources, list):
pass
self.sources = sources
else:
raise ValueError(
f"`sources` must be a list consisting of elements in {valid_sources}"
)
self.sources = sources
for source in self.sources:
if source not in valid_sources:
raise ValueError(
f"`sources` must a list consisting of elements in {valid_sources}"
)
self.format = format
self.autoplay = autoplay
self.height = height
self.width = width
self.mirror_webcam = mirror_webcam
self.include_audio = (
include_audio if include_audio is not None else "upload" in sources
include_audio if include_audio is not None else "upload" in self.sources
)
self.show_share_button = (
(utils.get_space() is not None)
Expand Down
4 changes: 3 additions & 1 deletion test/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ def test_component_functions(self, gradio_temp_dir):
image_input = gr.Image(type="pil", label="Upload Your Image")
assert image_input.get_config() == {
"image_mode": "RGB",
"sources": None,
"sources": ["upload", "webcam", "clipboard"],
"name": "image",
"show_share_button": False,
"show_download_button": True,
Expand Down Expand Up @@ -604,6 +604,8 @@ def test_component_functions(self, gradio_temp_dir):
with pytest.raises(ValueError):
gr.Image(type="unknown")

string_source = gr.Image(sources="upload")
assert string_source.sources == ["upload"]
# Output functionalities
image_output = gr.Image(type="pil")
processed_image = image_output.postprocess(
Expand Down

0 comments on commit d31d8c6

Please sign in to comment.