Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve kg throughput #1342

Merged
merged 22 commits into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions py/core/base/providers/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ class EmbeddingConfig(ProviderConfig):
batch_size: int = 1
prefixes: Optional[dict[str, str]] = None
add_title_as_prefix: bool = True
concurrent_request_limit: int = 16
max_retries: int = 2
initial_backoff: float = 1.0
max_backoff: float = 60.0
concurrent_request_limit: int = 256
max_retries: int = 8
initial_backoff: float = 1
max_backoff: float = 64.0

def validate_config(self) -> None:
if self.provider not in self.supported_providers:
Expand Down Expand Up @@ -63,6 +63,7 @@ async def _execute_with_backoff_async(self, task: dict[str, Any]):
try:
async with self.semaphore:
return await self._execute_task(task)
# TODO: Capture different error types and handle them accordingly
except Exception as e:
logger.warning(
f"Request failed (attempt {retries + 1}): {str(e)}"
Expand Down
4 changes: 2 additions & 2 deletions py/core/base/providers/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ class CompletionConfig(ProviderConfig):
provider: Optional[str] = None
generation_config: GenerationConfig = GenerationConfig()
concurrent_request_limit: int = 256
max_retries: int = 2
max_retries: int = 8
initial_backoff: float = 1.0
max_backoff: float = 60.0
max_backoff: float = 64.0

def validate_config(self) -> None:
if not self.provider:
Expand Down
2 changes: 1 addition & 1 deletion py/core/configs/r2r_aws_bedrock.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ overlap = 20

[completion]
provider = "litellm"
concurrent_request_limit = 16
concurrent_request_limit = 256

[completion.generation_config]
model = "bedrock/anthropic.claude-v2"
Expand Down
2 changes: 1 addition & 1 deletion py/core/main/orchestration/hatchet/kg_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_input_data_dict(input_data):
class KGExtractDescribeEmbedWorkflow:
def __init__(self, kg_service: KgService):
self.kg_service = kg_service

@orchestration_provider.concurrency(
max_runs=orchestration_provider.config.kg_creation_concurrency_limit,
limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN,
Expand Down
2 changes: 1 addition & 1 deletion py/core/pipes/kg/entity_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ async def process_entity(
),
}
],
generation_config=self.kg_provider.config.kg_enrichment_settings.generation_config,
generation_config=self.kg_provider.config.kg_creation_settings.generation_config,
)
)
.choices[0]
Expand Down
11 changes: 8 additions & 3 deletions py/core/providers/database/relational.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from contextlib import asynccontextmanager

import asyncio
import asyncpg

from core.base import RelationalDBProvider
Expand Down Expand Up @@ -35,13 +35,17 @@ def __init__(
self.project_name = project_name
self.pool = None
self.postgres_configuration_settings = postgres_configuration_settings
self.semaphore = asyncio.Semaphore(
int(self.postgres_configuration_settings.max_connections * 0.9)
)

async def initialize(self):
try:
self.pool = await asyncpg.create_pool(
self.connection_string,
max_size=self.postgres_configuration_settings.max_connections,
)

logger.info(
"Successfully connected to Postgres database and created connection pool."
)
Expand All @@ -57,8 +61,9 @@ def _get_table_name(self, base_name: str) -> str:

@asynccontextmanager
async def get_connection(self):
async with self.pool.acquire() as conn:
yield conn
async with self.semaphore:
async with self.pool.acquire() as conn:
yield conn

async def execute_query(self, query, params=None):
async with self.get_connection() as conn:
Expand Down
4 changes: 1 addition & 3 deletions py/core/providers/kg/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(

self.db_provider = db_provider.relational
self.embedding_provider = embedding_provider

try:
import networkx as nx

Expand Down Expand Up @@ -164,9 +165,6 @@ async def create_tables(self, project_name: str):

await self.execute_query(query)

# TODO: Create another table for entity_embedding_collection
# entity embeddings at a collection level

# communities table, result of the Leiden algorithm
query = f"""
CREATE TABLE IF NOT EXISTS {self._get_table_name("community")} (
Expand Down
1 change: 1 addition & 0 deletions py/core/providers/orchestration/hatchet.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def get_worker(self, name: str, max_threads: Optional[int] = None) -> Any:
self.worker = self.orchestrator.worker(name, max_threads)
return self.worker


def concurrency(self, *args, **kwargs) -> Callable:
shreyaspimpalgaonkar marked this conversation as resolved.
Show resolved Hide resolved
return self.orchestrator.concurrency(*args, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion py/r2r.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ default_admin_password = "change_me_immediately"

[completion]
provider = "litellm"
concurrent_request_limit = 16
concurrent_request_limit = 256

[completion.generation_config]
model = "openai/gpt-4o"
Expand Down
Loading