diff --git a/CHANGELOG.md b/CHANGELOG.md index b3a6c3b83d7c9..fd3df74bca81b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,11 +28,11 @@ No changes to highlight. ## Full Changelog: +- Allow users to submit with enter in Interfaces with textbox / number inputs [@aliabid94](https://github.com/aliabid94) in [PR 4090](https://github.com/gradio-app/gradio/pull/4090). - Updates gradio's requirements.txt to requires uvicorn>=0.14.0 by [@abidlabs](https://github.com/abidlabs) in [PR 4086](https://github.com/gradio-app/gradio/pull/4086) - Updates some error messaging by [@abidlabs](https://github.com/abidlabs) in [PR 4086](https://github.com/gradio-app/gradio/pull/4086) - Renames simplified Chinese translation file from `zh-cn.json` to `zh-CN.json` by [@abidlabs](https://github.com/abidlabs) in [PR 4086](https://github.com/gradio-app/gradio/pull/4086) - ## Contributors Shoutout: No changes to highlight. diff --git a/gradio/interface.py b/gradio/interface.py index f32afa0cc4e4f..5d82e9bd4c7df 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -25,7 +25,7 @@ get_component_instance, ) from gradio.data_classes import InterfaceTypes -from gradio.events import Changeable, Streamable +from gradio.events import Changeable, Streamable, Submittable from gradio.flagging import CSVLogger, FlaggingCallback, FlagMethod from gradio.layouts import Column, Row, Tab, Tabs from gradio.pipelines import load_from_pipeline @@ -625,17 +625,37 @@ def fn(*args): ] + [Button.update(visible=True), Button.update(visible=False)] extra_output = [submit_btn, stop_btn] - pred = submit_btn.click( - fn, - self.input_components, - self.output_components + extra_output, - api_name="predict", - scroll_to_output=True, - preprocess=not (self.api_mode), - postprocess=not (self.api_mode), - batch=self.batch, - max_batch_size=self.max_batch_size, - ) + triggers = [submit_btn.click] + [ + component.submit + for component in self.input_components + if isinstance(component, Submittable) + ] + predict_events = [] + for i, trigger in enumerate(triggers): + predict_events.append( + trigger( + fn, + self.input_components, + self.output_components + extra_output, + api_name="predict" if i == 0 else None, + scroll_to_output=True, + preprocess=not (self.api_mode), + postprocess=not (self.api_mode), + batch=self.batch, + max_batch_size=self.max_batch_size, + ) + ) + if stop_btn: + trigger( + lambda: ( + submit_btn.update(visible=False), + stop_btn.update(visible=True), + ), + inputs=None, + outputs=[submit_btn, stop_btn], + queue=False, + ) + if stop_btn: submit_btn.click( lambda: ( @@ -653,7 +673,7 @@ def fn(*args): ), inputs=None, outputs=[submit_btn, stop_btn], - cancels=[pred], + cancels=predict_events, queue=False, ) diff --git a/test/test_examples.py b/test/test_examples.py index 505417fd4a5e8..3771448e70a2a 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -327,10 +327,10 @@ def concatenate(str1, str2): app, _, _ = io.launch(prevent_thread_lock=True) client = TestClient(app) - response = client.post("/api/predict/", json={"fn_index": 5, "data": [0]}) + response = client.post("/api/predict/", json={"fn_index": 7, "data": [0]}) assert response.json()["data"] == ["Hello,"] - response = client.post("/api/predict/", json={"fn_index": 5, "data": [1]}) + response = client.post("/api/predict/", json={"fn_index": 7, "data": [1]}) assert response.json()["data"] == ["Michael"] def test_end_to_end_cache_examples(self): @@ -348,8 +348,8 @@ def concatenate(str1, str2): app, _, _ = io.launch(prevent_thread_lock=True) client = TestClient(app) - response = client.post("/api/predict/", json={"fn_index": 5, "data": [0]}) + response = client.post("/api/predict/", json={"fn_index": 7, "data": [0]}) assert response.json()["data"] == ["Hello,", "World", "Hello, World"] - response = client.post("/api/predict/", json={"fn_index": 5, "data": [1]}) + response = client.post("/api/predict/", json={"fn_index": 7, "data": [1]}) assert response.json()["data"] == ["Michael", "Jordan", "Michael Jordan"]