From 75434722bb488d597df79fec92e49150ba51d64b Mon Sep 17 00:00:00 2001 From: Sourour Benzarti Date: Sun, 5 Jan 2025 22:07:56 +0100 Subject: [PATCH] manage connections using a Connection pool --- postgrestq/task_queue.py | 615 ++++++++++++++++++++------------------- pyproject.toml | 1 + 2 files changed, 309 insertions(+), 307 deletions(-) diff --git a/postgrestq/task_queue.py b/postgrestq/task_queue.py index 2666342..1c6fb53 100644 --- a/postgrestq/task_queue.py +++ b/postgrestq/task_queue.py @@ -14,7 +14,8 @@ Sequence, ) -from psycopg import sql, connect, Connection +from psycopg import sql +from psycopg_pool import ConnectionPool # supported only from 3.11 onwards: # from datetime import UTC @@ -68,29 +69,24 @@ def __init__( # called when ttl <= 0 for a task self.ttl_zero_callback = ttl_zero_callback - - self.conn = self.connect() + self.connect() if create_table: self._create_queue_table() if reset: self._reset() - def get_connection(self): - connection = connect(self._dsn) - with connection.cursor() as cur: - cur.execute("SELECT 1+1") - cur.fetchone() - - return connection - def connect(self) -> None: """ Establish a connection to Postgres. If a connection already exists, it's overwritten. """ - if self.conn is None or self.conn.closed: - self.conn = self.get_connection() + # if self.conn is None or self.conn.closed: + # self.conn = self.get_connection() + self.pool = ConnectionPool(self._dsn, open=True, min_size=1, max_size=1) + # This will block the use of the pool until min_size connections + # have been acquired + self.pool.wait() def _create_queue_table(self) -> None: """ @@ -98,52 +94,54 @@ def _create_queue_table(self) -> None: """ # TODO: check if the table already exist # whether it has the same schema - with self.conn.cursor() as cur: - cur.execute( - sql.SQL( - """CREATE TABLE IF NOT EXISTS {} ( - id UUID PRIMARY KEY, - queue_name TEXT NOT NULL, - task JSONB NOT NULL, - ttl SMALLINT NOT NULL, - can_start_at TIMESTAMPTZ NOT NULL - DEFAULT CURRENT_TIMESTAMP, - lease_timeout FLOAT, - started_at TIMESTAMPTZ, - completed_at TIMESTAMPTZ - )""" - ).format(sql.Identifier(self._table_name)) - ) - cur.execute( - sql.SQL( - """CREATE INDEX IF NOT EXISTS - task_queue_queue_name_can_start_at_idx - ON {} (queue_name, can_start_at) - """ - ).format(sql.Identifier(self._table_name)) - ) - self.conn.commit() + with self.pool.connection() as conn: + with conn.cursor() as cur: + cur.execute( + sql.SQL( + """CREATE TABLE IF NOT EXISTS {} ( + id UUID PRIMARY KEY, + queue_name TEXT NOT NULL, + task JSONB NOT NULL, + ttl SMALLINT NOT NULL, + can_start_at TIMESTAMPTZ NOT NULL + DEFAULT CURRENT_TIMESTAMP, + lease_timeout FLOAT, + started_at TIMESTAMPTZ, + completed_at TIMESTAMPTZ + )""" + ).format(sql.Identifier(self._table_name)) + ) + cur.execute( + sql.SQL( + """CREATE INDEX IF NOT EXISTS + task_queue_queue_name_can_start_at_idx + ON {} (queue_name, can_start_at) + """ + ).format(sql.Identifier(self._table_name)) + ) + conn.commit() def __len__(self) -> int: """ Returns the length of processing or to be processed tasks """ - with self.conn.cursor() as cursor: - cursor.execute( - sql.SQL( - """ - SELECT count(*) as count - FROM {} - WHERE queue_name = %s - AND completed_at IS NULL - """ - ).format(sql.Identifier(self._table_name)), - (self._queue_name,), - ) - row = cursor.fetchone() - count: int = row[0] if row else 0 - self.conn.commit() - return count + with self.pool.connection() as conn: + with conn.cursor() as cursor: + cursor.execute( + sql.SQL( + """ + SELECT count(*) as count + FROM {} + WHERE queue_name = %s + AND completed_at IS NULL + """ + ).format(sql.Identifier(self._table_name)), + (self._queue_name,), + ) + row = cursor.fetchone() + count: int = row[0] if row else 0 + conn.commit() + return count def add( self, @@ -180,29 +178,29 @@ def add( id_ = str(uuid4()) serialized_task = self._serialize(task) - - with self.conn.cursor() as cursor: - # store the task + metadata and put task-id into the task queue - cursor.execute( - sql.SQL( - """ - INSERT INTO {} ( - id, - queue_name, - task, - ttl, - lease_timeout, - can_start_at + with self.pool.connection() as conn: + with conn.cursor() as cursor: + # store the task + metadata and put task-id into the task queue + cursor.execute( + sql.SQL( + """ + INSERT INTO {} ( + id, + queue_name, + task, + ttl, + lease_timeout, + can_start_at + ) + VALUES (%s, %s, %s, %s, %s, COALESCE(%s, current_timestamp)) + """ + ).format(sql.Identifier(self._table_name)), + ( + id_, self._queue_name, serialized_task, + ttl, lease_timeout, can_start_at + ), ) - VALUES (%s, %s, %s, %s, %s, COALESCE(%s, current_timestamp)) - """ - ).format(sql.Identifier(self._table_name)), - ( - id_, self._queue_name, serialized_task, - ttl, lease_timeout, can_start_at - ), - ) - self.conn.commit() + conn.commit() return id_ def add_many( @@ -243,35 +241,36 @@ def add_many( # into problems later when we calculate the actual deadline lease_timeout = float(lease_timeout) ret_ids = [] - with self.conn.cursor() as cursor: - for task in tasks: - id_ = str(uuid4()) - - serialized_task = self._serialize(task) - - cursor.execute( - sql.SQL( - """ - INSERT INTO {} ( - id, - queue_name, - task, - ttl, - lease_timeout, - can_start_at - ) - VALUES ( - %s, %s, %s, %s, %s, COALESCE(%s, current_timestamp) + with self.pool.connection() as conn: + with conn.cursor() as cursor: + for task in tasks: + id_ = str(uuid4()) + + serialized_task = self._serialize(task) + + cursor.execute( + sql.SQL( + """ + INSERT INTO {} ( + id, + queue_name, + task, + ttl, + lease_timeout, + can_start_at + ) + VALUES ( + %s, %s, %s, %s, %s, COALESCE(%s, current_timestamp) + ) + """ + ).format(sql.Identifier(self._table_name)), + ( + id_, self._queue_name, serialized_task, + ttl, lease_timeout, can_start_at + ), ) - """ - ).format(sql.Identifier(self._table_name)), - ( - id_, self._queue_name, serialized_task, - ttl, lease_timeout, can_start_at - ), - ) - ret_ids.append(id_) - self.conn.commit() + ret_ids.append(id_) + conn.commit() return ret_ids def get(self) -> Tuple[ @@ -314,41 +313,41 @@ def get(self) -> Tuple[ empty """ - conn = self.conn - - with conn.cursor() as cur: - cur.execute( - sql.SQL( - """ - UPDATE {} - SET started_at = current_timestamp - WHERE id = ( - SELECT id - FROM {} - WHERE completed_at IS NULL - AND started_at IS NULL - AND queue_name = %s - AND ttl > 0 - AND can_start_at <= current_timestamp - ORDER BY can_start_at - FOR UPDATE SKIP LOCKED - LIMIT 1 + with self.pool.connection() as conn: + with conn.cursor() as cur: + cur.execute( + sql.SQL( + """ + UPDATE {} + SET started_at = current_timestamp + WHERE id = ( + SELECT id + FROM {} + WHERE completed_at IS NULL + AND started_at IS NULL + AND queue_name = %s + AND ttl > 0 + AND can_start_at <= current_timestamp + ORDER BY can_start_at + FOR UPDATE SKIP LOCKED + LIMIT 1 + ) + RETURNING id, task;""" + ).format( + sql.Identifier(self._table_name), + sql.Identifier(self._table_name), + ), + (self._queue_name,), ) - RETURNING id, task;""" - ).format( - sql.Identifier(self._table_name), - sql.Identifier(self._table_name), - ), - (self._queue_name,), - ) - row = cur.fetchone() - conn.commit() - if row is None: - return None, None, None - task_id, task = row - logger.info(f"Got task with id {task_id}") - return task, task_id, self._queue_name + row = cur.fetchone() + conn.commit() + if row is None: + return None, None, None + + task_id, task = row + logger.info(f"Got task with id {task_id}") + return task, task_id, self._queue_name def get_many(self, amount: int) -> Sequence[ Tuple[Optional[Dict[str, Any]], Optional[UUID], Optional[str]], @@ -368,39 +367,38 @@ def get_many(self, amount: int) -> Sequence[ The tasks and their IDs, and the queue_name """ - conn = self.conn - - with conn.cursor() as cur: - cur.execute( - sql.SQL( - """ - UPDATE {} - SET started_at = current_timestamp - WHERE id IN ( - SELECT id - FROM {} - WHERE completed_at IS NULL - AND started_at IS NULL - AND queue_name = %s - AND ttl > 0 - AND can_start_at <= current_timestamp - ORDER BY can_start_at - FOR UPDATE SKIP LOCKED - LIMIT %s + with self.pool.connection() as conn: + with conn.cursor() as cur: + cur.execute( + sql.SQL( + """ + UPDATE {} + SET started_at = current_timestamp + WHERE id IN ( + SELECT id + FROM {} + WHERE completed_at IS NULL + AND started_at IS NULL + AND queue_name = %s + AND ttl > 0 + AND can_start_at <= current_timestamp + ORDER BY can_start_at + FOR UPDATE SKIP LOCKED + LIMIT %s + ) + RETURNING task, id;""" + ).format( + sql.Identifier(self._table_name), + sql.Identifier(self._table_name), + ), + (self._queue_name, amount), ) - RETURNING task, id;""" - ).format( - sql.Identifier(self._table_name), - sql.Identifier(self._table_name), - ), - (self._queue_name, amount), - ) - ret = [] - for task, task_id in cur.fetchall(): - logger.info(f"Got task with id {task_id}") - ret.append((task, task_id, self._queue_name,)) - conn.commit() + ret = [] + for task, task_id in cur.fetchall(): + logger.info(f"Got task with id {task_id}") + ret.append((task, task_id, self._queue_name,)) + conn.commit() return ret def complete(self, task_id: Optional[UUID]) -> int: @@ -424,24 +422,24 @@ def complete(self, task_id: Optional[UUID]) -> int: """ logger.info(f"Marking task {task_id} as completed") - conn = self.conn count = 0 - with conn.cursor() as cur: - cur.execute( - sql.SQL( - """ - UPDATE {} - SET completed_at = current_timestamp - WHERE id = %s - AND completed_at is NULL""" - ).format(sql.Identifier(self._table_name)), - (task_id,), - ) - count = cur.rowcount - if count == 0: - logger.info(f"Task {task_id} was already completed") + with self.pool.connection() as conn: + with conn.cursor() as cur: + cur.execute( + sql.SQL( + """ + UPDATE {} + SET completed_at = current_timestamp + WHERE id = %s + AND completed_at is NULL""" + ).format(sql.Identifier(self._table_name)), + (task_id,), + ) + count = cur.rowcount + if count == 0: + logger.info(f"Task {task_id} was already completed") - conn.commit() + conn.commit() return count def is_empty(self) -> bool: @@ -479,55 +477,57 @@ def check_expired_leases(self) -> None: """ # goes through all the tasks that are marked as started # and check the ones with expired timeout - with self.conn.cursor() as cur: - cur.execute( - sql.SQL( - """ - SELECT id - FROM {} - WHERE completed_at IS NULL - AND started_at IS NOT NULL - AND queue_name = %s - AND ( - started_at + (lease_timeout || ' seconds')::INTERVAL - ) < current_timestamp - ORDER BY can_start_at; - """ - ).format(sql.Identifier(self._table_name)), - (self._queue_name,), - ) - expired_tasks = cur.fetchall() - self.conn.commit() - logger.debug(f"Expired tasks {expired_tasks}") - for row in expired_tasks: - task_id: UUID = row[0] - logger.debug(f"Got expired task with id {task_id}") - task, ttl = self.get_updated_expired_task(task_id) - - if ttl is None: - # race condition! between the time we got `key` from the - # set of tasks (this outer loop) and the time we tried - # to get that task from the queue, it has been completed - # and therefore deleted from the queue. In this case - # tasks is None and we can continue - logger.info( - f"Task {task_id} was marked completed while we " - "checked for expired leases, nothing to do." + with self.pool.connection() as conn: + with conn.cursor() as cur: + cur.execute( + sql.SQL( + """ + SELECT id + FROM {} + WHERE completed_at IS NULL + AND started_at IS NOT NULL + AND queue_name = %s + AND ( + started_at + (lease_timeout || ' seconds')::INTERVAL + ) < current_timestamp + ORDER BY can_start_at; + """ + ).format(sql.Identifier(self._table_name)), + (self._queue_name,), ) - continue + expired_tasks = cur.fetchall() + conn.commit() + logger.debug(f"Expired tasks {expired_tasks}") + + for row in expired_tasks: + task_id: UUID = row[0] + logger.debug(f"Got expired task with id {task_id}") + task, ttl = self.get_updated_expired_task(task_id) + + if ttl is None: + # race condition! between the time we got `key` from the + # set of tasks (this outer loop) and the time we tried + # to get that task from the queue, it has been completed + # and therefore deleted from the queue. In this case + # tasks is None and we can continue + logger.info( + f"Task {task_id} was marked completed while we " + "checked for expired leases, nothing to do." + ) + continue - if ttl <= 0: - logger.error( - f"Job {task} with id {task_id} " - "failed too many times, marking it as completed." - ) - # # here committing to release the previous update lock - self.conn.commit() - self.complete(task_id) + if ttl <= 0: + logger.error( + f"Job {task} with id {task_id} " + "failed too many times, marking it as completed." + ) + # # here committing to release the previous update lock + conn.commit() + self.complete(task_id) - if self.ttl_zero_callback: - self.ttl_zero_callback(task_id, task) - self.conn.commit() + if self.ttl_zero_callback: + self.ttl_zero_callback(task_id, task) + conn.commit() def get_updated_expired_task( self, task_id: UUID @@ -545,42 +545,43 @@ def get_updated_expired_task( task_id after it's rescheduled """ - with self.conn.cursor() as cur: - cur.execute( - sql.SQL( - """ - UPDATE {} - SET ttl = ttl - 1, - started_at = NULL - WHERE id = ( - SELECT id - FROM {} - WHERE completed_at IS NULL - AND started_at IS NOT NULL - AND queue_name = %s - AND id = %s - FOR UPDATE SKIP LOCKED - LIMIT 1 + with self.pool.connection() as conn: + with conn.cursor() as cur: + cur.execute( + sql.SQL( + """ + UPDATE {} + SET ttl = ttl - 1, + started_at = NULL + WHERE id = ( + SELECT id + FROM {} + WHERE completed_at IS NULL + AND started_at IS NOT NULL + AND queue_name = %s + AND id = %s + FOR UPDATE SKIP LOCKED + LIMIT 1 + ) + RETURNING task, ttl; + """ + ).format( + sql.Identifier(self._table_name), + sql.Identifier(self._table_name), + ), + ( + self._queue_name, + task_id, + ), ) - RETURNING task, ttl; - """ - ).format( - sql.Identifier(self._table_name), - sql.Identifier(self._table_name), - ), - ( - self._queue_name, - task_id, - ), - ) - updated_row = cur.fetchone() + updated_row = cur.fetchone() + conn.commit() + if updated_row is None: + return None, None - if updated_row is None: - return None, None - - task, ttl = updated_row - task = self._serialize(task) - return task, ttl + task, ttl = updated_row + task = self._serialize(task) + return task, ttl def _serialize(self, task: Any) -> str: return json.dumps(task, sort_keys=True) @@ -626,45 +627,44 @@ def reschedule( if decrease_ttl: decrease_ttl_sql = "ttl = ttl - 1," - conn = self.conn - with conn.cursor() as cur: - cur.execute( - sql.SQL( - """ - UPDATE {} - SET {} started_at = NULL - WHERE id = ( - SELECT id - FROM {} - WHERE started_at IS NOT NULL - AND id = %s - FOR UPDATE SKIP LOCKED + with self.pool.connection() as conn: + with conn.cursor() as cur: + cur.execute( + sql.SQL( + """ + UPDATE {} + SET {} started_at = NULL + WHERE id = ( + SELECT id + FROM {} + WHERE started_at IS NOT NULL + AND id = %s + FOR UPDATE SKIP LOCKED + ) + RETURNING id;""" + ).format( + sql.Identifier(self._table_name), + sql.SQL(decrease_ttl_sql), + sql.Identifier(self._table_name), + ), + (task_id,), ) - RETURNING id;""" - ).format( - sql.Identifier(self._table_name), - sql.SQL(decrease_ttl_sql), - sql.Identifier(self._table_name), - ), - (task_id,), - ) - found = cur.fetchone() - conn.commit() - if found is None: - raise ValueError(f"Task {task_id} does not exist.") + found = cur.fetchone() + conn.commit() + if found is None: + raise ValueError(f"Task {task_id} does not exist.") def _reset(self) -> None: """Delete all tasks in the DB with our queue name.""" - with self.conn.cursor() as cursor: - cursor.execute( + with self.pool.connection() as conn: + conn.execute( sql.SQL("DELETE FROM {} WHERE queue_name = %s ").format( sql.Identifier(self._table_name) ), (self._queue_name,), ) - - self.conn.commit() + conn.commit() def prune_completed_tasks(self, before: int) -> None: """Delete all completed tasks older than the given number of seconds. @@ -680,21 +680,22 @@ def prune_completed_tasks(self, before: int) -> None: logger.info(f"Pruning all tasks completed more than " f"{before} second(s) ago.") - with self.conn.cursor() as cursor: - cursor.execute( - sql.SQL( - """ - DELETE FROM {} - WHERE queue_name = %s - AND completed_at IS NOT NULL - AND completed_at < current_timestamp - CAST( - %s || ' seconds' AS INTERVAL); - """ - ).format(sql.Identifier(self._table_name)), - (self._queue_name, before), - ) + with self.pool.connection() as conn: + with conn.cursor() as cursor: + cursor.execute( + sql.SQL( + """ + DELETE FROM {} + WHERE queue_name = %s + AND completed_at IS NOT NULL + AND completed_at < current_timestamp - CAST( + %s || ' seconds' AS INTERVAL); + """ + ).format(sql.Identifier(self._table_name)), + (self._queue_name, before), + ) - self.conn.commit() + conn.commit() def __iter__( self, diff --git a/pyproject.toml b/pyproject.toml index f471261..8e6aa1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ authors = [ ] dependencies = [ "psycopg[binary]>=3.1.12", + "psycopg-pool>=3.1.12", ] requires-python = ">=3.9" readme = "README.md"