Skip to content

Commit

Permalink
PG pool timeout (#2248)
Browse files Browse the repository at this point in the history
  • Loading branch information
javitonino authored Jun 12, 2024
1 parent 0734da8 commit db21116
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 17 deletions.
70 changes: 53 additions & 17 deletions nucliadb/src/nucliadb/common/maindb/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@
class DataLayer:
def __init__(self, connection: Union[asyncpg.Connection, asyncpg.Pool]):
self.connection = connection
# A lock to avoid sending concurrent queries to the connection. asyncpg has its own system to control this
# but instead of waiting, it raises an Exception. We use our own lock so that concurrent tasks wait for each
# other, rather than exploding. This could be avoided if we can guarantee that a single asyncpg connection
# is not shared between concurrent tasks.
self.lock = asyncio.Lock()

async def get(self, key: str, select_for_update: bool = False) -> Optional[bytes]:
Expand Down Expand Up @@ -163,16 +167,14 @@ class PGTransaction(Transaction):

def __init__(
self,
pool: asyncpg.Pool,
driver: PGDriver,
connection: asyncpg.Connection,
txn: Any,
driver: PGDriver,
):
self.pool = pool
self.driver = driver
self.connection = connection
self.data_layer = DataLayer(connection)
self.txn = txn
self.driver = driver
self.open = True
self._lock = asyncio.Lock()

Expand Down Expand Up @@ -226,8 +228,9 @@ async def keys(
count: int = DEFAULT_SCAN_LIMIT,
include_start: bool = True,
):
async with self.pool.acquire() as conn, conn.transaction():
# all txn implementations implement this API outside of the current txn
# Check out a new connection to guarantee that the cursor iteration does not
# run concurrently with other queries
async with self.driver._get_connection() as conn, conn.transaction():
dl = DataLayer(conn)
async for key in dl.scan_keys(match, count, include_start=include_start):
yield key
Expand All @@ -239,8 +242,7 @@ async def count(self, match: str) -> int:
class ReadOnlyPGTransaction(Transaction):
driver: PGDriver

def __init__(self, pool: asyncpg.Pool, driver: PGDriver):
self.pool = pool
def __init__(self, driver: PGDriver):
self.driver = driver
self.open = True

Expand All @@ -252,10 +254,12 @@ async def commit(self):
raise Exception("Cannot commit transaction in read only mode")

async def batch_get(self, keys: list[str]):
return await DataLayer(self.pool).batch_get(keys)
async with self.driver._get_connection() as conn:
return await DataLayer(conn).batch_get(keys)

async def get(self, key: str) -> Optional[bytes]:
return await DataLayer(self.pool).get(key)
async with self.driver._get_connection() as conn:
return await DataLayer(conn).get(key)

async def set(self, key: str, value: bytes):
raise Exception("Cannot set in read only transaction")
Expand All @@ -269,29 +273,57 @@ async def keys(
count: int = DEFAULT_SCAN_LIMIT,
include_start: bool = True,
):
async with self.pool.acquire() as conn, conn.transaction():
# all txn implementations implement this API outside of the current txn
async with self.driver._get_connection() as conn, conn.transaction():
dl = DataLayer(conn)
async for key in dl.scan_keys(match, count, include_start=include_start):
yield key

async def count(self, match: str) -> int:
return await DataLayer(self.pool).count(match)
async with self.driver._get_connection() as conn:
return await DataLayer(conn).count(match)


class InstrumentedAcquireContext:
def __init__(self, context):
self.context = context

async def __aenter__(self):
with pg_observer({"type": "acquire"}):
return await self.context.__aenter__()

async def __aexit__(self, *exc):
return await self.context.__aexit__()

def __await__(self):
async def wrap():
with pg_observer({"type": "acquire"}):
return await self.context

return wrap().__await__()


class PGDriver(Driver):
pool: asyncpg.Pool

def __init__(self, url: str, connection_pool_max_size: int = 10):
def __init__(
self,
url: str,
connection_pool_min_size: int = 10,
connection_pool_max_size: int = 10,
acquire_timeout_ms: int = 200,
):
self.url = url
self.connection_pool_min_size = connection_pool_min_size
self.connection_pool_max_size = connection_pool_max_size
self.acquire_timeout_ms = acquire_timeout_ms
self._lock = asyncio.Lock()

async def initialize(self, for_replication: bool = False):
async with self._lock:
if self.initialized is False:
self.pool = await asyncpg.create_pool(
self.url,
min_size=self.connection_pool_min_size,
max_size=self.connection_pool_max_size,
)

Expand All @@ -313,10 +345,14 @@ async def finalize(self):
@backoff.on_exception(backoff.expo, RETRIABLE_EXCEPTIONS, jitter=backoff.random_jitter, max_tries=3)
async def begin(self, read_only: bool = False) -> Union[PGTransaction, ReadOnlyPGTransaction]:
if read_only:
return ReadOnlyPGTransaction(self.pool, driver=self)
return ReadOnlyPGTransaction(self)
else:
conn: asyncpg.Connection = await self.pool.acquire()
conn = await self._get_connection()
with pg_observer({"type": "begin"}):
txn = conn.transaction()
await txn.start()
return PGTransaction(self.pool, conn, txn, driver=self)
return PGTransaction(self, conn, txn)

def _get_connection(self) -> asyncpg.Connection:
timeout = self.acquire_timeout_ms / 1000
return InstrumentedAcquireContext(self.pool.acquire(timeout=timeout))
2 changes: 2 additions & 0 deletions nucliadb/src/nucliadb/common/maindb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ async def setup_driver() -> Driver:
raise ConfigurationError("No DRIVER_PG_URL env var defined.")
pg_driver = PGDriver(
url=settings.driver_pg_url,
connection_pool_min_size=settings.driver_pg_connection_pool_min_size,
connection_pool_max_size=settings.driver_pg_connection_pool_max_size,
acquire_timeout_ms=settings.driver_pg_connection_pool_acquire_timeout_ms,
)
set_utility(Utility.MAINDB_DRIVER, pg_driver)
elif settings.driver == DriverConfig.LOCAL:
Expand Down
8 changes: 8 additions & 0 deletions nucliadb/src/nucliadb/ingest/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,18 @@ class DriverSettings(BaseSettings):
default=None,
description="PostgreSQL DSN. The connection string to the PG server. Example: postgres://username:password@postgres:5432/nucliadb.", # noqa
)
driver_pg_connection_pool_min_size: int = Field(
default=10,
description="PostgreSQL min pool size. The minimum number of connections to the PostgreSQL server.",
)
driver_pg_connection_pool_max_size: int = Field(
default=20,
description="PostgreSQL max pool size. The maximum number of connections to the PostgreSQL server.",
)
driver_pg_connection_pool_acquire_timeout_ms: int = Field(
default=200,
description="PostgreSQL pool acquire timeout in ms. The maximum time to wait until a connection becomes available.",
)
driver_tikv_connection_pool_size: int = Field(
default=3,
description="TiKV max pool size. The maximum number of connections to the TiKV server.",
Expand Down
24 changes: 24 additions & 0 deletions nucliadb/tests/nucliadb/integration/common/maindb/test_drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pytest

from nucliadb.common.maindb.driver import Driver
from nucliadb.common.maindb.pg import PGDriver
from nucliadb.common.maindb.redis import RedisDriver
from nucliadb.common.maindb.tikv import TiKVDriver

Expand Down Expand Up @@ -77,6 +78,29 @@ async def test_local_driver(local_driver):
await driver_basic(local_driver)


@pytest.mark.skipif("pg" not in TESTING_MAINDB_DRIVERS, reason="pg not in TESTING_MAINDB_DRIVERS")
async def test_pg_driver_pool_timeout(pg):
url = f"postgresql://postgres:postgres@{pg[0]}:{pg[1]}/postgres"
driver = PGDriver(url, connection_pool_min_size=1, connection_pool_max_size=1)
await driver.initialize()

# Get one connection and hold it
conn = await driver.begin()

# Try to get another connection, should fail because pool is full
with pytest.raises(TimeoutError):
await driver.begin()

# Abort the connection and try again
await conn.abort()

# Should now work
conn2 = await driver.begin()

# Closing for hygiene
await conn2.abort()


async def _clear_db(driver: Driver):
all_keys = []
async with driver.transaction() as txn:
Expand Down

3 comments on commit db21116

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: db21116 Previous: 0d03d9f Ratio
tests/search/unit/search/test_fetch.py::test_highligh_error 2962.275325865513 iter/sec (stddev: 0.000004504356633707327) 2841.0684406726436 iter/sec (stddev: 0.000004954958228416619) 0.96

This comment was automatically generated by workflow using github-action-benchmark.

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: db21116 Previous: 0d03d9f Ratio
tests/search/unit/search/test_fetch.py::test_highligh_error 2914.094496526268 iter/sec (stddev: 0.000012732212986876445) 2841.0684406726436 iter/sec (stddev: 0.000004954958228416619) 0.97

This comment was automatically generated by workflow using github-action-benchmark.

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: db21116 Previous: 0d03d9f Ratio
tests/search/unit/search/test_fetch.py::test_highligh_error 3013.849625696085 iter/sec (stddev: 0.0000065974920125946385) 2841.0684406726436 iter/sec (stddev: 0.000004954958228416619) 0.94

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.