-
Notifications
You must be signed in to change notification settings - Fork 0
/
gradio_interface.py
58 lines (45 loc) · 2.36 KB
/
gradio_interface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from haystack import Pipeline
from haystack.core.serialization import DeserializationCallbacks
from typing import Type, Dict, Any
import gradio as gr
import os
def component_pre_init_callback(component_name: str, component_cls: Type, init_params: Dict[str, Any]):
# This function gets called every time a component is deserialized.
if component_name == "cleaner":
assert "DocumentCleaner" in component_cls.__name__
# Modify the init parameters. The modified parameters are passed to
# the init method of the component during deserialization.
init_params["remove_empty_lines"] = False
print("Modified 'remove_empty_lines' to False in 'cleaner' component")
else:
print(f"Not modifying component {component_name} of class {component_cls}")
# Load the pipeline from the YAML file
def load_pipeline_from_yaml(yaml_file_path):
with open(yaml_file_path, "r") as stream:
pipeline_yaml = stream.read()
return Pipeline.loads(pipeline_yaml, callbacks=DeserializationCallbacks(component_pre_init_callback))
# Function to interact with the pipeline
def ask_question(question, pipeline):
if question == "" or question is None:
return ""
answer = pipeline.run({
"text_embedder": {"text": question},
"prompt_builder": {"question": question},
"answer_builder": {"query": question}
})
return answer['answer_builder']['answers'][0].data
# Load the pipeline (modify path if necessary)
my_pipeline = load_pipeline_from_yaml('./pipeline.yml')
# Set up Gradio interface with Blocks layout
with gr.Blocks() as interface:
gr.Markdown("# Wire RAG Documentation") # Title
input_box = gr.Textbox(label="Ask your question:", placeholder="Type your question here...", lines=1)
# Move the submit button above the output box
submit_btn = gr.Button("Submit") # Submit button
output_box = gr.Markdown(label="Answer:") # Output field
# Set the click function for the button
submit_btn.click(fn=lambda question: ask_question(question, my_pipeline), inputs=input_box, outputs=output_box)
# Set the enter key behavior for the input box
input_box.submit(fn=lambda question: ask_question(question, my_pipeline), inputs=input_box, outputs=output_box)
# Launch the Gradio app with sharing and authentication
interface.launch(share=False, auth=("user", os.getenv("GRADIO_KEY")))