Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyaspimpalgaonkar committed Oct 8, 2024
1 parent a97a8c7 commit 4f439a3
Show file tree
Hide file tree
Showing 9 changed files with 18 additions and 126 deletions.
2 changes: 1 addition & 1 deletion docs/api-reference/openapi.json

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions docs/cookbooks/walkthrough.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ r2r search --query="Who founded Airbnb?" --use-kg-search --kg-search-type=local
```python
client.search("Who founded Airbnb?", kg_search_settings={
"use_kg_search": True,
"kg_search_type": "global",
"kg_search_type": "local",
"kg_search_level": 0, # level of community to search
"max_community_description_length": 65536,
"max_llm_queries_for_global_search": 250,
Expand All @@ -643,7 +643,7 @@ client.search("Who founded Airbnb?", kg_search_settings={
```javascript
await client.search("Who founded Airbnb?", true, {}, 10, false, {}, {
useKgSearch: true,
kgSearchType: "global",
kgSearchType: "local",
kgSearchLevel: "0",
maxCommunityDescriptionLength: 65536,
maxLlmQueriesForGlobalSearch: 250,
Expand All @@ -664,7 +664,7 @@ curl -X POST http://localhost:7272/v2/search \
"query": "Who founded Airbnb?",
"kg_search_settings": {
"use_kg_search": true,
"kg_search_type": "global",
"kg_search_type": "local",
"kg_search_level": "0",
"max_community_description_length": 65536,
"max_llm_queries_for_global_search": 250,
Expand All @@ -682,7 +682,7 @@ curl -X POST http://localhost:7272/v2/search \
Key configurable parameters for knowledge graph search include:

- `use_kg_search`: Enable knowledge graph search.
- `kg_search_type`: Choose between "global" or "local" search.
- `kg_search_type`: "local"
- `kg_search_level`: Specify the level of community to search.
- `entity_types`: List of entity types to include in the search.
- `relationships`: List of relationship types to include in the search.
Expand Down
2 changes: 1 addition & 1 deletion docs/documentation/configuration/rag.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ vector_search_settings = {
# Configure graphRAG search
kg_search_settings = {
"use_kg_search": True,
"kg_search_type": "global",
"kg_search_type": "local",
"kg_search_level": None,
"generation_config": {
"model": "gpt-4",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Knowledge graph search settings can be configured both server-side and at runtim
```python
kg_search_settings = {
"use_kg_search": True,
"kg_search_type": "global",
"kg_search_type": "local",
"kg_search_level": None,
"generation_config": {
"model": "gpt-4",
Expand Down
6 changes: 3 additions & 3 deletions docs/documentation/js-sdk/retrieval.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ const searchResponse = await client.search("What was Uber's profit in 2020?");
Whether to use knowledge graph search.
</ParamField>

<ParamField path="kg_search_type" type="str" default="global">
Type of knowledge graph search. Can be 'global' or 'local'.
<ParamField path="kg_search_type" type="str" default="local">
Type of knowledge graph search. Supported values: "local".
</ParamField>

<ParamField path="kg_search_level" type="Optional[str]" default="None">
Expand Down Expand Up @@ -323,7 +323,7 @@ const ragResponse = await client.rag("What was Uber's profit in 2020?");
Whether to use knowledge graph search.
</ParamField>

<ParamField path="kg_search_type" type="str" default="global">
<ParamField path="kg_search_type" type="str" default="local">
Type of knowledge graph search. Can be 'global' or 'local'.
</ParamField>

Expand Down
6 changes: 3 additions & 3 deletions docs/documentation/python-sdk/retrieval.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ search_response = client.search("What was Uber's profit in 2020?")
Whether to use knowledge graph search.
</ParamField>

<ParamField path="kg_search_type" type="str" default="global">
<ParamField path="kg_search_type" type="str" default="local">
Type of knowledge graph search. Can be 'global' or 'local'.
</ParamField>

Expand Down Expand Up @@ -385,7 +385,7 @@ rag_response = client.rag("What was Uber's profit in 2020?")
Whether to use knowledge graph search.
</ParamField>

<ParamField path="kg_search_type" type="str" default="global">
<ParamField path="kg_search_type" type="str" default="local">
Type of knowledge graph search. Can be 'global' or 'local'.
</ParamField>

Expand Down Expand Up @@ -695,7 +695,7 @@ Note that any of the customization seen in AI powered search and RAG documentati
Whether to use knowledge graph search.
</ParamField>
<ParamField path="kg_search_type" type="str" default="global">
<ParamField path="kg_search_type" type="str" default="local">
Type of knowledge graph search. Can be 'global' or 'local'.
</ParamField>
Expand Down
2 changes: 1 addition & 1 deletion js/sdk/src/models.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ export interface KGSearchSettings {
filters?: Record<string, any>;
selected_collection_ids?: string[];
graphrag_map_system_prompt?: string;
kg_search_type?: "global" | "local";
kg_search_type?: "local" | "local";
kg_search_level?: number | null;
generation_config?: GenerationConfig;
// entity_types?: any[];
Expand Down
113 changes: 3 additions & 110 deletions py/core/pipes/retrieval/kg_search_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,107 +209,6 @@ async def local_search(
},
)

async def global_search(
self,
input: GeneratorPipe.Input,
state: AsyncState,
run_id: UUID,
kg_search_settings: KGSearchSettings,
*args: Any,
**kwargs: Any,
) -> AsyncGenerator[KGSearchResult, None]:
# map reduce
async for message in input.message:
map_responses = []
communities = self.kg_provider.get_communities( # type: ignore
level=kg_search_settings.kg_search_level
)

if len(communities) == 0:
raise R2RException(
"No communities found. Please make sure you have run the KG enrichment step before running the search: r2r create-graph and r2r enrich-graph",
400,
)

async def preprocess_communities(communities):
merged_report = ""
for community in communities:
community_report = community.summary
if (
len(merged_report) + len(community_report)
> kg_search_settings.max_community_description_length
):
yield merged_report.strip()
merged_report = ""
merged_report += community_report + "\n\n"
if merged_report:
yield merged_report.strip()

async def process_community(merged_report):
output = await self.llm_provider.aget_completion(
messages=self.prompt_provider._get_message_payload(
task_prompt_name=self.kg_provider.config.kg_search_settings.graphrag_map_system_prompt,
task_inputs={
"context_data": merged_report,
"input": message,
},
),
generation_config=kg_search_settings.generation_config,
)

return output.choices[0].message.content

preprocessed_reports = [
merged_report
async for merged_report in preprocess_communities(communities)
]

# Use asyncio.gather to process all preprocessed community reports concurrently
logger.info(
f"Processing {len(communities)} communities, {len(preprocessed_reports)} reports, Max LLM queries = {kg_search_settings.max_llm_queries_for_global_search}"
)

map_responses = await asyncio.gather(
*[
process_community(report)
for report in preprocessed_reports[
: kg_search_settings.max_llm_queries_for_global_search
]
]
)
# Filter only the relevant responses
filtered_responses = self.filter_responses(map_responses)

# reducing the outputs
output = await self.llm_provider.aget_completion(
messages=self.prompt_provider._get_message_payload(
task_prompt_name=self.kg_provider.config.kg_search_settings.graphrag_reduce_system_prompt,
task_inputs={
"response_type": "multiple paragraphs",
"report_data": filtered_responses,
"input": message,
},
),
generation_config=kg_search_settings.generation_config,
)

output_text = output.choices[0].message.content

if not output_text:
logger.warning(f"No output generated for query: {message}.")
raise R2RException(
"No output generated for query.",
400,
)

yield KGSearchResult(
content=KGGlobalResult(
name="Global Result", description=output_text
),
method=KGSearchMethod.GLOBAL,
metadata={"associated_query": message},
)

async def _run_logic( # type: ignore
self,
input: GeneratorPipe.Input,
Expand All @@ -321,17 +220,11 @@ async def _run_logic( # type: ignore
) -> AsyncGenerator[KGSearchResult, None]:
kg_search_type = kg_search_settings.kg_search_type

# runs local and/or global search
if kg_search_type == "local" or kg_search_type == "local_and_global":
if kg_search_type == "local":
logger.info("Performing KG local search")
async for result in self.local_search(
input, state, run_id, kg_search_settings
):
yield result

if kg_search_type == "global" or kg_search_type == "local_and_global":
logger.info("Performing KG global search")
async for result in self.global_search(
input, state, run_id, kg_search_settings
):
yield result
else:
raise ValueError(f"Unsupported KG search type: {kg_search_type}")
3 changes: 1 addition & 2 deletions py/shared/abstractions/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ class KGSearchResultType(str, Enum):

class KGSearchMethod(str, Enum):
LOCAL = "local"
GLOBAL = "global"


class KGEntityResult(R2RSerializable):
Expand Down Expand Up @@ -357,7 +356,7 @@ class Config:
json_encoders = {UUID: str}
json_schema_extra = {
"use_kg_search": True,
"kg_search_type": "global",
"kg_search_type": "local",
"kg_search_level": "0",
"generation_config": GenerationConfig.Config.json_schema_extra,
"max_community_description_length": 65536,
Expand Down

0 comments on commit 4f439a3

Please sign in to comment.