Skip to content

Commit

Permalink
Merge pull request #216 from dinesh-aot/COMP-101
Browse files Browse the repository at this point in the history
SqlAlchemy session management fix
  • Loading branch information
nitheesh-aot authored Feb 4, 2025
2 parents 60c72e4 + bfd9c77 commit 16f0ba8
Show file tree
Hide file tree
Showing 24 changed files with 280 additions and 371 deletions.
51 changes: 21 additions & 30 deletions compliance-api/src/compliance_api/models/case_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from compliance_api.utils.constant import DELETE_DIC_PARAMS

from .base_model import BaseModelVersioned
from .db import db
from .utils import with_session


class CaseFileInitiationEnum(enum.Enum):
Expand Down Expand Up @@ -117,28 +117,24 @@ class CaseFile(BaseModelVersioned):
)

@classmethod
@with_session
def create_case_file(cls, case_file_data, session=None):
"""Persist case file data in database."""
case_file = CaseFile(**case_file_data)
if session:
session.add(case_file)
session.flush()
else:
case_file.save()
session.add(case_file)
session.flush()
return case_file

@classmethod
@with_session
def update_case_file(cls, case_file_id, case_file_data, session=None):
"""Update the case file data in database."""
query = cls.query.filter_by(id=case_file_id)
case_file: CaseFile = query.first()
if not case_file or case_file.is_deleted:
return None
case_file.update(case_file_data, commit=False)
if session:
session.flush()
else:
db.session.commit()
session.flush()
return case_file

@classmethod
Expand All @@ -149,6 +145,7 @@ def get_by_file_number(cls, case_file_number):
).first()

@classmethod
@with_session
def change_status(
cls, case_file_id, case_file_status: CaseFileStatusEnum, session=None
):
Expand All @@ -158,10 +155,7 @@ def change_status(
if not case_file or case_file.is_deleted:
return None
case_file.update({"case_file_status": case_file_status}, commit=False)
if session:
session.flush()
else:
db.session.commit()
session.flush()
return case_file

@classmethod
Expand All @@ -184,7 +178,9 @@ def get_max_case_file_number_by_year(cls, year: int):
.filter(
func.regexp_replace(cls.case_file_number, "[^0-9]", "", "g").op("~")(
f"^{year}[0-9]{{4}}$"
)
),
cls.is_active.is_(True),
cls.is_deleted.is_(False)
)
.scalar()
)
Expand Down Expand Up @@ -234,6 +230,7 @@ def get_all_by_case_file_id(cls, case_file_id: int):
return cls.query.filter_by(case_file_id=case_file_id, is_deleted=False).all()

@classmethod
@with_session
def bulk_delete(cls, case_file_id: int, officer_ids: list[int], session=None):
"""Delete officer ids by id per case file."""
query = session.query(CaseFileOfficer) if session else cls.query
Expand All @@ -242,20 +239,18 @@ def bulk_delete(cls, case_file_id: int, officer_ids: list[int], session=None):
)
for officer in officers:
officer.update(DELETE_DIC_PARAMS, commit=not session)
session.flush()

@classmethod
@with_session
def bulk_insert(cls, case_file_id: int, officer_ids: list[int], session=None):
"""Insert officers per case file."""
case_file_officer_data = [
CaseFileOfficer(**{"case_file_id": case_file_id, "officer_id": officer_id})
for officer_id in officer_ids
]
if session:
session.add_all(case_file_officer_data)
session.flush()
else:
db.session.add_all(case_file_officer_data)
db.session.commit()
session.add_all(case_file_officer_data)
session.flush()


class CaseFileInitiationOption(BaseModelVersioned):
Expand Down Expand Up @@ -318,6 +313,7 @@ def get_links_by_source_and_target(cls, source_id, target_id):
).first()

@classmethod
@with_session
def delete_link(cls, source_id, taget_id, session=None):
"""Delete the case file link."""
links = cls.query.filter(
Expand All @@ -327,18 +323,13 @@ def delete_link(cls, source_id, taget_id, session=None):
).all()
for link in links:
link.update(DELETE_DIC_PARAMS, commit=False)
if session:
session.flush()
else:
db.session.commit()
session.flush()

@classmethod
@with_session
def create_link(cls, link_data, session=None):
"""Persist case file link data in database."""
case_file = CaseFileLink(**link_data)
if session:
session.add(case_file)
session.flush()
else:
case_file.save()
session.add(case_file)
session.flush()
return case_file
39 changes: 16 additions & 23 deletions compliance-api/src/compliance_api/models/complaint/complaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@

from compliance_api.utils.constant import DELETE_DIC_PARAMS

from ..base_model import BaseModelVersioned, db
from ..base_model import BaseModelVersioned
from ..case_file import CaseFile as CaseFileModel
from ..utils import with_session


class ComplaintStatusEnum(enum.Enum):
Expand Down Expand Up @@ -134,52 +135,48 @@ def get_count_by_project_nd_case_file_id(cls, project_id: int, case_file_id: int
func.count(Complaint.id).label( # pylint: disable=not-callable
"complaint_count"
),
)
).join(CaseFileModel, CaseFileModel.id == Complaint.case_file_id)
.filter(
CaseFileModel.project_id == project_id,
Complaint.case_file_id == case_file_id,
Complaint.is_active.is_(True),
Complaint.is_deleted.is_(False)
)
.group_by(Complaint.case_file_id, CaseFileModel.project_id)
.first()
)
return result.complaint_count if result else 0

