Skip to content

Commit

Permalink
fix data races in _applied_schema_hashes (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
sazikov-a authored Mar 4, 2024
1 parent 2f4d7cd commit ea1f773
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 16 deletions.
28 changes: 12 additions & 16 deletions testsuite/databases/pgsql/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ class PgControl:
_applied_schemas: typing.Dict[str, typing.Set[pathlib.Path]]
_connections: typing.Dict[str, ConnectionWrapper]
_connection_pool: typing.Optional[pool.AutocommitConnectionPool]
_applied_schema_hashes: typing.Optional[testsuite_db.AppliedSchemaHashes]

def __init__(
self,
Expand All @@ -237,16 +238,19 @@ def __init__(
self._applied_schemas = {}
self._skip_applied_schemas = skip_applied_schemas

@testsuite_utils.cached_property
def _applied_schema_hashes(
self,
) -> typing.Optional[testsuite_db.AppliedSchemaHashes]:
def initialize(self) -> None:
if not self._connection_pool:
self._connection_pool = pool.AutocommitConnectionPool(
minconn=1, maxconn=10, uri=self._get_connection_uri('postgres')
)

if self._skip_applied_schemas:
return testsuite_db.AppliedSchemaHashes(
self._get_connection_pool(),
self._applied_schema_hashes = None
else:
self._applied_schema_hashes = testsuite_db.AppliedSchemaHashes(
self._connection_pool,
self._conninfo,
)
return None

def get_connection_cached(self, dbname) -> ConnectionWrapper:
if dbname not in self._connections:
Expand Down Expand Up @@ -288,7 +292,7 @@ def _initialize_shard(self, shard: discover.PgShard) -> None:
def _create_database(self, dbname: str) -> None:
if dbname in self._applied_schemas:
return
with self._get_connection_pool().get_connection() as connection:
with self._connection_pool.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute(DROP_DATABASE_TEMPLATE.format(dbname))
cursor.execute(CREATE_DATABASE_TEMPLATE.format(dbname))
Expand Down Expand Up @@ -374,14 +378,6 @@ def close(self):
for conn in self._connections.values():
conn.close()

def _get_connection_pool(self) -> pool.AutocommitConnectionPool:
if not self._connection_pool:
self._connection_pool = pool.AutocommitConnectionPool(
minconn=1, maxconn=10, uri=self._get_connection_uri('postgres')
)

return self._connection_pool

def _get_connection_uri(self, dbname: str) -> str:
return self._conninfo.replace(dbname=dbname).get_uri()

Expand Down
2 changes: 2 additions & 0 deletions testsuite/databases/pgsql/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def initialize(
if self._initialized:
return self._shard_connections

self._pgsql_control.initialize()

def init_database(db):
self._pgsql_control.initialize_sharded_db(db)

Expand Down

0 comments on commit ea1f773

Please sign in to comment.