-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathsearch_module.py
180 lines (147 loc) · 6.13 KB
/
search_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""
Search module for RL training loop.
This module provides functions to search through vectorized documents and retrieve question-answer pairs.
"""
import pickle
import json
import random
import asyncio
from typing import List, Tuple, Optional, Union, Dict, Any
from enum import Enum
from pydantic import BaseModel
from langchain.vectorstores import FAISS
from datasets import Dataset
from embeddings import CustomHuggingFaceEmbeddings
# Load pre-saved vectorstore
def load_vectorstore():
"""Load the pre-saved FAISS index"""
try:
import os
embeddings = CustomHuggingFaceEmbeddings()
# Load the FAISS index with absolute path
index_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "faiss_index")
print(f"Loading FAISS index from: {index_path}")
vectorstore = FAISS.load_local(index_path, embeddings, allow_dangerous_deserialization=True)
print("Successfully loaded FAISS index")
return vectorstore
except Exception as e:
print(f"Error loading vectorstore: {e}")
import traceback
traceback.print_exc()
return None
# Load the vectorstore when module is imported
try:
vectorstore = load_vectorstore()
if vectorstore is None:
print("Warning: FAISS vectorstore could not be loaded.")
except Exception as e:
print(f"Error loading vectorstore: {e}")
vectorstore = None
def search(query: str, return_type=str, results: int = 5) -> Union[str, List[str]]:
"""
Search for relevant chunks using similarity search.
Args:
query: The search query
return_type: Return as string or list (default: str)
results: Number of results to return (default: 5)
Returns:
Results as string or list depending on return_type
"""
if vectorstore is None:
raise ValueError("Vectorstore not loaded. Please ensure FAISS index exists.")
search_results = vectorstore.similarity_search(query, k=results)
if return_type == str:
str_results = ""
for idx, result in enumerate(search_results, start=1):
str_results += f"Result {idx}:\n"
str_results += result.page_content + "\n"
str_results += "------\n"
return str_results
elif return_type == list:
return [result.page_content for result in search_results]
else:
raise ValueError("Invalid return_type. Use str or list.")
# Load questions from saved data
def load_qa_data():
"""Load the pre-generated questions and document chunks"""
try:
import os
# Get absolute paths to data files
base_dir = os.path.dirname(os.path.abspath(__file__))
chunks_path = os.path.join(base_dir, "saved_data", "chunks.pkl")
questions_path = os.path.join(base_dir, "saved_data", "questions.json")
print(f"Loading chunks from: {chunks_path}")
print(f"Loading questions from: {questions_path}")
# Load the chunks
with open(chunks_path, "rb") as f:
chunks = pickle.load(f)
# Load the questions
with open(questions_path, "r") as f:
questions = json.load(f)
print(f"Successfully loaded {len(chunks)} chunks and {len(questions)} questions")
return chunks, questions
except Exception as e:
print(f"Error loading QA data: {e}")
import traceback
traceback.print_exc()
return None, None
# Load chunks and questions when module is imported
try:
chunks, questions = load_qa_data()
if chunks is None or questions is None:
print("Warning: Could not load QA data.")
except Exception as e:
print(f"Error initializing QA data: {e}")
chunks, questions = None, None
def get_question_answer(idx: Optional[int] = None, return_both: bool = True) -> Union[dict, str]:
"""
Get a question-answer pair either by index or randomly.
Args:
idx: Index of the question to retrieve (if None, selects random question)
return_both: Whether to return both question and answer (default: True)
Returns:
Question and answer as tuple if return_both=True, otherwise just the question
"""
if questions is None:
raise ValueError("Questions not loaded. Please ensure questions.json exists.")
if idx is None:
# Select a random question
qa_pair = random.choice(questions)
elif 0 <= idx < len(questions):
# Select question by index
qa_pair = questions[idx]
else:
raise ValueError(f"Index out of range. Must be between 0 and {len(questions)-1}")
question = qa_pair['question']
answer = qa_pair['answer']
if return_both:
return {"question": question, "answer": answer}
else:
return question
# Function to get the total number of questions
def get_question_count() -> int:
"""Get the total number of available questions"""
if questions is None:
raise ValueError("Questions not loaded. Please ensure questions.json exists.")
return len(questions)
def get_qa_dataset():
"""
Return a HuggingFace Dataset containing question and answer pairs.
This dataset is constructed from the loaded questions data (questions.json).
Each element in the dataset is a dictionary that includes at least:
- "question": The question text.
- "answer": The corresponding answer text.
Additional keys present in the original questions data will also be included.
Returns:
A HuggingFace Dataset object.
"""
if questions is None:
raise ValueError("Questions not loaded. Please ensure questions.json exists.")
qa_dataset = Dataset.from_list(questions)
full_dataset = qa_dataset.shuffle(seed=42)
train_dataset = full_dataset.train_test_split(test_size=0.1, seed=42)['train']
test_dataset = full_dataset.train_test_split(test_size=0.1, seed=42)['test']
# rename the column of the dataset from "question" to "input"
train_dataset = train_dataset.rename_column("question", "prompt")
test_dataset = test_dataset.rename_column("question", "prompt")
return train_dataset, test_dataset