Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KG tests #1351

Merged
merged 25 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c34a966
cli tests
shreyaspimpalgaonkar Oct 7, 2024
06f52b6
add sdk tests
shreyaspimpalgaonkar Oct 7, 2024
337436c
typo fix
shreyaspimpalgaonkar Oct 7, 2024
86e8145
change workflow ordering
shreyaspimpalgaonkar Oct 7, 2024
18ef93d
add collection integration tests (#1352)
emrgnt-cmplxty Oct 7, 2024
5b9b81a
bump pkg
emrgnt-cmplxty Oct 7, 2024
cb6c436
remove workflows
emrgnt-cmplxty Oct 7, 2024
95050d1
fix sdk test port
emrgnt-cmplxty Oct 7, 2024
6755b1e
fix delete collection return check
emrgnt-cmplxty Oct 7, 2024
9517851
Merge branch 'dev-minor' into add-tests
shreyaspimpalgaonkar Oct 7, 2024
6582f6d
Fix document info serialization (#1353)
NolanTrem Oct 7, 2024
5b4a39e
Update integration-test-workflow-debian.yml
shreyaspimpalgaonkar Oct 7, 2024
a7e793d
pre-commit
shreyaspimpalgaonkar Oct 7, 2024
ed2356a
Merge branch 'add-tests' of https://github.com/SciPhi-AI/R2R into add…
shreyaspimpalgaonkar Oct 7, 2024
5829b71
slightly modify
shreyaspimpalgaonkar Oct 7, 2024
8e84e60
up
shreyaspimpalgaonkar Oct 7, 2024
3735f48
Merge remote-tracking branch 'origin/dev-minor' into add-tests
shreyaspimpalgaonkar Oct 7, 2024
698eea4
up
shreyaspimpalgaonkar Oct 7, 2024
58664b6
smaller file
shreyaspimpalgaonkar Oct 8, 2024
cbd432a
up
shreyaspimpalgaonkar Oct 8, 2024
3dd8a91
typo, change order
shreyaspimpalgaonkar Oct 8, 2024
14783b7
up
shreyaspimpalgaonkar Oct 8, 2024
5a48604
up
shreyaspimpalgaonkar Oct 8, 2024
6617050
change order
shreyaspimpalgaonkar Oct 8, 2024
491b295
Merge branch 'dev-minor' into add-tests
shreyaspimpalgaonkar Oct 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions .github/workflows/integration-test-workflow-debian.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,22 @@ jobs:
echo "Waiting for services to start..."
sleep 30

- name: Run SDK GraphRAG
working-directory: ./py
run: |
poetry run python tests/integration/harness_sdk.py test_remove_all_files_and_ingest_sample_file_sdk
poetry run python tests/integration/harness_sdk.py test_kg_create_graph_sample_file_sdk
poetry run python tests/integration/harness_sdk.py test_kg_enrich_graph_sample_file_sdk
poetry run python tests/integration/harness_sdk.py test_kg_search_sample_file_sdk

- name: Run CLI GraphRAG
working-directory: ./py
run: |
poetry run python tests/integration/harness_sdk.py test_remove_all_files_and_ingest_sample_file_sdk
poetry run python tests/integration/harness_cli.py test_kg_create_graph_sample_file_cli
poetry run python tests/integration/harness_cli.py test_kg_enrichment_sample_file_cli
poetry run python tests/integration/harness_cli.py test_kg_search_sample_file_cli

- name: Run CLI Ingestion
working-directory: ./py
run: |
Expand Down Expand Up @@ -117,8 +133,29 @@ jobs:
poetry run python tests/integration/harness_sdk.py test_user_search_and_rag
poetry run python tests/integration/harness_sdk.py test_user_password_management
poetry run python tests/integration/harness_sdk.py test_user_profile_management
poetry run python tests/integration/harness_sdk.py test_user_overview
poetry run python tests/integration/harness_sdk.py test_user_logout

- name: Run Collections
working-directory: ./py
run: |
poetry run python tests/integration/harness_sdk.py test_user_creates_collection
poetry run python tests/integration/harness_sdk.py test_user_updates_collection
poetry run python tests/integration/harness_sdk.py test_user_lists_collections
poetry run python tests/integration/harness_sdk.py test_user_collection_document_management
poetry run python tests/integration/harness_sdk.py test_user_removes_document_from_collection
poetry run python tests/integration/harness_sdk.py test_user_lists_documents_in_collection
poetry run python tests/integration/harness_sdk.py test_pagination_and_filtering
poetry run python tests/integration/harness_sdk.py test_advanced_collection_management
poetry run python tests/integration/harness_sdk.py test_user_gets_collection_details
poetry run python tests/integration/harness_sdk.py test_user_adds_user_to_collection
poetry run python tests/integration/harness_sdk.py test_user_removes_user_from_collection
poetry run python tests/integration/harness_sdk.py test_user_lists_users_in_collection
poetry run python tests/integration/harness_sdk.py test_user_gets_collections_for_user
poetry run python tests/integration/harness_sdk.py test_user_gets_collections_for_document
poetry run python tests/integration/harness_sdk.py test_user_permissions


- name: Stop R2R server
if: always()
run: ps aux | grep "r2r serve" | awk '{print $2}' | xargs kill || true
Expand Down
1 change: 1 addition & 0 deletions py/cli/commands/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def ingest_files_from_urls(client, urls):

files_to_ingest.append(temp_file.name)
metadatas.append({"title": filename})
# TODO: use the utils function generate_document_id
document_ids.append(uuid.uuid5(uuid.NAMESPACE_DNS, url))

response = client.ingest_files(
Expand Down
2 changes: 1 addition & 1 deletion py/cli/utils/docker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def check_external_ollama(ollama_url="http://localhost:11434/api/version"):
def check_set_docker_env_vars(exclude_postgres: bool = False):

env_vars = {
"R2R_PROJECT_NAME": "r2r",
# "R2R_PROJECT_NAME": "r2r",
}

if not exclude_postgres:
Expand Down
12 changes: 12 additions & 0 deletions py/core/base/providers/kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ async def add_kg_extractions(
"""Abstract method to add KG extractions."""
pass

@abstractmethod
async def get_communities(
self,
collection_id: UUID,
offset: int,
limit: int,
levels: list[int] | None = None,
community_numbers: list[int] | None = None,
) -> list[CommunityReport]:
"""Abstract method to get communities."""
pass

@abstractmethod
async def get_entities(
self,
Expand Down
32 changes: 32 additions & 0 deletions py/core/main/api/kg_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,35 @@ async def get_triples(
limit,
triple_ids,
)

@self.router.get("/communities")
@self.base_endpoint
async def get_communities(
collection_id: UUID = Query(
..., description="Collection ID to retrieve communities from."
),
offset: int = Query(0, ge=0, description="Offset for pagination."),
limit: int = Query(
100, ge=1, le=1000, description="Limit for pagination."
),
levels: Optional[list[int]] = Query(
None, description="Levels to filter by."
),
community_numbers: Optional[list[int]] = Query(
None, description="Community numbers to filter by."
),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
):
"""
Retrieve communities from the knowledge graph.
"""
if not auth_user.is_superuser:
logger.warning("Implement permission checks here.")

return await self.service.get_communities(
collection_id,
offset,
limit,
levels,
community_numbers,
)
25 changes: 21 additions & 4 deletions py/core/main/api/management_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,24 @@ async def document_chunks_app(
is_owner = str(document_chunks_result[0].get("user_id")) == str(
auth_user.id
)
document_collections = await self.service.document_collections(
document_uuid, 0, 1
)

if not is_owner and not auth_user.is_superuser:
user_has_access = (
is_owner
or set(auth_user.collection_ids).intersection(
set(
[
ele.collection_id
for ele in document_collections["results"]
]
)
)
!= set()
)

if not user_has_access and not auth_user.is_superuser:
raise R2RException(
"Only a superuser can arbitrarily call document_chunks.",
403,
Expand Down Expand Up @@ -512,7 +528,8 @@ async def list_collections_app(
) -> WrappedCollectionListResponse:
if not auth_user.is_superuser:
raise R2RException(
"Only a superuser can list all collections.", 403
"Only a superuser can call the list collections endpoint.",
403,
)
list_collections_response = await self.service.list_collections(
offset=offset, limit=min(max(limit, 1), 1000)
Expand Down Expand Up @@ -543,7 +560,7 @@ async def add_user_to_collection_app(
result = await self.service.add_user_to_collection(
user_uuid, collection_uuid
)
return {"result": result} # type: ignore
return result # type: ignore

@self.router.post("/remove_user_from_collection")
@self.base_endpoint
Expand Down Expand Up @@ -611,7 +628,7 @@ async def get_collections_for_user_app(
),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
) -> WrappedUserCollectionResponse:
if str(auth_user.id) != user_id or not auth_user.is_superuser:
if str(auth_user.id) != user_id and not auth_user.is_superuser:
raise R2RException(
"The currently authenticated user does not have access to the specified collection.",
403,
Expand Down
2 changes: 1 addition & 1 deletion py/core/main/orchestration/hatchet/ingestion_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ async def embed(self, context: Context) -> dict:

return {
"status": "Successfully finalized ingestion",
"document_info": document_info.model_dump(),
"document_info": document_info.to_dict(),
}

@orchestration_provider.failure()
Expand Down
18 changes: 18 additions & 0 deletions py/core/main/services/kg_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,21 @@ async def get_triples(
limit,
triple_ids,
)

@telemetry_event("get_communities")
async def get_communities(
self,
collection_id: UUID,
offset: int = 0,
limit: int = 100,
levels: Optional[list[int]] = None,
community_numbers: Optional[list[int]] = None,
**kwargs,
):
return await self.providers.kg.get_communities(
collection_id,
offset,
limit,
levels,
community_numbers,
)
2 changes: 1 addition & 1 deletion py/core/pipes/retrieval/kg_search_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ async def global_search(
# map reduce
async for message in input.message:
map_responses = []
communities = self.kg_provider.get_communities( # type: ignore
communities = await self.kg_provider.get_communities( # type: ignore
level=kg_search_settings.kg_search_level
)

Expand Down
3 changes: 2 additions & 1 deletion py/core/providers/database/relational.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging
from contextlib import asynccontextmanager
import asyncio

import asyncpg

from core.base import RelationalDBProvider
Expand Down
28 changes: 28 additions & 0 deletions py/core/providers/kg/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,34 @@ async def add_communities(self, communities: List[Any]) -> None:
"""
await self.execute_many(QUERY, communities)

async def get_communities(
self,
collection_id: UUID,
offset: int = 0,
limit: int = 100,
levels: Optional[list[int]] = None,
community_numbers: Optional[list[int]] = None,
) -> List[CommunityReport]:

query_parts = [
f"SELECT * FROM {self._get_table_name('community_report')} WHERE collection_id = $1 ORDER BY community_number LIMIT $2 OFFSET $3"
]
params = [collection_id, limit, offset]

if levels is not None:
query_parts.append(f"AND level = ANY(${len(params) + 1})")
params.append(levels)

if community_numbers is not None:
query_parts.append(
f"AND community_number = ANY(${len(params) + 1})"
)
params.append(community_numbers)

QUERY = " ".join(query_parts)

return await self.fetch_query(QUERY, params)

async def add_community_report(
self, community_report: CommunityReport
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion py/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "r2r"
readme = "README.md"
version = "3.2.8"
version = "3.2.9"

description = "SciPhi R2R"
authors = ["Owen Colegrove <owen@sciphi.ai>"]
Expand Down
1 change: 0 additions & 1 deletion py/sdk/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ async def update_user(
"bio": bio,
"profile_picture": profile_picture,
}
print("data = ", data)
data = {k: v for k, v in data.items() if v is not None}
return await client._make_request("PUT", "user", json=data)

Expand Down
31 changes: 31 additions & 0 deletions py/sdk/kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,34 @@ async def get_triples(
params["triple_ids"] = ",".join(triple_ids)

return await client._make_request("GET", "triples", params=params)

@staticmethod
async def get_communities(
client,
collection_id: str,
offset: int = 0,
limit: int = 100,
levels: Optional[list[int]] = None,
community_numbers: Optional[list[int]] = None,
) -> dict:
"""
Retrieve communities from the knowledge graph.

Args:
collection_id (str): The ID of the collection to retrieve communities from.

Returns:
dict: A dictionary containing the retrieved communities.
"""
params = {
"collection_id": collection_id,
"offset": offset,
"limit": limit,
}

if levels:
params["levels"] = levels
if community_numbers:
params["community_numbers"] = community_numbers

return await client._make_request("GET", "communities", params=params)
4 changes: 2 additions & 2 deletions py/shared/api/models/management/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,10 @@ class AddUserResponse(BaseModel):
WrappedCollectionOverviewResponse = ResultsWrapper[
list[CollectionOverviewResponse]
]
WrappedAddUserResponse = ResultsWrapper[AddUserResponse]
WrappedAddUserResponse = ResultsWrapper[None]
WrappedUsersInCollectionResponse = PaginatedResultsWrapper[list[UserResponse]]
WrappedUserCollectionResponse = PaginatedResultsWrapper[
list[CollectionOverviewResponse]
list[CollectionResponse]
]
WrappedDocumentChunkResponse = PaginatedResultsWrapper[
list[DocumentChunkResponse]
Expand Down
Loading
Loading