Skip to content

Commit

Permalink
fix: Ensure Postgres queries are committed or autocommit is used (#5039)
Browse files Browse the repository at this point in the history
  • Loading branch information
TomSteenbergen authored Feb 11, 2025
1 parent d937dcb commit 46f8d7a
Showing 1 changed file with 12 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,24 +58,28 @@ class PostgreSQLOnlineStore(OnlineStore):
_conn_pool_async: Optional[AsyncConnectionPool] = None

@contextlib.contextmanager
def _get_conn(self, config: RepoConfig) -> Generator[Connection, Any, Any]:
def _get_conn(
self, config: RepoConfig, autocommit: bool = False
) -> Generator[Connection, Any, Any]:
assert config.online_store.type == "postgres"

if config.online_store.conn_type == ConnectionType.pool:
if not self._conn_pool:
self._conn_pool = _get_connection_pool(config.online_store)
self._conn_pool.open()
connection = self._conn_pool.getconn()
connection.set_autocommit(autocommit)
yield connection
self._conn_pool.putconn(connection)
else:
if not self._conn:
self._conn = _get_conn(config.online_store)
self._conn.set_autocommit(autocommit)
yield self._conn

@contextlib.asynccontextmanager
async def _get_conn_async(
self, config: RepoConfig
self, config: RepoConfig, autocommit: bool = False
) -> AsyncGenerator[AsyncConnection, Any]:
if config.online_store.conn_type == ConnectionType.pool:
if not self._conn_pool_async:
Expand All @@ -84,11 +88,13 @@ async def _get_conn_async(
)
await self._conn_pool_async.open()
connection = await self._conn_pool_async.getconn()
await connection.set_autocommit(autocommit)
yield connection
await self._conn_pool_async.putconn(connection)
else:
if not self._conn_async:
self._conn_async = await _get_conn_async(config.online_store)
await self._conn_async.set_autocommit(autocommit)
yield self._conn_async

def online_write_batch(
Expand Down Expand Up @@ -161,7 +167,7 @@ def online_read(
config, table, keys, requested_features
)

with self._get_conn(config) as conn, conn.cursor() as cur:
with self._get_conn(config, autocommit=True) as conn, conn.cursor() as cur:
cur.execute(query, params)
rows = cur.fetchall()

Expand All @@ -179,7 +185,7 @@ async def online_read_async(
config, table, keys, requested_features
)

async with self._get_conn_async(config) as conn:
async with self._get_conn_async(config, autocommit=True) as conn:
async with conn.cursor() as cur:
await cur.execute(query, params)
rows = await cur.fetchall()
Expand Down Expand Up @@ -339,6 +345,7 @@ def teardown(
for table in tables:
table_name = _table_id(project, table)
cur.execute(_drop_table_and_index(table_name))
conn.commit()
except Exception:
logging.exception("Teardown failed")
raise
Expand Down Expand Up @@ -398,7 +405,7 @@ def retrieve_online_documents(
Optional[ValueProto],
]
] = []
with self._get_conn(config) as conn, conn.cursor() as cur:
with self._get_conn(config, autocommit=True) as conn, conn.cursor() as cur:
table_name = _table_id(project, table)

# Search query template to find the top k items that are closest to the given embedding
Expand Down

0 comments on commit 46f8d7a

Please sign in to comment.