@classmethod
@with_session
def create_complaint(cls, complaint_obj, session=None):
"""Persist inspection in database."""
complaint = Complaint(**complaint_obj)
if session:
session.add(complaint)
session.flush()
else:
complaint.save()
session.add(complaint)
session.flush()
return complaint

@classmethod
@with_session
def update_complaint(cls, complaint_id, complaint_data, session=None):
"""Update inspection."""
query = cls.query.filter_by(id=complaint_id)
complaint: Complaint = query.first()
if not complaint or complaint.is_deleted:
return None
complaint.update(complaint_data, commit=False)
if session:
session.flush()
else:
db.session.commit()
session.flush()
return complaint

@classmethod
@with_session
def change_status(
cls, complaint_id, complaint_status: ComplaintStatusEnum, session=None
):
"""Update the complaint status."""
complaint = cls.query.filter(cls.id == complaint_id).first()
complaint.update({"status": complaint_status}, commit=False)
if session:
session.flush()
else:
db.session.commit()
session.flush()

@classmethod
def get_by_complaint_number(cls, complaint_number):
Expand All @@ -189,24 +186,20 @@ def get_by_complaint_number(cls, complaint_number):
).first()

@classmethod
@with_session
def delete_by_case_file(cls, case_file_id, session=None):
"""Delete complaint by case file."""
complaints = cls.query.filter(
Complaint.case_file_id == case_file_id, Complaint.is_deleted is False
).all()
for complaint in complaints:
complaint.update(DELETE_DIC_PARAMS, commit=False)
if session:
session.flush()
else:
db.session.commit()
session.flush()

@classmethod
@with_session
def delete_complaint(cls, complaint_id, session=None):
"""Delete complaint."""
complaint = cls.query.filter(Complaint.id == complaint_id).first()
complaint.update(DELETE_DIC_PARAMS, commit=False)
if session:
session.flush()
else:
db.session.commit()
session.flush()
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from compliance_api.utils.constant import DELETE_DIC_PARAMS

from ..base_model import BaseModelVersioned, db
from ..base_model import BaseModelVersioned
from ..utils import with_session


class ComplaintReqEACDetail(BaseModelVersioned):
Expand Down Expand Up @@ -46,27 +47,23 @@ def to_dict(self):
}

@classmethod
@with_session
def create(cls, requirement_obj, session=None):
"""Create eac details."""
requirement_more = ComplaintReqEACDetail(**requirement_obj)
if session:
session.add(requirement_more)
session.flush()
else:
requirement_more.save()
session.add(requirement_more)
session.flush()
return requirement_more

@classmethod
@with_session
def delete_eac_details(cls, requirement_id, session=None):
"""Mark the details as deleted."""
requirement = cls.query.filter_by(
req_id=requirement_id, is_deleted=False
).first()
requirement.update(DELETE_DIC_PARAMS, commit=False)
if session:
session.flush()
else:
db.session.commit()
session.flush()

@classmethod
def get_by_requirement(cls, req_id):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from compliance_api.utils.constant import DELETE_DIC_PARAMS

from ..base_model import BaseModelVersioned, db
from ..base_model import BaseModelVersioned
from ..utils import with_session


class ComplaintReqOrderDetail(BaseModelVersioned):
Expand Down Expand Up @@ -38,27 +39,23 @@ def to_dict(self):
return {"id": self.id, "req_id": self.req_id, "order_number": self.order_number}

@classmethod
@with_session
def create(cls, requirement_obj, session=None):
"""Create order details."""
requirement_more = ComplaintReqOrderDetail(**requirement_obj)
if session:
session.add(requirement_more)
session.flush()
else:
requirement_more.save()
session.add(requirement_more)
session.flush()
return requirement_more

@classmethod
@with_session
def delete_order_details(cls, requirement_id, session=None):
"""Mark the details as deleted."""
requirement = cls.query.filter_by(
req_id=requirement_id, is_deleted=False
).first()
requirement.update(DELETE_DIC_PARAMS, commit=False)
if session:
session.flush()
else:
db.session.commit()
session.flush()

@classmethod
def get_by_requirement(cls, req_id):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from compliance_api.utils.constant import DELETE_DIC_PARAMS

from ..base_model import BaseModelVersioned, db
from ..base_model import BaseModelVersioned
from ..utils import with_session


class ComplaintReqScheduleBDetail(BaseModelVersioned):
Expand Down Expand Up @@ -42,27 +43,23 @@ def to_dict(self):
}

@classmethod
@with_session
def create(cls, requirement_obj, session=None):
"""Create schedule b details."""
requirement_more = ComplaintReqScheduleBDetail(**requirement_obj)
if session:
session.add(requirement_more)
session.flush()
else:
requirement_more.save()
session.add(requirement_more)
session.flush()
return requirement_more

@classmethod
@with_session
def delete_schedule_b_details(cls, requirement_id, session=None):
"""Mark the details as deleted."""
requirement = cls.query.filter_by(
req_id=requirement_id, is_deleted=False
).first()
requirement.update(DELETE_DIC_PARAMS, commit=False)
if session:
session.flush()
else:
db.session.commit()
session.flush()

@classmethod
def get_by_requirement(cls, req_id):
Expand Down
Loading

0 comments on commit 16f0ba8

Please sign in to comment.