diff --git a/nucliadb/src/nucliadb/common/maindb/pg.py b/nucliadb/src/nucliadb/common/maindb/pg.py index 8af9223b15..4f0d63a1d8 100644 --- a/nucliadb/src/nucliadb/common/maindb/pg.py +++ b/nucliadb/src/nucliadb/common/maindb/pg.py @@ -64,6 +64,10 @@ class DataLayer: def __init__(self, connection: Union[asyncpg.Connection, asyncpg.Pool]): self.connection = connection + # A lock to avoid sending concurrent queries to the connection. asyncpg has its own system to control this + # but instead of waiting, it raises an Exception. We use our own lock so that concurrent tasks wait for each + # other, rather than exploding. This could be avoided if we can guarantee that a single asyncpg connection + # is not shared between concurrent tasks. self.lock = asyncio.Lock() async def get(self, key: str, select_for_update: bool = False) -> Optional[bytes]: @@ -163,16 +167,14 @@ class PGTransaction(Transaction): def __init__( self, - pool: asyncpg.Pool, + driver: PGDriver, connection: asyncpg.Connection, txn: Any, - driver: PGDriver, ): - self.pool = pool + self.driver = driver self.connection = connection self.data_layer = DataLayer(connection) self.txn = txn - self.driver = driver self.open = True self._lock = asyncio.Lock() @@ -226,8 +228,9 @@ async def keys( count: int = DEFAULT_SCAN_LIMIT, include_start: bool = True, ): - async with self.pool.acquire() as conn, conn.transaction(): - # all txn implementations implement this API outside of the current txn + # Check out a new connection to guarantee that the cursor iteration does not + # run concurrently with other queries + async with self.driver._get_connection() as conn, conn.transaction(): dl = DataLayer(conn) async for key in dl.scan_keys(match, count, include_start=include_start): yield key @@ -239,8 +242,7 @@ async def count(self, match: str) -> int: class ReadOnlyPGTransaction(Transaction): driver: PGDriver - def __init__(self, pool: asyncpg.Pool, driver: PGDriver): - self.pool = pool + def __init__(self, driver: PGDriver): self.driver = driver self.open = True @@ -252,10 +254,12 @@ async def commit(self): raise Exception("Cannot commit transaction in read only mode") async def batch_get(self, keys: list[str]): - return await DataLayer(self.pool).batch_get(keys) + async with self.driver._get_connection() as conn: + return await DataLayer(conn).batch_get(keys) async def get(self, key: str) -> Optional[bytes]: - return await DataLayer(self.pool).get(key) + async with self.driver._get_connection() as conn: + return await DataLayer(conn).get(key) async def set(self, key: str, value: bytes): raise Exception("Cannot set in read only transaction") @@ -269,22 +273,49 @@ async def keys( count: int = DEFAULT_SCAN_LIMIT, include_start: bool = True, ): - async with self.pool.acquire() as conn, conn.transaction(): - # all txn implementations implement this API outside of the current txn + async with self.driver._get_connection() as conn, conn.transaction(): dl = DataLayer(conn) async for key in dl.scan_keys(match, count, include_start=include_start): yield key async def count(self, match: str) -> int: - return await DataLayer(self.pool).count(match) + async with self.driver._get_connection() as conn: + return await DataLayer(conn).count(match) + + +class InstrumentedAcquireContext: + def __init__(self, context): + self.context = context + + async def __aenter__(self): + with pg_observer({"type": "acquire"}): + return await self.context.__aenter__() + + async def __aexit__(self, *exc): + return await self.context.__aexit__() + + def __await__(self): + async def wrap(): + with pg_observer({"type": "acquire"}): + return await self.context + + return wrap().__await__() class PGDriver(Driver): pool: asyncpg.Pool - def __init__(self, url: str, connection_pool_max_size: int = 10): + def __init__( + self, + url: str, + connection_pool_min_size: int = 10, + connection_pool_max_size: int = 10, + acquire_timeout_ms: int = 200, + ): self.url = url + self.connection_pool_min_size = connection_pool_min_size self.connection_pool_max_size = connection_pool_max_size + self.acquire_timeout_ms = acquire_timeout_ms self._lock = asyncio.Lock() async def initialize(self, for_replication: bool = False): @@ -292,6 +323,7 @@ async def initialize(self, for_replication: bool = False): if self.initialized is False: self.pool = await asyncpg.create_pool( self.url, + min_size=self.connection_pool_min_size, max_size=self.connection_pool_max_size, ) @@ -313,10 +345,14 @@ async def finalize(self): @backoff.on_exception(backoff.expo, RETRIABLE_EXCEPTIONS, jitter=backoff.random_jitter, max_tries=3) async def begin(self, read_only: bool = False) -> Union[PGTransaction, ReadOnlyPGTransaction]: if read_only: - return ReadOnlyPGTransaction(self.pool, driver=self) + return ReadOnlyPGTransaction(self) else: - conn: asyncpg.Connection = await self.pool.acquire() + conn = await self._get_connection() with pg_observer({"type": "begin"}): txn = conn.transaction() await txn.start() - return PGTransaction(self.pool, conn, txn, driver=self) + return PGTransaction(self, conn, txn) + + def _get_connection(self) -> asyncpg.Connection: + timeout = self.acquire_timeout_ms / 1000 + return InstrumentedAcquireContext(self.pool.acquire(timeout=timeout)) diff --git a/nucliadb/src/nucliadb/common/maindb/utils.py b/nucliadb/src/nucliadb/common/maindb/utils.py index 265dab8bc1..4ba04a90ce 100644 --- a/nucliadb/src/nucliadb/common/maindb/utils.py +++ b/nucliadb/src/nucliadb/common/maindb/utils.py @@ -91,7 +91,9 @@ async def setup_driver() -> Driver: raise ConfigurationError("No DRIVER_PG_URL env var defined.") pg_driver = PGDriver( url=settings.driver_pg_url, + connection_pool_min_size=settings.driver_pg_connection_pool_min_size, connection_pool_max_size=settings.driver_pg_connection_pool_max_size, + acquire_timeout_ms=settings.driver_pg_connection_pool_acquire_timeout_ms, ) set_utility(Utility.MAINDB_DRIVER, pg_driver) elif settings.driver == DriverConfig.LOCAL: diff --git a/nucliadb/src/nucliadb/ingest/settings.py b/nucliadb/src/nucliadb/ingest/settings.py index 1ae430d965..2863e9112a 100644 --- a/nucliadb/src/nucliadb/ingest/settings.py +++ b/nucliadb/src/nucliadb/ingest/settings.py @@ -61,10 +61,18 @@ class DriverSettings(BaseSettings): default=None, description="PostgreSQL DSN. The connection string to the PG server. Example: postgres://username:password@postgres:5432/nucliadb.", # noqa ) + driver_pg_connection_pool_min_size: int = Field( + default=10, + description="PostgreSQL min pool size. The minimum number of connections to the PostgreSQL server.", + ) driver_pg_connection_pool_max_size: int = Field( default=20, description="PostgreSQL max pool size. The maximum number of connections to the PostgreSQL server.", ) + driver_pg_connection_pool_acquire_timeout_ms: int = Field( + default=200, + description="PostgreSQL pool acquire timeout in ms. The maximum time to wait until a connection becomes available.", + ) driver_tikv_connection_pool_size: int = Field( default=3, description="TiKV max pool size. The maximum number of connections to the TiKV server.", diff --git a/nucliadb/tests/nucliadb/integration/common/maindb/test_drivers.py b/nucliadb/tests/nucliadb/integration/common/maindb/test_drivers.py index 91231e3439..5b9ba13d53 100644 --- a/nucliadb/tests/nucliadb/integration/common/maindb/test_drivers.py +++ b/nucliadb/tests/nucliadb/integration/common/maindb/test_drivers.py @@ -24,6 +24,7 @@ import pytest from nucliadb.common.maindb.driver import Driver +from nucliadb.common.maindb.pg import PGDriver from nucliadb.common.maindb.redis import RedisDriver from nucliadb.common.maindb.tikv import TiKVDriver @@ -77,6 +78,29 @@ async def test_local_driver(local_driver): await driver_basic(local_driver) +@pytest.mark.skipif("pg" not in TESTING_MAINDB_DRIVERS, reason="pg not in TESTING_MAINDB_DRIVERS") +async def test_pg_driver_pool_timeout(pg): + url = f"postgresql://postgres:postgres@{pg[0]}:{pg[1]}/postgres" + driver = PGDriver(url, connection_pool_min_size=1, connection_pool_max_size=1) + await driver.initialize() + + # Get one connection and hold it + conn = await driver.begin() + + # Try to get another connection, should fail because pool is full + with pytest.raises(TimeoutError): + await driver.begin() + + # Abort the connection and try again + await conn.abort() + + # Should now work + conn2 = await driver.begin() + + # Closing for hygiene + await conn2.abort() + + async def _clear_db(driver: Driver): all_keys = [] async with driver.transaction() as txn: