Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KG tests #1351

Merged
merged 25 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c34a966
cli tests
shreyaspimpalgaonkar Oct 7, 2024
06f52b6
add sdk tests
shreyaspimpalgaonkar Oct 7, 2024
337436c
typo fix
shreyaspimpalgaonkar Oct 7, 2024
86e8145
change workflow ordering
shreyaspimpalgaonkar Oct 7, 2024
18ef93d
add collection integration tests (#1352)
emrgnt-cmplxty Oct 7, 2024
5b9b81a
bump pkg
emrgnt-cmplxty Oct 7, 2024
cb6c436
remove workflows
emrgnt-cmplxty Oct 7, 2024
95050d1
fix sdk test port
emrgnt-cmplxty Oct 7, 2024
6755b1e
fix delete collection return check
emrgnt-cmplxty Oct 7, 2024
9517851
Merge branch 'dev-minor' into add-tests
shreyaspimpalgaonkar Oct 7, 2024
6582f6d
Fix document info serialization (#1353)
NolanTrem Oct 7, 2024
5b4a39e
Update integration-test-workflow-debian.yml
shreyaspimpalgaonkar Oct 7, 2024
a7e793d
pre-commit
shreyaspimpalgaonkar Oct 7, 2024
ed2356a
Merge branch 'add-tests' of https://github.com/SciPhi-AI/R2R into add…
shreyaspimpalgaonkar Oct 7, 2024
5829b71
slightly modify
shreyaspimpalgaonkar Oct 7, 2024
8e84e60
up
shreyaspimpalgaonkar Oct 7, 2024
3735f48
Merge remote-tracking branch 'origin/dev-minor' into add-tests
shreyaspimpalgaonkar Oct 7, 2024
698eea4
up
shreyaspimpalgaonkar Oct 7, 2024
58664b6
smaller file
shreyaspimpalgaonkar Oct 8, 2024
cbd432a
up
shreyaspimpalgaonkar Oct 8, 2024
3dd8a91
typo, change order
shreyaspimpalgaonkar Oct 8, 2024
14783b7
up
shreyaspimpalgaonkar Oct 8, 2024
5a48604
up
shreyaspimpalgaonkar Oct 8, 2024
6617050
change order
shreyaspimpalgaonkar Oct 8, 2024
491b295
Merge branch 'dev-minor' into add-tests
shreyaspimpalgaonkar Oct 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion .github/workflows/integration-test-workflow-debian.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ jobs:
- name: Install Poetry and dependencies
run: |
curl -sSL https://install.python-poetry.org | python3 -
export PATH="/root/.local/bin:$PATH"
cd py && poetry install -E core -E ingestion-bundle

- name: Remove pre-installed PostgreSQL
Expand Down Expand Up @@ -107,6 +108,20 @@ jobs:
poetry run python tests/integration/harness_sdk.py test_hybrid_search_sample_file_filter_sdk
poetry run python tests/integration/harness_sdk.py test_rag_response_sample_file_sdk

- name: Run CLI GraphRAG
working-directory: ./py
run: |
poetry run python tests/integration/harness_cli.py test_kg_create_graph_sample_file_cli
poetry run python tests/integration/harness_cli.py test_kg_enrichment_sample_file_cli
poetry run python tests/integration/harness_cli.py test_kg_search_sample_file_cli

- name: Run SDK GraphRAG
working-directory: ./py
run: |
poetry run python tests/integration/harness_sdk.py test_kg_create_graph_sample_file_sdk
poetry run python tests/integration/harness_sdk.py test_kg_enrich_graph_sample_file_sdk
poetry run python tests/integration/harness_sdk.py test_kg_search_sample_file_sdk

- name: Run SDK Auth
working-directory: ./py
run: |
Expand All @@ -129,4 +144,4 @@ jobs:
sudo apt-get purge -y 'postgresql-*'
sudo rm -rf /var/lib/postgresql
sudo rm -rf /var/log/postgresql
sudo rm -rf /etc/postgresql
sudo rm -rf /etc/postgresql
33 changes: 33 additions & 0 deletions py/core/main/api/kg_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,36 @@ async def get_triples(
limit,
triple_ids,
)


@self.router.get("/communities")
@self.base_endpoint
async def get_communities(
collection_id: UUID = Query(
..., description="Collection ID to retrieve communities from."
),
offset: int = Query(0, ge=0, description="Offset for pagination."),
limit: int = Query(
100, ge=1, le=1000, description="Limit for pagination."
),
levels: Optional[list[int]] = Query(
None, description="Levels to filter by."
),
community_numbers: Optional[list[int]] = Query(
None, description="Community numbers to filter by."
),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
):
"""
Retrieve communities from the knowledge graph.
"""
if not auth_user.is_superuser:
logger.warning("Implement permission checks here.")

return await self.service.get_communities(
collection_id,
offset,
limit,
levels,
community_numbers,
)
19 changes: 19 additions & 0 deletions py/core/main/services/kg_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,22 @@ async def get_triples(
limit,
triple_ids,
)


@telemetry_event("get_communities")
async def get_communities(
self,
collection_id: UUID,
offset: int = 0,
limit: int = 100,
levels: Optional[list[int]] = None,
community_numbers: Optional[list[int]] = None,
**kwargs,
):
return await self.providers.kg.get_communities(
collection_id,
offset,
limit,
levels,
community_numbers,
)
19 changes: 19 additions & 0 deletions py/core/providers/kg/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,25 @@ async def add_communities(self, communities: List[Any]) -> None:
"""
await self.execute_many(QUERY, communities)

async def get_communities(
self, collection_id: UUID, offset: int = 0, limit: int = 100, levels: Optional[list[int]] = None, community_numbers: Optional[list[int]] = None
) -> List[CommunityReport]:

query_parts = [f"SELECT * FROM {self._get_table_name('community_report')} WHERE collection_id = $1 ORDER BY community_number LIMIT $2 OFFSET $3"]
params = [collection_id, limit, offset]

if levels is not None:
query_parts.append(f"AND level = ANY(${len(params) + 1})")
params.append(levels)

if community_numbers is not None:
query_parts.append(f"AND community_number = ANY(${len(params) + 1})")
params.append(community_numbers)

QUERY = " ".join(query_parts)

return await self.fetch_query(QUERY, params)

async def add_community_report(
self, community_report: CommunityReport
) -> None:
Expand Down
33 changes: 33 additions & 0 deletions py/sdk/kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,36 @@ async def get_triples(
params["triple_ids"] = ",".join(triple_ids)

return await client._make_request("GET", "triples", params=params)


@staticmethod
async def get_communities(
client,
collection_id: str,
offset: int = 0,
limit: int = 100,
levels: Optional[list[int]] = None,
community_numbers: Optional[list[int]] = None,
) -> dict:
"""
Retrieve communities from the knowledge graph.

Args:
collection_id (str): The ID of the collection to retrieve communities from.

Returns:
dict: A dictionary containing the retrieved communities.
"""
params = {
"collection_id": collection_id,
"offset": offset,
"limit": limit,
}

if levels:
params["levels"] = ",".join(levels)
if community_numbers:
params["community_numbers"] = ",".join(community_numbers)

return await client._make_request("GET", "communities", params=params)

87 changes: 86 additions & 1 deletion py/tests/integration/harness_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import json
import subprocess
import sys

import time
import requests
import re

def compare_result_fields(result, expected_fields):
for field, expected_value in expected_fields.items():
Expand Down Expand Up @@ -236,10 +238,93 @@ def test_rag_response_stream_sample_file_cli():
print("~" * 100)


def test_kg_create_graph_sample_file_cli():
print("Testing: KG create graph")
run_command("poetry run r2r kg create-graph --run")

response = requests.get("http://localhost:7272/v2/entities", params={"collection_id": "122fdf6a-e116-546b-a8f6-e4cb2e2c0a09"})

if response.status_code != 200:
print("KG create graph test failed: Graph not created")
sys.exit(1)

entities_list = [ele["name"] for ele in response.json()["results"]["results"]]

assert "ARISTOTLE" in entities_list

print("KG create graph test passed")
print("~" * 100)

def test_kg_enrich_graph_sample_file_cli():
print("Testing: KG enrich graph")
run_command("poetry run r2r kg enrich-graph --run")

response = requests.get("http://localhost:7272/v2/communities", params={"collection_id": "122fdf6a-e116-546b-a8f6-e4cb2e2c0a09"})

if response.status_code != 200:
print("KG enrichment test failed: Communities not created")
sys.exit(1)

communities = response.json()["results"]
assert len(communities) >= 10

for community in communities:
assert "community_number" in community
assert "level" in community
assert "collection_id" in community
assert "name" in community
assert "summary" in community
assert "findings" in community

print("KG enrichment test passed")
print("~" * 100)

def test_kg_search_sample_file_cli():
print("Testing: KG search")

output = run_command("poetry run r2r search --query='Who was aristotle?' --use-kg-search")

output_lines = output.strip().split("\n")
results = []
for line in output_lines:
line = line.strip()

try:
result = json.loads(line)
results.append(result)
except json.JSONDecodeError as e:
results.append(line)
continue

if not results:
print("KG search test failed: No results returned")
sys.exit(1)

# there should be vector search and KG search results
kg_search_result_present = False
entities_found = False
communities_found = False
for result in results:
if "{'method': 'local'" in result:
kg_search_result_present = True
if 'entity' in result:
entities_found = True
if 'community' in result:
communities_found = True

assert kg_search_result_present, "No KG search result present"
assert entities_found, "No entities found"
assert communities_found, "No communities found"

print("KG search test passed")
print("~" * 100)


if __name__ == "__main__":
if len(sys.argv) < 2:
print("Please specify a test function to run")
sys.exit(1)

test_function = sys.argv[1]
globals()[test_function]()

61 changes: 61 additions & 0 deletions py/tests/integration/harness_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,67 @@ def test_superuser_capabilities():
print("~" * 100)


def test_kg_create_graph_sample_file_sdk():
print("Testing: KG create graph")

create_graph_result = client.create_graph(collection_id="122fdf6a-e116-546b-a8f6-e4cb2e2c0a09", run_type="run")

result = client.get_entities(collection_id="122fdf6a-e116-546b-a8f6-e4cb2e2c0a09")

entities_list = [ele["name"] for ele in result["results"]["results"]]

assert "ARISTOTLE" in entities_list

print("KG create graph test passed")
print("~" * 100)


def test_kg_enrich_graph_sample_file_sdk():
print("Testing: KG enrich graph")

enrich_graph_result = client.enrich_graph(collection_id="122fdf6a-e116-546b-a8f6-e4cb2e2c0a09", run_type="run")

result = client.get_communities(collection_id="122fdf6a-e116-546b-a8f6-e4cb2e2c0a09")

communities = result["results"]
assert len(communities) >= 10

for community in communities:
assert "community_number" in community
assert "level" in community
assert "collection_id" in community
assert "name" in community

print("KG enrich graph test passed")
print("~" * 100)

def test_kg_search_sample_file_sdk():
print("Testing: KG search")

output = client.search(query="Who was aristotle?", kg_search_settings={"use_kg_search": True})

kg_search_results = output["results"]["kg_search_results"]
assert len(kg_search_results) >= 1

kg_search_result_present = False
entities_found = False
communities_found = False
for result in kg_search_results:
if 'method' in result and result['method'] == 'local':
kg_search_result_present = True
if 'result_type' in result and result['result_type'] == 'entity':
entities_found = True
if 'result_type' in result and result['result_type'] == 'community':
communities_found = True

assert kg_search_result_present, "No KG search result present"
assert entities_found, "No entities found"
assert communities_found, "No communities found"


print("KG search test passed")
print("~" * 100)

if __name__ == "__main__":
if len(sys.argv) < 2:
print("Please specify a test function to run")
Expand Down
Loading