diff --git a/py/Dockerfile b/py/Dockerfile index df143bcc9..0db577a1e 100644 --- a/py/Dockerfile +++ b/py/Dockerfile @@ -49,8 +49,4 @@ COPY r2r.toml /app/r2r.toml COPY pyproject.toml /app/pyproject.toml # Run the application -<<<<<<< HEAD -CMD ["sh", "-c", "uvicorn core.main.app_entry:app --host $HOST --port $R2R_PORT"] -======= CMD ["sh", "-c", "uvicorn core.main.app_entry:app --host $R2R_HOST --port $R2R_PORT"] ->>>>>>> 8ae04c5bfdbeab77073b6ae1169c5bff1b32489b diff --git a/py/core/base/providers/embedding.py b/py/core/base/providers/embedding.py index 1d7b5557a..4798f78a5 100644 --- a/py/core/base/providers/embedding.py +++ b/py/core/base/providers/embedding.py @@ -10,6 +10,7 @@ VectorSearchResult, default_embedding_prefixes, ) + from .base import Provider, ProviderConfig logger = logging.getLogger(__name__) diff --git a/py/core/base/providers/kg.py b/py/core/base/providers/kg.py index 36c82186d..1c3e12a3a 100644 --- a/py/core/base/providers/kg.py +++ b/py/core/base/providers/kg.py @@ -196,6 +196,7 @@ async def get_entity_count( self, collection_id: Optional[UUID] = None, document_id: Optional[UUID] = None, + distinct: bool = False, entity_table_name: str = "entity_embedding", ) -> int: """Abstract method to get the entity count.""" diff --git a/py/core/main/api/ingestion_router.py b/py/core/main/api/ingestion_router.py index c546fec10..91f6106dd 100644 --- a/py/core/main/api/ingestion_router.py +++ b/py/core/main/api/ingestion_router.py @@ -10,6 +10,7 @@ from pydantic import Json from core.base import R2RException, RawChunk, generate_document_id + from core.base.api.models import ( CreateVectorIndexResponse, WrappedCreateVectorIndexResponse, diff --git a/py/core/main/api/kg_router.py b/py/core/main/api/kg_router.py index ed9da0bb2..77f95b149 100644 --- a/py/core/main/api/kg_router.py +++ b/py/core/main/api/kg_router.py @@ -18,7 +18,6 @@ from core.utils import generate_default_user_collection_id from shared.abstractions.kg import KGRunType from shared.utils.base_utils import update_settings_from_dict - from ..services.kg_service import KgService from .base_router import BaseRouter diff --git a/py/core/main/orchestration/hatchet/kg_workflow.py b/py/core/main/orchestration/hatchet/kg_workflow.py index 878879b57..8e1bab8a0 100644 --- a/py/core/main/orchestration/hatchet/kg_workflow.py +++ b/py/core/main/orchestration/hatchet/kg_workflow.py @@ -3,14 +3,17 @@ import logging import math import uuid +import time from hatchet_sdk import ConcurrencyLimitStrategy, Context from core import GenerationConfig from core.base import OrchestrationProvider - +from shared.abstractions.document import KGExtractionStatus from ...services import KgService +from shared.utils import create_hatchet_logger + logger = logging.getLogger(__name__) from typing import TYPE_CHECKING @@ -57,9 +60,7 @@ def concurrency(self, context: Context) -> str: @orchestration_provider.step(retries=1, timeout="360m") async def kg_extract(self, context: Context) -> dict: - context.log( - f"Running KG Extraction for input: {context.workflow_input()['request']}" - ) + start_time = time.time() input_data = get_input_data_dict( context.workflow_input()["request"] @@ -70,12 +71,16 @@ async def kg_extract(self, context: Context) -> dict: await self.kg_service.kg_triples_extraction( document_id=uuid.UUID(document_id), - logger=context.log, + logger=create_hatchet_logger(context.log), **input_data["kg_creation_settings"], ) + context.log( + f"Successfully ran kg triples extraction for document {document_id}" + ) + return { - "result": f"successfully ran kg triples extraction for document {document_id}" + "result": f"successfully ran kg triples extraction for document {document_id} in {time.time() - start_time:.2f} seconds", } @orchestration_provider.step( @@ -90,13 +95,44 @@ async def kg_entity_description(self, context: Context) -> dict: await self.kg_service.kg_entity_description( document_id=uuid.UUID(document_id), + logger=create_hatchet_logger(context.log), **input_data["kg_creation_settings"], ) + context.log( + f"Successfully ran kg node description for document {document_id}" + ) + return { "result": f"successfully ran kg node description for document {document_id}" } + @orchestration_provider.failure() + async def on_failure(self, context: Context) -> None: + request = context.workflow_input().get("request", {}) + document_id = request.get("document_id") + + if not document_id: + context.log( + "No document id was found in workflow input to mark a failure." + ) + return + + try: + await self.kg_service.providers.database.relational.set_workflow_status( + id=uuid.UUID(document_id), + status_type="kg_extraction_status", + status=KGExtractionStatus.FAILED, + ) + context.log( + f"Updated KG extraction status for {document_id} to FAILED" + ) + + except Exception as e: + context.log( + f"Failed to update document status for {document_id}: {e}" + ) + @orchestration_provider.workflow(name="create-graph", timeout="360m") class CreateGraphWorkflow: def __init__(self, kg_service: KgService): @@ -187,6 +223,8 @@ def __init__(self, kg_service: KgService): @orchestration_provider.step(retries=1, parents=[], timeout="360m") async def kg_clustering(self, context: Context) -> dict: + start_time = time.time() + logger.info("Running KG Clustering") input_data = get_input_data_dict( context.workflow_input()["request"] @@ -195,11 +233,12 @@ async def kg_clustering(self, context: Context) -> dict: kg_clustering_results = await self.kg_service.kg_clustering( collection_id=collection_id, + logger=create_hatchet_logger(context.log), **input_data["kg_enrichment_settings"], ) context.log( - f"Successfully ran kg clustering for collection {collection_id}: {json.dumps(kg_clustering_results)}" + f"Successfully ran kg clustering for collection {collection_id}: {json.dumps(kg_clustering_results)} in {time.time() - start_time:.2f} seconds" ) logger.info( f"Successfully ran kg clustering for collection {collection_id}: {json.dumps(kg_clustering_results)}" @@ -220,10 +259,14 @@ async def kg_community_summary(self, context: Context) -> dict: num_communities = context.step_output("kg_clustering")[ "kg_clustering" ][0]["num_communities"] - parallel_communities = min(100, num_communities) total_workflows = math.ceil(num_communities / parallel_communities) workflows = [] + + context.log( + f"Running KG Community Summary for {num_communities} communities, spawning {total_workflows} workflows" + ) + for i in range(total_workflows): offset = i * parallel_communities workflows.append( @@ -257,15 +300,19 @@ def __init__(self, kg_service: KgService): @orchestration_provider.step(retries=1, timeout="360m") async def kg_community_summary(self, context: Context) -> dict: + + start_time = time.time() + input_data = get_input_data_dict( context.workflow_input()["request"] ) community_summary = await self.kg_service.kg_community_summary( - **input_data + logger=create_hatchet_logger(context.log), + **input_data, ) context.log( - f"Successfully ran kg community summary for communities {input_data['offset']} to {input_data['offset'] + len(community_summary)}" + f"Successfully ran kg community summary for communities {input_data['offset']} to {input_data['offset'] + len(community_summary)} in {time.time() - start_time:.2f} seconds " ) return { "result": f"successfully ran kg community summary for communities {input_data['offset']} to {input_data['offset'] + len(community_summary)}" diff --git a/py/core/main/orchestration/simple/kg_workflow.py b/py/core/main/orchestration/simple/kg_workflow.py index b5bca7d39..6eaee35f9 100644 --- a/py/core/main/orchestration/simple/kg_workflow.py +++ b/py/core/main/orchestration/simple/kg_workflow.py @@ -40,15 +40,22 @@ async def create_graph(input_data): for _, document_id in enumerate(document_ids): # Extract triples from the document - await service.kg_triples_extraction( - document_id=document_id, - **input_data["kg_creation_settings"], - ) - # Describe the entities in the graph - await service.kg_entity_description( - document_id=document_id, - **input_data["kg_creation_settings"], - ) + + try: + await service.kg_triples_extraction( + document_id=document_id, + **input_data["kg_creation_settings"], + ) + # Describe the entities in the graph + await service.kg_entity_description( + document_id=document_id, + **input_data["kg_creation_settings"], + ) + + except Exception as e: + logger.error( + f"Error in creating graph for document {document_id}: {e}" + ) async def enrich_graph(input_data): diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index 24bf4bd3e..833ffa155 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -1,6 +1,7 @@ import logging import math -from typing import Any, AsyncGenerator, Optional +import time +from typing import Any, AsyncGenerator, Optional, Union from uuid import UUID from core.base import KGExtractionStatus, RunLoggingSingleton, RunManager @@ -11,10 +12,14 @@ ) from core.telemetry.telemetry_decorator import telemetry_event +from shared.utils import HatchetLogger + from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders from ..config import R2RConfig from .base import Service +from time import strftime + logger = logging.getLogger(__name__) @@ -57,11 +62,16 @@ async def kg_triples_extraction( max_knowledge_triples: int, entity_types: list[str], relation_types: list[str], + logger: Union[logging.Logger, HatchetLogger] = logging.getLogger( + __name__ + ), **kwargs, ): try: - logger.info(f"Processing document {document_id} for KG extraction") + logger.info( + f"KGService: Processing document {document_id} for KG extraction" + ) await self.providers.database.relational.set_workflow_status( id=document_id, @@ -78,12 +88,17 @@ async def kg_triples_extraction( "max_knowledge_triples": max_knowledge_triples, "entity_types": entity_types, "relation_types": relation_types, + "logger": logger, } ), state=None, run_manager=self.run_manager, ) + logger.info( + f"KGService: Finished processing document {document_id} for KG extraction" + ) + result_gen = await self.pipes.kg_storage_pipe.run( input=self.pipes.kg_storage_pipe.Input(message=triples), state=None, @@ -91,12 +106,13 @@ async def kg_triples_extraction( ) except Exception as e: - logger.error(f"Error in kg_extraction: {e}") + logger.error(f"KGService: Error in kg_extraction: {e}") await self.providers.database.relational.set_workflow_status( id=document_id, status_type="kg_extraction_status", status=KGExtractionStatus.FAILED, ) + raise e return await _collect_results(result_gen) @@ -114,7 +130,6 @@ async def get_document_ids_for_create_graph( ] if force_kg_creation: document_status_filter += [ - KGExtractionStatus.SUCCESS, KGExtractionStatus.PROCESSING, ] @@ -131,14 +146,28 @@ async def kg_entity_description( self, document_id: UUID, max_description_input_length: int, + logger: Union[logging.Logger, HatchetLogger] = logging.getLogger( + __name__ + ), **kwargs, ): + start_time = time.time() + + logger.info( + f"KGService: Running kg_entity_description for document {document_id}" + ) + entity_count = await self.providers.kg.get_entity_count( document_id=document_id, + distinct=True, entity_table_name="entity_raw", ) + logger.info( + f"KGService: Found {entity_count} entities in document {document_id}" + ) + # TODO - Do not hardcode the batch size, # make it a configurable parameter at runtime & server-side defaults @@ -147,7 +176,7 @@ async def kg_entity_description( all_results = [] for i in range(num_batches): logger.info( - f"Running kg_entity_description for batch {i+1}/{num_batches} for document {document_id}" + f"KGService: Running kg_entity_description for batch {i+1}/{num_batches} for document {document_id}" ) node_descriptions = await self.pipes.kg_entity_description_pipe.run( @@ -157,6 +186,7 @@ async def kg_entity_description( "limit": 256, "max_description_input_length": max_description_input_length, "document_id": document_id, + "logger": logger, } ), state=None, @@ -165,12 +195,20 @@ async def kg_entity_description( all_results.append(await _collect_results(node_descriptions)) + logger.info( + f"KGService: Completed kg_entity_description for batch {i+1}/{num_batches} for document {document_id}" + ) + await self.providers.database.relational.set_workflow_status( id=document_id, status_type="kg_extraction_status", status=KGExtractionStatus.SUCCESS, ) + logger.info( + f"KGService: Completed kg_entity_description for document {document_id} in {time.time() - start_time:.2f} seconds", + ) + return all_results @telemetry_event("kg_clustering") @@ -179,6 +217,9 @@ async def kg_clustering( collection_id: UUID, generation_config: GenerationConfig, leiden_params: dict, + logger: Union[logging.Logger, HatchetLogger] = logging.getLogger( + __name__ + ), **kwargs, ): clustering_result = await self.pipes.kg_clustering_pipe.run( @@ -187,6 +228,7 @@ async def kg_clustering( "collection_id": collection_id, "generation_config": generation_config, "leiden_params": leiden_params, + "logger": logger, } ), state=None, @@ -202,6 +244,9 @@ async def kg_community_summary( max_summary_input_length: int, generation_config: GenerationConfig, collection_id: UUID, + logger: Union[logging.Logger, HatchetLogger] = logging.getLogger( + __name__ + ), **kwargs, ): summary_results = await self.pipes.kg_community_summary_pipe.run( @@ -212,6 +257,7 @@ async def kg_community_summary( "generation_config": generation_config, "max_summary_input_length": max_summary_input_length, "collection_id": collection_id, + "logger": logger, } ), state=None, diff --git a/py/core/pipes/kg/community_summary.py b/py/core/pipes/kg/community_summary.py index 0acece79c..8af6684fc 100644 --- a/py/core/pipes/kg/community_summary.py +++ b/py/core/pipes/kg/community_summary.py @@ -3,7 +3,7 @@ import logging from typing import Any, AsyncGenerator, Optional from uuid import UUID - +import time from core.base import ( AsyncPipe, AsyncState, @@ -193,20 +193,31 @@ async def _run_logic( # type: ignore Executes the KG community summary pipe: summarizing communities. """ + start_time = time.time() + offset = input.message["offset"] limit = input.message["limit"] generation_config = input.message["generation_config"] max_summary_input_length = input.message["max_summary_input_length"] collection_id = input.message["collection_id"] community_summary_jobs = [] + logger = input.message.get("logger", logging.getLogger(__name__)) # check which community summaries exist and don't run them again + + logger.info( + f"KGCommunitySummaryPipe: Checking if community summaries exist for communities {offset} to {offset + limit}" + ) community_numbers_exist = ( await self.kg_provider.check_community_reports_exist( collection_id=collection_id, offset=offset, limit=limit ) ) + logger.info( + f"KGCommunitySummaryPipe: Community summaries exist for communities {len(community_numbers_exist)}" + ) + for community_number in range(offset, offset + limit): if community_number not in community_numbers_exist: community_summary_jobs.append( @@ -218,5 +229,11 @@ async def _run_logic( # type: ignore ) ) + completed_community_summary_jobs = 0 for community_summary in asyncio.as_completed(community_summary_jobs): + completed_community_summary_jobs += 1 + if completed_community_summary_jobs % 50 == 0: + logger.info( + f"KGCommunitySummaryPipe: {completed_community_summary_jobs}/{len(community_summary_jobs)} community summaries completed, elapsed time: {time.time() - start_time:.2f} seconds" + ) yield await community_summary diff --git a/py/core/pipes/kg/entity_description.py b/py/core/pipes/kg/entity_description.py index 12de049b3..cba1aaa8a 100644 --- a/py/core/pipes/kg/entity_description.py +++ b/py/core/pipes/kg/entity_description.py @@ -16,6 +16,7 @@ ) from core.base.abstractions import Entity from core.base.pipes.base_pipe import AsyncPipe +import time logger = logging.getLogger(__name__) @@ -60,6 +61,8 @@ async def _run_logic( # type: ignore Extracts description from the input. """ + start_time = time.time() + # TODO - Move this to a .yaml file and load it as we do in triples extraction summarization_content = """ Provide a comprehensive yet concise summary of the given entity, incorporating its description and associated triples: @@ -167,16 +170,23 @@ async def process_entity( offset = input.message["offset"] limit = input.message["limit"] document_id = input.message["document_id"] + logger = input.message["logger"] + + logger.info( + f"KGEntityDescriptionPipe: Getting entity map for document {document_id}", + ) + entity_map = await self.kg_provider.get_entity_map( offset, limit, document_id ) - total_entities = len(entity_map) + logger.info( - f"Processing {total_entities} entities for document {document_id}" + f"KGEntityDescriptionPipe: Got entity map for document {document_id}, total entities: {total_entities}, time from start: {time.time() - start_time:.2f} seconds", ) workflows = [] + for i, (entity_name, entity_info) in enumerate(entity_map.items()): try: workflows.append( @@ -190,9 +200,15 @@ async def process_entity( except Exception as e: logger.error(f"Error processing entity {entity_name}: {e}") + completed_entities = 0 for result in asyncio.as_completed(workflows): + if completed_entities % 100 == 0: + logger.info( + f"KGEntityDescriptionPipe: Completed {completed_entities+1} of {total_entities} entities for document {document_id}", + ) yield await result + completed_entities += 1 logger.info( - f"Processed {total_entities} entities for document {document_id}" + f"KGEntityDescriptionPipe: Processed {total_entities} entities for document {document_id}, time from start: {time.time() - start_time:.2f} seconds", ) diff --git a/py/core/pipes/kg/triples_extraction.py b/py/core/pipes/kg/triples_extraction.py index e818b6cc9..872bcf295 100644 --- a/py/core/pipes/kg/triples_extraction.py +++ b/py/core/pipes/kg/triples_extraction.py @@ -3,7 +3,8 @@ import logging import re from typing import Any, AsyncGenerator, Optional, Union - +import time +from shared.utils import HatchetLogger from core.base import ( AsyncState, CompletionProvider, @@ -82,6 +83,11 @@ async def extract_kg( relation_types: list[str], retries: int = 5, delay: int = 2, + logger: Union[logging.Logger, HatchetLogger] = logging.getLogger( + __name__ + ), + task_id: Optional[int] = None, + total_tasks: Optional[int] = None, ) -> KGExtraction: """ Extracts NER triples from a extraction with retries. @@ -207,6 +213,10 @@ def parse_fn(response_str: str) -> Any: # raise e # you should raise an error. # add metadata to entities and triples + logger.info( + f"KGExtractionPipe: Completed task number {task_id} of {total_tasks} for document {extractions[0].document_id}", + ) + return KGExtraction( extraction_ids=[extraction.id for extraction in extractions], document_id=extractions[0].document_id, @@ -222,7 +232,8 @@ async def _run_logic( # type: ignore *args: Any, **kwargs: Any, ) -> AsyncGenerator[Union[KGExtraction, R2RDocumentProcessingError], None]: - logger.info("Running KG Extraction Pipe") + + start_time = time.time() document_id = input.message["document_id"] generation_config = input.message["generation_config"] @@ -230,6 +241,12 @@ async def _run_logic( # type: ignore max_knowledge_triples = input.message["max_knowledge_triples"] entity_types = input.message["entity_types"] relation_types = input.message["relation_types"] + logger = input.message.get("logger", logging.getLogger(__name__)) + + logger.info( + f"KGTriplesExtractionPipe: Processing document {document_id} for KG extraction", + ) + extractions = [ DocumentExtraction( id=extraction["extraction_id"], @@ -246,6 +263,10 @@ async def _run_logic( # type: ignore ] ] + logger.info( + f"KGTriplesExtractionPipe: Obtained {len(extractions)} extractions to process, time from start: {time.time() - start_time:.2f} seconds", + ) + # sort the extractions accroding to chunk_order field in metadata in ascending order extractions = sorted( extractions, key=lambda x: x.metadata["chunk_order"] @@ -257,6 +278,10 @@ async def _run_logic( # type: ignore for i in range(0, len(extractions), extraction_merge_count) ] + logger.info( + f"KGTriplesExtractionPipe: Extracting KG Triples for document and created {len(extractions_groups)}, time from start: {time.time() - start_time:.2f} seconds", + ) + tasks = [ asyncio.create_task( self.extract_kg( @@ -265,24 +290,36 @@ async def _run_logic( # type: ignore max_knowledge_triples=max_knowledge_triples, entity_types=entity_types, relation_types=relation_types, + logger=logger, + task_id=task_id, + total_tasks=len(extractions_groups), ) ) - for extractions_group in extractions_groups + for task_id, extractions_group in enumerate(extractions_groups) ] completed_tasks = 0 total_tasks = len(tasks) + logger.info( + f"KGTriplesExtractionPipe: Waiting for {total_tasks} KG extraction tasks to complete", + ) + for completed_task in asyncio.as_completed(tasks): try: yield await completed_task completed_tasks += 1 - logger.info( - f"Completed {completed_tasks}/{total_tasks} KG extraction tasks for document {document_id}" - ) + if completed_tasks % 100 == 0: + logger.info( + f"KGTriplesExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks", + ) except Exception as e: logger.error(f"Error in Extracting KG Triples: {e}") yield R2RDocumentProcessingError( document_id=document_id, error_message=str(e), ) + + logger.info( + f"KGTriplesExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks, time from start: {time.time() - start_time:.2f} seconds", + ) diff --git a/py/core/providers/database/vecs/collection.py b/py/core/providers/database/vecs/collection.py index 6770a7ab2..1c2686c42 100644 --- a/py/core/providers/database/vecs/collection.py +++ b/py/core/providers/database/vecs/collection.py @@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, Any, Iterable, Optional, Union from uuid import UUID, uuid4 +import time from flupy import flu from sqlalchemy import ( Column, @@ -1029,6 +1030,7 @@ def create_index( def _build_table( project_name: str, name: str, meta: MetaData, dimension: int ) -> Table: + table = Table( name, meta, diff --git a/py/core/providers/database/vector.py b/py/core/providers/database/vector.py index 535dc770f..748260958 100644 --- a/py/core/providers/database/vector.py +++ b/py/core/providers/database/vector.py @@ -15,11 +15,13 @@ VectorSearchResult, ) from core.base.abstractions import VectorSearchSettings + from shared.abstractions.vector import ( - IndexArgsHNSW, + IndexMethod, IndexArgsIVFFlat, + IndexArgsHNSW, + VectorTableName, IndexMeasure, - IndexMethod, VectorTableName, ) diff --git a/py/core/providers/kg/postgres.py b/py/core/providers/kg/postgres.py index 3a1467925..104b4ff25 100644 --- a/py/core/providers/kg/postgres.py +++ b/py/core/providers/kg/postgres.py @@ -1049,6 +1049,7 @@ async def get_entity_count( self, collection_id: Optional[UUID] = None, document_id: Optional[UUID] = None, + distinct: bool = False, entity_table_name: str = "entity_embedding", ) -> int: if collection_id is None and document_id is None: @@ -1073,8 +1074,13 @@ async def get_entity_count( conditions.append("document_id = $1") params.append(str(document_id)) + if distinct: + count_value = "DISTINCT name" + else: + count_value = "*" + QUERY = f""" - SELECT COUNT(*) FROM {self._get_table_name(entity_table_name)} + SELECT COUNT({count_value}) FROM {self._get_table_name(entity_table_name)} WHERE {" AND ".join(conditions)} """ return (await self.fetch_query(QUERY, params))[0]["count"] diff --git a/py/shared/utils/__init__.py b/py/shared/utils/__init__.py index ea86962d6..53cf96cd2 100644 --- a/py/shared/utils/__init__.py +++ b/py/shared/utils/__init__.py @@ -1,4 +1,5 @@ from .base_utils import ( + HatchetLogger, decrement_version, format_entity_types, format_relations, @@ -17,6 +18,7 @@ run_pipeline, to_async_generator, validate_uuid, + create_hatchet_logger, ) from .splitter.text import RecursiveCharacterTextSplitter, TextSplitter @@ -41,7 +43,9 @@ "to_async_generator", "llm_cost_per_million_tokens", "validate_uuid", + "create_hatchet_logger", # Text splitter "RecursiveCharacterTextSplitter", "TextSplitter", + "HatchetLogger", ] diff --git a/py/shared/utils/base_utils.py b/py/shared/utils/base_utils.py index e90332a99..285979f61 100644 --- a/py/shared/utils/base_utils.py +++ b/py/shared/utils/base_utils.py @@ -1,3 +1,5 @@ +from time import strftime +from typing import Optional import asyncio import json import logging @@ -5,9 +7,11 @@ from datetime import datetime from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable from uuid import NAMESPACE_DNS, UUID, uuid4, uuid5 +from copy import deepcopy from ..abstractions import R2RSerializable from ..abstractions.graph import EntityType, RelationshipType +from ..abstractions import R2RSerializable from ..abstractions.search import ( AggregateSearchResult, KGCommunityResult, @@ -264,3 +268,37 @@ def update_settings_from_dict(server_settings, settings_dict: dict): setattr(settings, key, value) return settings + + +class HatchetLogger: + def __init__(self, hatchet_logger: Any): + self.hatchet_logger = hatchet_logger + + def _log(self, level: str, message: str, function: Optional[str] = None): + if function: + log_message = f"[{level}]: {function}: {message}" + else: + log_message = f"[{level}]: {message}" + self.hatchet_logger(log_message) + + def debug(self, message: str, function: Optional[str] = None): + self._log("DEBUG", message, function) + + def info(self, message: str, function: Optional[str] = None): + self._log("INFO", message, function) + + def warning(self, message: str, function: Optional[str] = None): + self._log("WARNING", message, function) + + def error(self, message: str, function: Optional[str] = None): + self._log("ERROR", message, function) + + def critical(self, message: str, function: Optional[str] = None): + self._log("CRITICAL", message, function) + + +def create_hatchet_logger(hatchet_logger: Any) -> HatchetLogger: + """ + Creates a HatchetLogger instance with different logging levels. + """ + return HatchetLogger(hatchet_logger)