Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
aliabid94 committed Nov 2, 2023
1 parent afb72bd commit 5a745ed
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 19 deletions.
2 changes: 1 addition & 1 deletion demo/chatbot_streaming/run.ipynb
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_streaming"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import random\n", "import time\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot()\n", " msg = gr.Textbox()\n", " clear = gr.Button(\"Clear\")\n", "\n", " def user(user_message, history):\n", " return \"\", history + [[user_message, None]]\n", "\n", " def bot(history):\n", " bot_message = random.choice([\"How are you?\", \"I love you\", \"I'm very hungry\"])\n", " history[-1][1] = \"\"\n", " for character in bot_message:\n", " history[-1][1] += character\n", " time.sleep(0.05)\n", " yield history\n", "\n", " msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(\n", " bot, chatbot, chatbot\n", " )\n", " clear.click(lambda: None, None, chatbot, queue=False)\n", " \n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_streaming"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import random\n", "import time\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot()\n", " msg = gr.Textbox()\n", " clear = gr.Button(\"Clear\")\n", "\n", " def user(user_message, history):\n", " return \"\", history + [[user_message, None]]\n", "\n", " def bot(history):\n", " bot_message = random.choice([\"How are you?\", \"I love you\", \"I'm very hungry\"])\n", " for character in bot_message:\n", " time.sleep(0.05)\n", " yield character\n", "\n", " msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(\n", " bot, chatbot, chatbot\n", " )\n", " clear.click(lambda: None, None, chatbot, queue=False)\n", " \n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
4 changes: 1 addition & 3 deletions demo/chatbot_streaming/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@ def user(user_message, history):

def bot(history):
bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"])
history[-1][1] = ""
for character in bot_message:
history[-1][1] += character
time.sleep(0.05)
yield history
yield character

msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
Expand Down
14 changes: 11 additions & 3 deletions gradio/components/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ class ChatbotData(GradioRootModel):
root: List[Tuple[Union[str, FileMessage, None], Union[str, FileMessage, None]]]


class ChatbotDataOut(GradioRootModel):
root: Union[
List[Tuple[Union[str, FileMessage, None], Union[str, FileMessage, None]]], str
]


@document()
class Chatbot(Component):
"""
Expand Down Expand Up @@ -189,9 +195,11 @@ def _postprocess_chat_messages(
def postprocess(
self,
value: list[list[str | tuple[str] | tuple[str, str] | None] | tuple],
) -> ChatbotData:
) -> ChatbotDataOut:
if value is None:
return ChatbotData(root=[])
return ChatbotDataOut(root=[])
elif isinstance(value, str):
return ChatbotDataOut(root=value)
processed_messages = []
for message_pair in value:
if not isinstance(message_pair, (tuple, list)):
Expand All @@ -208,7 +216,7 @@ def postprocess(
self._postprocess_chat_messages(message_pair[1]),
]
)
return ChatbotData(root=processed_messages)
return ChatbotDataOut(root=processed_messages)

def example_inputs(self) -> Any:
return [["Hello!", None]]
40 changes: 28 additions & 12 deletions js/chatbot/Index.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
export let elem_id = "";
export let elem_classes: string[] = [];
export let visible = true;
export let value: [
string | { file: FileData; alt_text: string | null } | null,
string | { file: FileData; alt_text: string | null } | null
][] = [];
export let value:
| [
string | { file: FileData; alt_text: string | null } | null,
string | { file: FileData; alt_text: string | null } | null
][]
| string = [];
export let scale: number | null = null;
export let min_width: number | undefined = undefined;
export let label: string;
Expand All @@ -41,7 +43,7 @@
display: boolean;
}[];
export let gradio: Gradio<{
change: typeof value;
change: typeof _value;
select: SelectData;
share: ShareData;
error: string;
Expand All @@ -65,20 +67,34 @@
}
return {
file: normalise_file(message?.file, root, proxy_url) as FileData,
alt_text: message?.alt_text
alt_text: message?.alt_text,
};
}
$: _value = value
? value.map(([user_msg, bot_msg]) => [
$: {
if (value === null) {
_value = [];
} else if (typeof value === "string") {
if (_value.length === 0) {
_value = [null, null];
}
if (_value[_value.length - 1][1] === null) {
_value[_value.length - 1][1] = value;
} else {
_value[_value.length - 1][1] += value;
}
value = _value;
} else {
_value = value.map(([user_msg, bot_msg]) => [
typeof user_msg === "string"
? redirect_src_url(user_msg)
: normalize_messages(user_msg),
typeof bot_msg === "string"
? redirect_src_url(bot_msg)
: normalize_messages(bot_msg)
])
: [];
: normalize_messages(bot_msg),
]);
}
}
export let loading_status: LoadingStatus | undefined = undefined;
export let height = 400;
Expand Down Expand Up @@ -124,7 +140,7 @@
pending_message={loading_status?.status === "pending"}
{rtl}
{show_copy_button}
on:change={() => gradio.dispatch("change", value)}
on:change={() => gradio.dispatch("change", _value)}
on:select={(e) => gradio.dispatch("select", e.detail)}
on:like={(e) => gradio.dispatch("like", e.detail)}
on:share={(e) => gradio.dispatch("share", e.detail)}
Expand Down

0 comments on commit 5a745ed

Please sign in to comment.