diff --git a/RAG.py b/RAG.py index 76521aa..e9884c7 100644 --- a/RAG.py +++ b/RAG.py @@ -74,24 +74,23 @@ def query_rag(query): # logic to add sources to the response max_relevant_sources = 4 # number of sources at most to be added to the response all_sources = "" - sources = set() + sources = [] count = 1 - for i in response["context"]: - # limiting the no.of sources to 4 for better readability - if count > max_relevant_sources: - break - else: - source = i.metadata["source"] - if source in sources: + for i in range(max_relevant_sources): + try: + source = response["context"][i].metadata["source"] + # check if the source is already added to the list + if source not in sources: + sources.append(source) + all_sources += f"[Source {count}]({source}), " count += 1 - continue # to remove duplicates in the most relevant sources - sources.add(source) - all_sources += f"[Source {count}]({source}), " - count += 1 - all_sources = all_sources[:-2] + except IndexError: # if there are no more sources to add + break + all_sources = all_sources[:-2] # remove the last comma and space response["answer"] += f"\n\nSources: {all_sources}" print("Response Generated") - return response["answer"], list(sources) + + return response["answer"], sources