From 41bdfb317650d99a5ebc55e1ca6236104b3644c7 Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Mon, 19 Aug 2024 21:04:36 +0530 Subject: [PATCH 1/5] Fix a race condition which allowed duplicate consultations to be created --- .../api/serializers/patient_consultation.py | 229 +++++++++--------- .../api/viewsets/patient_consultation.py | 5 + care/utils/lock.py | 35 +++ 3 files changed, 154 insertions(+), 115 deletions(-) create mode 100644 care/utils/lock.py diff --git a/care/facility/api/serializers/patient_consultation.py b/care/facility/api/serializers/patient_consultation.py index a316153caa..f34acaced6 100644 --- a/care/facility/api/serializers/patient_consultation.py +++ b/care/facility/api/serializers/patient_consultation.py @@ -64,6 +64,7 @@ UserBaseMinimumSerializer, ) from care.users.models import User +from care.utils.lock import Lock from care.utils.notification_handler import NotificationGenerator from care.utils.queryset.facility import get_home_facility_queryset from care.utils.serializer.external_id_field import ExternalIdSerializerField @@ -191,6 +192,9 @@ def get_discharge_prn_prescription(self, consultation): dosage_type=PrescriptionDosageType.PRN.value, ).values() + def _lock_key(self, patient_id): + return f"patient_consultation__patient_registration__{patient_id}" + class Meta: model = PatientConsultation read_only_fields = TIMESTAMP_FIELDS + ( @@ -353,18 +357,16 @@ def create(self, validated_data): create_diagnosis = validated_data.pop("create_diagnoses") create_symptoms = validated_data.pop("create_symptoms") - action = -1 - review_interval = -1 - if "action" in validated_data: - action = validated_data.pop("action") - if "review_interval" in validated_data: - review_interval = validated_data.pop("review_interval") + + action = validated_data.pop("action", -1) + review_interval = validated_data.get("review_interval", -1) # Authorisation Check - allowed_facilities = get_home_facility_queryset(self.context["request"].user) + user = self.context["request"].user + allowed_facilities = get_home_facility_queryset(user) if not allowed_facilities.filter( - id=self.validated_data["patient"].facility.id + id=self.validated_data["patient"].facility_id ).exists(): raise ValidationError( {"facility": "Consultation creates are only allowed in home facility"} @@ -372,130 +374,127 @@ def create(self, validated_data): # End Authorisation Checks - if validated_data["patient"].last_consultation: + with Lock( + self._lock_key(validated_data["patient"].id), 30 + ), transaction.atomic(): + patient = validated_data["patient"] + if patient.last_consultation: + if patient.last_consultation.assigned_to == user: + raise ValidationError( + { + "Permission Denied": "Only Facility Staff can create consultation for a Patient" + }, + ) + + if not patient.last_consultation.discharge_date: + raise ValidationError( + {"consultation": "Exists please Edit Existing Consultation"} + ) + + if "is_kasp" in validated_data: + if validated_data["is_kasp"]: + validated_data["kasp_enabled_date"] = now() + + # Coercing facility as the patient's facility + validated_data["facility_id"] = patient.facility_id + + consultation: PatientConsultation = super().create(validated_data) + consultation.created_by = user + consultation.last_edited_by = user + consultation.previous_consultation = patient.last_consultation + last_consultation = patient.last_consultation if ( - self.context["request"].user - == validated_data["patient"].last_consultation.assigned_to + last_consultation + and consultation.suggestion == SuggestionChoices.A + and last_consultation.suggestion == SuggestionChoices.A + and last_consultation.discharge_date + and last_consultation.discharge_date + timedelta(days=30) + > consultation.encounter_date ): - raise ValidationError( - { - "Permission Denied": "Only Facility Staff can create consultation for a Patient" - }, - ) + consultation.is_readmission = True + + diagnosis = ConsultationDiagnosis.objects.bulk_create( + [ + ConsultationDiagnosis( + consultation=consultation, + diagnosis_id=obj["diagnosis"].id, + is_principal=obj["is_principal"], + verification_status=obj["verification_status"], + created_by=user, + ) + for obj in create_diagnosis + ] + ) - if validated_data["patient"].last_consultation: - if not validated_data["patient"].last_consultation.discharge_date: - raise ValidationError( - {"consultation": "Exists please Edit Existing Consultation"} + symptoms = EncounterSymptom.objects.bulk_create( + EncounterSymptom( + consultation=consultation, + symptom=obj.get("symptom"), + onset_date=obj.get("onset_date"), + cure_date=obj.get("cure_date"), + clinical_impression_status=obj.get("clinical_impression_status"), + other_symptom=obj.get("other_symptom") or "", + created_by=user, ) + for obj in create_symptoms + ) - if "is_kasp" in validated_data: - if validated_data["is_kasp"]: - validated_data["kasp_enabled_date"] = localtime(now()) - - bed = validated_data.pop("bed", None) - - validated_data["facility_id"] = validated_data[ - "patient" - ].facility_id # Coercing facility as the patient's facility - consultation = super().create(validated_data) - consultation.created_by = self.context["request"].user - consultation.last_edited_by = self.context["request"].user - patient = consultation.patient - consultation.previous_consultation = patient.last_consultation - last_consultation = patient.last_consultation - if ( - last_consultation - and consultation.suggestion == SuggestionChoices.A - and last_consultation.suggestion == SuggestionChoices.A - and last_consultation.discharge_date - and last_consultation.discharge_date + timedelta(days=30) - > consultation.encounter_date - ): - consultation.is_readmission = True - consultation.save() - - diagnosis = ConsultationDiagnosis.objects.bulk_create( - [ - ConsultationDiagnosis( + bed = validated_data.pop("bed", None) + if bed and consultation.suggestion == SuggestionChoices.A: + consultation_bed = ConsultationBed( + bed=bed, consultation=consultation, - diagnosis_id=obj["diagnosis"].id, - is_principal=obj["is_principal"], - verification_status=obj["verification_status"], - created_by=self.context["request"].user, + start_date=consultation.created_date, ) - for obj in create_diagnosis - ] - ) + consultation_bed.save() + consultation.current_bed = consultation_bed - symptoms = EncounterSymptom.objects.bulk_create( - EncounterSymptom( - consultation=consultation, - symptom=obj.get("symptom"), - onset_date=obj.get("onset_date"), - cure_date=obj.get("cure_date"), - clinical_impression_status=obj.get("clinical_impression_status"), - other_symptom=obj.get("other_symptom") or "", - created_by=self.context["request"].user, - ) - for obj in create_symptoms - ) + if consultation.suggestion == SuggestionChoices.OP: + consultation.discharge_date = now() + patient.is_active = False + patient.allow_transfer = True + else: + patient.is_active = True + patient.last_consultation = consultation - if bed and consultation.suggestion == SuggestionChoices.A: - consultation_bed = ConsultationBed( - bed=bed, - consultation=consultation, - start_date=consultation.created_date, - ) - consultation_bed.save() - consultation.current_bed = consultation_bed - consultation.save(update_fields=["current_bed"]) + if action != -1: + patient.action = action - if consultation.suggestion == SuggestionChoices.OP: - consultation.discharge_date = localtime(now()) - consultation.save() - patient.is_active = False - patient.allow_transfer = True - else: - patient.is_active = True - patient.last_consultation = consultation - - if action != -1: - patient.action = action - consultation.review_interval = review_interval - if review_interval > 0: - patient.review_time = localtime(now()) + timedelta(minutes=review_interval) - else: - patient.review_time = None + if review_interval > 0: + patient.review_time = now() + timedelta(minutes=review_interval) + else: + patient.review_time = None - patient.save() - NotificationGenerator( - event=Notification.Event.PATIENT_CONSULTATION_CREATED, - caused_by=self.context["request"].user, - caused_object=consultation, - facility=patient.facility, - ).generate() + consultation.save() + patient.save() - create_consultation_events( - consultation.id, - (consultation, *diagnosis, *symptoms), - consultation.created_by.id, - consultation.created_date, - ) + create_consultation_events( + consultation.id, + (consultation, *diagnosis, *symptoms), + consultation.created_by.id, + consultation.created_date, + ) - if consultation.assigned_to: NotificationGenerator( - event=Notification.Event.PATIENT_CONSULTATION_ASSIGNMENT, - caused_by=self.context["request"].user, + event=Notification.Event.PATIENT_CONSULTATION_CREATED, + caused_by=user, caused_object=consultation, - facility=consultation.patient.facility, - notification_mediums=[ - Notification.Medium.SYSTEM, - Notification.Medium.WHATSAPP, - ], + facility=patient.facility, ).generate() - return consultation + if consultation.assigned_to: + NotificationGenerator( + event=Notification.Event.PATIENT_CONSULTATION_ASSIGNMENT, + caused_by=user, + caused_object=consultation, + facility=consultation.patient.facility, + notification_mediums=[ + Notification.Medium.SYSTEM, + Notification.Medium.WHATSAPP, + ], + ).generate() + + return consultation def validate_create_diagnoses(self, value): # Reject if create_diagnoses is present for edits diff --git a/care/facility/api/viewsets/patient_consultation.py b/care/facility/api/viewsets/patient_consultation.py index 4fc1b857b2..95aa138331 100644 --- a/care/facility/api/viewsets/patient_consultation.py +++ b/care/facility/api/viewsets/patient_consultation.py @@ -1,3 +1,4 @@ +from django.db import transaction from django.db.models import Prefetch from django.db.models.query_utils import Q from django.shortcuts import get_object_or_404, render @@ -109,6 +110,10 @@ def get_queryset(self): applied_filters |= Q(facility=self.request.user.home_facility) return self.queryset.filter(applied_filters) + @transaction.non_atomic_requests + def create(self, request, *args, **kwargs) -> Response: + return super().create(request, *args, **kwargs) + @extend_schema(tags=["consultation"]) @action(detail=True, methods=["POST"]) def discharge_patient(self, request, *args, **kwargs): diff --git a/care/utils/lock.py b/care/utils/lock.py new file mode 100644 index 0000000000..7600105b38 --- /dev/null +++ b/care/utils/lock.py @@ -0,0 +1,35 @@ +from django.core.cache import cache +from rest_framework.exceptions import APIException + + +class ObjectLocked(APIException): + status_code = 423 + default_detail = "The resource you are trying to access is locked" + default_code = "object_locked" + + +class Lock: + def __init__(self, key, timeout=None): + self.key = f"lock:{key}" + self.timeout = timeout + + def acquire(self): + try: + if not cache.set(self.key, True, self.timeout, nx=True): + raise ObjectLocked() + # handle nx not supported + except TypeError: + if cache.get(self.key): + raise ObjectLocked() + cache.set(self.key, True, self.timeout) + + def release(self): + return cache.delete(self.key) + + def __enter__(self): + self.acquire() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.release() + return False From 69b46f0cc401d11a00196c575929fd36270cb14d Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Mon, 19 Aug 2024 22:13:21 +0530 Subject: [PATCH 2/5] make timeout constant --- care/facility/api/serializers/patient_consultation.py | 7 ++++--- care/utils/lock.py | 3 ++- config/settings/base.py | 3 +++ 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/care/facility/api/serializers/patient_consultation.py b/care/facility/api/serializers/patient_consultation.py index f34acaced6..0bc163c92f 100644 --- a/care/facility/api/serializers/patient_consultation.py +++ b/care/facility/api/serializers/patient_consultation.py @@ -374,9 +374,10 @@ def create(self, validated_data): # End Authorisation Checks - with Lock( - self._lock_key(validated_data["patient"].id), 30 - ), transaction.atomic(): + with ( + Lock(self._lock_key(validated_data["patient"].id)), + transaction.atomic(), + ): patient = validated_data["patient"] if patient.last_consultation: if patient.last_consultation.assigned_to == user: diff --git a/care/utils/lock.py b/care/utils/lock.py index 7600105b38..12d80cb2ea 100644 --- a/care/utils/lock.py +++ b/care/utils/lock.py @@ -1,3 +1,4 @@ +from django.conf import settings from django.core.cache import cache from rest_framework.exceptions import APIException @@ -9,7 +10,7 @@ class ObjectLocked(APIException): class Lock: - def __init__(self, key, timeout=None): + def __init__(self, key, timeout=settings.LOCK_TIMEOUT): self.key = f"lock:{key}" self.timeout = timeout diff --git a/config/settings/base.py b/config/settings/base.py index 8e67190fe9..8cac4e189a 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -62,6 +62,9 @@ DATABASES["default"]["CONN_MAX_AGE"] = env.int("CONN_MAX_AGE", default=0) DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" +# timeout for setnx lock +LOCK_TIMEOUT = env.int("LOCK_TIMEOUT", default=32) + REDIS_URL = env("REDIS_URL", default="redis://localhost:6379") # CACHES From a4c3d345275924d18c873c2938744b6e4aa0a571 Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Mon, 19 Aug 2024 22:22:11 +0530 Subject: [PATCH 3/5] cleanup lock --- care/utils/lock.py | 10 ++-------- care/utils/tests/test_utils.py | 2 +- config/caches.py | 12 ++++++++++++ config/settings/test.py | 2 +- 4 files changed, 16 insertions(+), 10 deletions(-) create mode 100644 config/caches.py diff --git a/care/utils/lock.py b/care/utils/lock.py index 12d80cb2ea..f622600daf 100644 --- a/care/utils/lock.py +++ b/care/utils/lock.py @@ -15,14 +15,8 @@ def __init__(self, key, timeout=settings.LOCK_TIMEOUT): self.timeout = timeout def acquire(self): - try: - if not cache.set(self.key, True, self.timeout, nx=True): - raise ObjectLocked() - # handle nx not supported - except TypeError: - if cache.get(self.key): - raise ObjectLocked() - cache.set(self.key, True, self.timeout) + if not cache.set(self.key, True, self.timeout, nx=True): + raise ObjectLocked() def release(self): return cache.delete(self.key) diff --git a/care/utils/tests/test_utils.py b/care/utils/tests/test_utils.py index 39ed7a6c42..e396b9a51e 100644 --- a/care/utils/tests/test_utils.py +++ b/care/utils/tests/test_utils.py @@ -50,7 +50,7 @@ def __init__(self, decorated): super().__init__( CACHES={ "default": { - "BACKEND": "django.core.cache.backends.locmem.LocMemCache", + "BACKEND": "config.caches.LocMemCache", "LOCATION": f"care-test-{uuid.uuid4()}", } }, diff --git a/config/caches.py b/config/caches.py new file mode 100644 index 0000000000..5f35a51e34 --- /dev/null +++ b/config/caches.py @@ -0,0 +1,12 @@ +from django.core.cache.backends import dummy, locmem +from django.core.cache.backends.base import DEFAULT_TIMEOUT + + +class DummyCache(dummy.DummyCache): + def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None, nx=False): + return super().set(key, value, timeout, version) + + +class LocMemCache(locmem.LocMemCache): + def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None, nx=False): + return super().set(key, value, timeout, version) diff --git a/config/settings/test.py b/config/settings/test.py index e78cde5780..69b5f54fb8 100644 --- a/config/settings/test.py +++ b/config/settings/test.py @@ -39,7 +39,7 @@ # test in peace CACHES = { "default": { - "BACKEND": "django.core.cache.backends.dummy.DummyCache", + "BACKEND": "config.caches.DummyCache", } } # for testing retelimit use override_settings decorator From d15295541951c7c01fa5fbad563301415e62af29 Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Mon, 19 Aug 2024 22:30:19 +0530 Subject: [PATCH 4/5] fix tests --- care/users/tests/test_api.py | 1 - config/caches.py | 10 ++++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/care/users/tests/test_api.py b/care/users/tests/test_api.py index bcd7216e0d..a91a891a74 100644 --- a/care/users/tests/test_api.py +++ b/care/users/tests/test_api.py @@ -240,7 +240,6 @@ def test_last_active_filter(self): response = self.client.get("/api/v1/users/?last_active_days=10") self.assertEqual(response.status_code, status.HTTP_200_OK) res_data_json = response.json() - print(res_data_json) self.assertEqual(res_data_json["count"], 2) self.assertIn( self.user_2.username, {r["username"] for r in res_data_json["results"]} diff --git a/config/caches.py b/config/caches.py index 5f35a51e34..99a90bc0fa 100644 --- a/config/caches.py +++ b/config/caches.py @@ -3,10 +3,12 @@ class DummyCache(dummy.DummyCache): - def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None, nx=False): - return super().set(key, value, timeout, version) + def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None, nx=True): + super().set(key, value, timeout, version) + return nx class LocMemCache(locmem.LocMemCache): - def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None, nx=False): - return super().set(key, value, timeout, version) + def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None, nx=True): + super().set(key, value, timeout, version) + return nx From af9da5b86f2ecff495f5837f8f4b73787dffc34e Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Mon, 19 Aug 2024 22:39:52 +0530 Subject: [PATCH 5/5] cleanups --- config/caches.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/config/caches.py b/config/caches.py index 99a90bc0fa..d6d7cbe123 100644 --- a/config/caches.py +++ b/config/caches.py @@ -3,12 +3,14 @@ class DummyCache(dummy.DummyCache): - def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None, nx=True): + def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None, nx=None): super().set(key, value, timeout, version) - return nx + # mimic the behavior of django_redis with setnx, for tests + return True class LocMemCache(locmem.LocMemCache): - def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None, nx=True): + def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None, nx=None): super().set(key, value, timeout, version) - return nx + # mimic the behavior of django_redis with setnx, for tests + return True