diff --git a/superset/key_value/commands/delete.py b/superset/key_value/commands/delete.py index b3cf84be07515..8b9095c09c9b2 100644 --- a/superset/key_value/commands/delete.py +++ b/superset/key_value/commands/delete.py @@ -57,13 +57,7 @@ def validate(self) -> None: def delete(self) -> bool: filter_ = get_filter(self.resource, self.key) - entry = ( - db.session.query(KeyValueEntry) - .filter_by(**filter_) - .autoflush(False) - .first() - ) - if entry: + if entry := db.session.query(KeyValueEntry).filter_by(**filter_).first(): db.session.delete(entry) db.session.commit() return True diff --git a/superset/key_value/commands/get.py b/superset/key_value/commands/get.py index 9d659f3bc7c06..8a7a250f1c088 100644 --- a/superset/key_value/commands/get.py +++ b/superset/key_value/commands/get.py @@ -66,12 +66,7 @@ def validate(self) -> None: def get(self) -> Optional[Any]: filter_ = get_filter(self.resource, self.key) - entry = ( - db.session.query(KeyValueEntry) - .filter_by(**filter_) - .autoflush(False) - .first() - ) + entry = db.session.query(KeyValueEntry).filter_by(**filter_).first() if entry and (entry.expires_on is None or entry.expires_on > datetime.now()): return self.codec.decode(entry.value) return None diff --git a/superset/key_value/commands/update.py b/superset/key_value/commands/update.py index becd6d9ca8d01..4bcd496243dda 100644 --- a/superset/key_value/commands/update.py +++ b/superset/key_value/commands/update.py @@ -77,10 +77,7 @@ def validate(self) -> None: def update(self) -> Optional[Key]: filter_ = get_filter(self.resource, self.key) entry: KeyValueEntry = ( - db.session.query(KeyValueEntry) - .filter_by(**filter_) - .autoflush(False) - .first() + db.session.query(KeyValueEntry).filter_by(**filter_).first() ) if entry: entry.value = self.codec.encode(self.value) diff --git a/superset/key_value/commands/upsert.py b/superset/key_value/commands/upsert.py index c5668f11610ab..9a4092c002716 100644 --- a/superset/key_value/commands/upsert.py +++ b/superset/key_value/commands/upsert.py @@ -81,10 +81,7 @@ def validate(self) -> None: def upsert(self) -> Key: filter_ = get_filter(self.resource, self.key) entry: KeyValueEntry = ( - db.session.query(KeyValueEntry) - .filter_by(**filter_) - .autoflush(False) - .first() + db.session.query(KeyValueEntry).filter_by(**filter_).first() ) if entry: entry.value = self.codec.encode(self.value) diff --git a/superset/tags/models.py b/superset/tags/models.py index a469c7a33d22b..7a77677a367ad 100644 --- a/superset/tags/models.py +++ b/superset/tags/models.py @@ -20,9 +20,9 @@ from typing import TYPE_CHECKING from flask_appbuilder import Model -from sqlalchemy import Column, Enum, ForeignKey, Integer, String, Table, Text +from sqlalchemy import Column, Enum, ForeignKey, Integer, orm, String, Table, Text from sqlalchemy.engine.base import Connection -from sqlalchemy.orm import relationship, Session, sessionmaker +from sqlalchemy.orm import relationship, sessionmaker from sqlalchemy.orm.mapper import Mapper from superset import security_manager @@ -35,7 +35,7 @@ from superset.models.slice import Slice from superset.models.sql_lab import Query -Session = sessionmaker(autoflush=False) +Session = sessionmaker() user_favorite_tag_table = Table( "user_favorite_tag", @@ -111,7 +111,7 @@ class TaggedObject(Model, AuditMixinNullable): tag = relationship("Tag", back_populates="objects", overlaps="tags") -def get_tag(name: str, session: Session, type_: TagType) -> Tag: +def get_tag(name: str, session: orm.Session, type_: TagType) -> Tag: tag_name = name.strip() tag = session.query(Tag).filter_by(name=tag_name, type=type_).one_or_none() if tag is None: @@ -148,7 +148,7 @@ def get_owners_ids( @classmethod def _add_owners( cls, - session: Session, + session: orm.Session, target: Dashboard | FavStar | Slice | Query | SqlaTable, ) -> None: for owner_id in cls.get_owners_ids(target): @@ -166,9 +166,7 @@ def after_insert( connection: Connection, target: Dashboard | FavStar | Slice | Query | SqlaTable, ) -> None: - session = Session(bind=connection) - - try: + with Session(bind=connection) as session: # add `owner:` tags cls._add_owners(session, target) @@ -179,8 +177,6 @@ def after_insert( ) session.add(tagged_object) session.commit() - finally: - session.close() @classmethod def after_update( @@ -189,9 +185,7 @@ def after_update( connection: Connection, target: Dashboard | FavStar | Slice | Query | SqlaTable, ) -> None: - session = Session(bind=connection) - - try: + with Session(bind=connection) as session: # delete current `owner:` tags query = ( session.query(TaggedObject.id) @@ -210,8 +204,6 @@ def after_update( # add `owner:` tags cls._add_owners(session, target) session.commit() - finally: - session.close() @classmethod def after_delete( @@ -220,9 +212,7 @@ def after_delete( connection: Connection, target: Dashboard | FavStar | Slice | Query | SqlaTable, ) -> None: - session = Session(bind=connection) - - try: + with Session(bind=connection) as session: # delete row from `tagged_objects` session.query(TaggedObject).filter( TaggedObject.object_type == cls.object_type, @@ -230,8 +220,6 @@ def after_delete( ).delete() session.commit() - finally: - session.close() class ChartUpdater(ObjectUpdater): @@ -271,8 +259,7 @@ class FavStarUpdater: def after_insert( cls, _mapper: Mapper, connection: Connection, target: FavStar ) -> None: - session = Session(bind=connection) - try: + with Session(bind=connection) as session: name = f"favorited_by:{target.user_id}" tag = get_tag(name, session, TagType.favorited_by) tagged_object = TaggedObject( @@ -282,15 +269,12 @@ def after_insert( ) session.add(tagged_object) session.commit() - finally: - session.close() @classmethod def after_delete( cls, _mapper: Mapper, connection: Connection, target: FavStar ) -> None: - session = Session(bind=connection) - try: + with Session(bind=connection) as session: name = f"favorited_by:{target.user_id}" query = ( session.query(TaggedObject.id) @@ -307,5 +291,3 @@ def after_delete( ) session.commit() - finally: - session.close() diff --git a/tests/integration_tests/key_value/commands/create_test.py b/tests/integration_tests/key_value/commands/create_test.py index a2ee3d13aed22..c7ba076b5fee4 100644 --- a/tests/integration_tests/key_value/commands/create_test.py +++ b/tests/integration_tests/key_value/commands/create_test.py @@ -46,9 +46,7 @@ def test_create_id_entry(app_context: AppContext, admin: User) -> None: value=JSON_VALUE, codec=JSON_CODEC, ).run() - entry = ( - db.session.query(KeyValueEntry).filter_by(id=key.id).autoflush(False).one() - ) + entry = db.session.query(KeyValueEntry).filter_by(id=key.id).one() assert json.loads(entry.value) == JSON_VALUE assert entry.created_by_fk == admin.id db.session.delete(entry) @@ -63,9 +61,7 @@ def test_create_uuid_entry(app_context: AppContext, admin: User) -> None: key = CreateKeyValueCommand( resource=RESOURCE, value=JSON_VALUE, codec=JSON_CODEC ).run() - entry = ( - db.session.query(KeyValueEntry).filter_by(uuid=key.uuid).autoflush(False).one() - ) + entry = db.session.query(KeyValueEntry).filter_by(uuid=key.uuid).one() assert json.loads(entry.value) == JSON_VALUE assert entry.created_by_fk == admin.id db.session.delete(entry) @@ -93,9 +89,7 @@ def test_create_pickle_entry(app_context: AppContext, admin: User) -> None: value=PICKLE_VALUE, codec=PICKLE_CODEC, ).run() - entry = ( - db.session.query(KeyValueEntry).filter_by(id=key.id).autoflush(False).one() - ) + entry = db.session.query(KeyValueEntry).filter_by(id=key.id).one() assert type(pickle.loads(entry.value)) == type(PICKLE_VALUE) assert entry.created_by_fk == admin.id db.session.delete(entry) diff --git a/tests/integration_tests/key_value/commands/update_test.py b/tests/integration_tests/key_value/commands/update_test.py index 2c0fc3e31de51..816a6f857ab0f 100644 --- a/tests/integration_tests/key_value/commands/update_test.py +++ b/tests/integration_tests/key_value/commands/update_test.py @@ -57,7 +57,7 @@ def test_update_id_entry( ).run() assert key is not None assert key.id == ID_KEY - entry = db.session.query(KeyValueEntry).filter_by(id=ID_KEY).autoflush(False).one() + entry = db.session.query(KeyValueEntry).filter_by(id=ID_KEY).one() assert json.loads(entry.value) == NEW_VALUE assert entry.changed_by_fk == admin.id @@ -79,9 +79,7 @@ def test_update_uuid_entry( ).run() assert key is not None assert key.uuid == UUID_KEY - entry = ( - db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).autoflush(False).one() - ) + entry = db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).one() assert json.loads(entry.value) == NEW_VALUE assert entry.changed_by_fk == admin.id diff --git a/tests/integration_tests/key_value/commands/upsert_test.py b/tests/integration_tests/key_value/commands/upsert_test.py index c26b66d02e7bf..9b094ef65ec45 100644 --- a/tests/integration_tests/key_value/commands/upsert_test.py +++ b/tests/integration_tests/key_value/commands/upsert_test.py @@ -57,9 +57,7 @@ def test_upsert_id_entry( ).run() assert key is not None assert key.id == ID_KEY - entry = ( - db.session.query(KeyValueEntry).filter_by(id=int(ID_KEY)).autoflush(False).one() - ) + entry = db.session.query(KeyValueEntry).filter_by(id=int(ID_KEY)).one() assert json.loads(entry.value) == NEW_VALUE assert entry.changed_by_fk == admin.id @@ -81,9 +79,7 @@ def test_upsert_uuid_entry( ).run() assert key is not None assert key.uuid == UUID_KEY - entry = ( - db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).autoflush(False).one() - ) + entry = db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).one() assert json.loads(entry.value) == NEW_VALUE assert entry.changed_by_fk == admin.id