diff --git a/care/emr/api/viewsets/location.py b/care/emr/api/viewsets/location.py index 752caf0250..b965f6519b 100644 --- a/care/emr/api/viewsets/location.py +++ b/care/emr/api/viewsets/location.py @@ -1,4 +1,5 @@ from django_filters import rest_framework as filters +from drf_spectacular.utils import extend_schema from pydantic import UUID4, BaseModel from rest_framework.decorators import action from rest_framework.exceptions import PermissionDenied, ValidationError @@ -149,6 +150,10 @@ def authorize_organization(self, facility, organization): ): raise PermissionDenied("You do not have permission to given organizations") + @extend_schema( + request=FacilityLocationOrganizationManageSpec, + responses={200: FacilityOrganizationReadSpec}, + ) @action(detail=True, methods=["POST"]) def organizations_add(self, request, *args, **kwargs): instance = self.get_object() @@ -168,6 +173,9 @@ def organizations_add(self, request, *args, **kwargs): ) return Response(FacilityOrganizationReadSpec.serialize(organization).to_json()) + @extend_schema( + request=FacilityLocationOrganizationManageSpec, responses={204: None} + ) @action(detail=True, methods=["POST"]) def organizations_remove(self, request, *args, **kwargs): instance = self.get_object() @@ -192,6 +200,7 @@ def organizations_remove(self, request, *args, **kwargs): class FacilityLocationEncounterAssignSpec(BaseModel): encounter: UUID4 + @extend_schema(request=FacilityLocationEncounterAssignSpec) @action(detail=True, methods=["POST"]) def associate_encounter(self, request, *args, **kwargs): instance = self.get_object() diff --git a/care/utils/swagger/schema.py b/care/utils/swagger/schema.py index 1cf7b9261e..d0b418f34d 100644 --- a/care/utils/swagger/schema.py +++ b/care/utils/swagger/schema.py @@ -28,54 +28,69 @@ def get_tags(self): def get_request_serializer(self): view = self.view + + action = getattr(view, "action", None) + + if action not in [ + "create", + "update", + "partial_update", + "destroy", + "list", + "retrieve", + ]: + return None + if self.method == "POST": - if hasattr(view, "pydantic_model"): - return view.pydantic_model - elif self.method in ["PUT", "PATCH"]: - if hasattr(view, "pydantic_update_model"): - return view.pydantic_update_model - if hasattr(view, "pydantic_model"): - return view.pydantic_model - elif self.method == "GET": - return None # Can be improved later, if required - return self._get_serializer() + return getattr(view, "pydantic_model", None) + + if self.method in {"PUT", "PATCH"}: + return getattr(view, "pydantic_update_model", None) or getattr( + view, "pydantic_model", None + ) + + return None def get_response_serializers(self): view = self.view + action = getattr(view, "action", None) - if self.method in ["POST", "PUT", "PATCH"] and ( - hasattr(view, "pydantic_model") or hasattr(view, "pydantic_read_model") - ): - return {200: view.pydantic_read_model or view.pydantic_model} + if action not in [ + "create", + "update", + "partial_update", + "destroy", + "list", + "retrieve", + ]: + return None if self.method == "DELETE": return {"204": {"description": "No response body"}} - if ( - self.method == "GET" - and ( - isinstance(self.view, ListModelMixin) + if self.method == "GET": + if ( + isinstance(self.view, (ListModelMixin, EMRListMixin)) or self.view.action == "list" - or isinstance(self.view, EMRListMixin) - ) - and ( - hasattr(view, "pydantic_model") or hasattr(view, "pydantic_read_model") + ): + model = getattr(view, "pydantic_read_model", None) or getattr( + view, "pydantic_model", None + ) + else: + model = ( + getattr(view, "pydantic_retrieve_model", None) + or getattr(view, "pydantic_read_model", None) + or getattr(view, "pydantic_model", None) + ) + + elif self.method in ["POST", "PUT", "PATCH"]: + model = getattr(view, "pydantic_read_model", None) or getattr( + view, "pydantic_model", None ) - ): - return {200: view.pydantic_read_model or view.pydantic_model} - - if self.method == "GET" and ( - hasattr(view, "pydantic_retrieve_model") - or hasattr(view, "pydantic_read_model") - or hasattr(view, "pydantic_model") - ): - return { - 200: view.pydantic_retrieve_model - or view.pydantic_read_model - or view.pydantic_model - } - - return self._get_serializer() + else: + return None + + return {200: model} if model else None def _resolve_path_parameters(self, variables): if hasattr(self.view, "database_model"):