Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modified User Route for User Management Redesign #2595

Merged
merged 14 commits into from
Dec 10, 2024
9 changes: 9 additions & 0 deletions care/facility/api/viewsets/facility_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,21 @@ class FacilityUserViewSet(GenericViewSet, mixins.ListModelMixin):

def get_queryset(self):
try:
search_fields = {
key: self.request.query_params.get(key)
for key in self.search_fields
if self.request.query_params.get(key)
}
facility = Facility.objects.get(
external_id=self.kwargs.get("facility_external_id"),
)
queryset = facility.users.filter(
deleted=False,
).order_by("-last_login")
if search_fields:
queryset = queryset.filter(
**{key: value for key, value in search_fields.items() if value}
)
Jacobjeevan marked this conversation as resolved.
Show resolved Hide resolved
return queryset.prefetch_related(
Prefetch(
"skills",
Expand Down
12 changes: 12 additions & 0 deletions care/facility/tests/test_facilityuser_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def setUpTestData(cls) -> None:
cls.facility2 = cls.create_facility(
cls.super_user, cls.district, cls.local_body
)
cls.user2 = cls.create_user(
"dummystaff", cls.district, home_facility=cls.facility2
)

def setUp(self) -> None:
self.client.force_authenticate(self.super_user)
Expand All @@ -32,6 +35,15 @@ def test_get_queryset_with_prefetching(self):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNumQueries(2)

def test_get_queryset_with_search(self):
response = self.client.get(
f"/api/v1/facility/{self.facility2.external_id}/get_users/?username={self.user2.username}"
)

self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data["results"]), 1)
self.assertEqual(response.data["results"][0]["username"], self.user2.username)
Jacobjeevan marked this conversation as resolved.
Show resolved Hide resolved

def test_link_new_facility(self):
response = self.client.get("/api/v1/facility/")

Expand Down
2 changes: 2 additions & 0 deletions care/users/api/serializers/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ class Meta:
"pf_auth",
"read_profile_picture_url",
"user_flags",
"last_login",
Jacobjeevan marked this conversation as resolved.
Show resolved Hide resolved
)
read_only_fields = (
"is_superuser",
Expand All @@ -347,6 +348,7 @@ class Meta:
"pf_endpoint",
"pf_p256dh",
"pf_auth",
"last_login",
)

extra_kwargs = {"url": {"lookup_field": "username"}}
Expand Down
22 changes: 21 additions & 1 deletion care/users/api/viewsets/change_password.py
vigneshhari marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from django.contrib.auth import get_user_model
from django.shortcuts import get_object_or_404
from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework import serializers, status
from rest_framework.generics import UpdateAPIView
Expand Down Expand Up @@ -29,7 +30,22 @@ class ChangePasswordView(UpdateAPIView):
model = User

def update(self, request, *args, **kwargs):
self.object = self.request.user
username = request.data.get("username")
if not username:
return Response(
{"message": ["Username is required"]},
status=status.HTTP_400_BAD_REQUEST,
)
self.object = get_object_or_404(User, username=username)
vigneshhari marked this conversation as resolved.
Show resolved Hide resolved
if not self.has_permission(request, self.object):
return Response(
{
"message": [
"User does not have elevated permissions to change password"
]
},
status=status.HTTP_403_FORBIDDEN,
)
serializer = self.get_serializer(data=request.data)

if serializer.is_valid():
Expand All @@ -48,3 +64,7 @@ def update(self, request, *args, **kwargs):
return Response({"message": "Password updated successfully"})

return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)

def has_permission(self, request, user):
authuser = request.user
return authuser == user or authuser.is_superuser
17 changes: 16 additions & 1 deletion care/users/api/viewsets/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,18 @@ def last_active_after(self, queryset, name, value):
return queryset.filter(last_login__gte=date)


class UserViewSetPermission(DRYPermissions):
def has_permission(self, request, view):
if request.method == "GET" and view.action == "retrieve":
return True
return super().has_permission(request, view)

def has_object_permission(self, request, view, obj):
if request.method == "GET" and view.action == "retrieve":
return True
return super().has_object_permission(request, view, obj)


class UserViewSet(
mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
Expand All @@ -113,7 +125,7 @@ class UserViewSet(
queryset = queryset.filter(Q(asset__isnull=True))
lookup_field = "username"
lookup_value_regex = "[^/]+"
permission_classes = (IsAuthenticated, DRYPermissions)
permission_classes = (IsAuthenticated, UserViewSetPermission)
filter_backends = (
filters.DjangoFilterBackend,
rest_framework_filters.OrderingFilter,
Expand Down Expand Up @@ -155,6 +167,9 @@ def get_queryset(self):

def get_object(self) -> User:
try:
if self.request.method == "GET" and self.action == "retrieve":
username = self.kwargs.get("username")
return get_object_or_404(User, username=username)
return super().get_object()
except Http404 as e:
error = "User not found"
Expand Down
138 changes: 130 additions & 8 deletions care/users/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def setUpTestData(cls) -> None:
cls.user = cls.create_user("staff1", cls.district)
cls.user_data = cls.get_user_data(cls.district, 40)

cls.data_2 = cls.get_user_data(cls.district)
cls.data_2.update({"username": "user_2", "password": "password"})
cls.user_2 = cls.create_user(**cls.data_2)

def setUp(self):
self.client.force_authenticate(self.super_user)

Expand Down Expand Up @@ -51,6 +55,7 @@ def get_detail_representation(self, obj=None) -> dict:
"video_connect_link": obj.video_connect_link,
"read_profile_picture_url": obj.profile_picture_url,
"user_flags": [],
"last_login": obj.last_login,
**self.get_local_body_district_state_representation(obj),
}

Expand All @@ -67,9 +72,11 @@ def test_superuser_can_view(self):
data = self.user_data.copy()
data["date_of_birth"] = str(data["date_of_birth"])
data.pop("password")
user_data = self.get_detail_representation(self.user)
user_data.pop("created_by")
self.assertDictEqual(
res_data_json,
self.get_detail_representation(self.user),
user_data,
)

def test_superuser_can_modify(self):
Expand Down Expand Up @@ -106,6 +113,61 @@ def test_superuser_can_delete(self):
deleted=False,
)

def test_superuser_can_change_password_of_others(self):
"""Test a user with superuser access can change the password of other users underneath the hierarchy"""
username = self.data_2["username"]
password = self.data_2["password"]
response = self.client.put(
"/api/v1/password_change/",
{
"username": username,
"old_password": password,
"new_password": "password2",
},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)

def test_superuser_cannot_change_password_of_others_without_username(
self,
):
"""Test a user with superuser access cannot change the password of other users without username"""
response = self.client.put(
"/api/v1/password_change/",
{"old_password": "password", "new_password": "password2"},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data["message"][0], "Username is required")

def test_superuser_cannot_change_password_of_non_existing_user(self):
"""Test a user with superuser access cannot change the password of a non existing user"""
response = self.client.put(
"/api/v1/password_change/",
{
"username": "foobar",
"old_password": "password",
"new_password": "password2",
},
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

def test_superuser_cannot_change_password_of_others_with_invalid_old_password(
self,
):
"""Test a user with superuser access cannot change the password of other users with invalid old password"""
response = self.client.put(
"/api/v1/password_change/",
{
"username": self.data_2["username"],
"old_password": "wrong_password",
"new_password": "password2",
},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(
response.data["old_password"][0],
"Wrong password entered. Please check your password.",
)


class TestUser(TestUtils, APITestCase):
def get_detail_representation(self, obj=None) -> dict:
Expand Down Expand Up @@ -137,10 +199,38 @@ def setUpTestData(cls) -> None:
cls.user_2 = cls.create_user(**cls.data_2)

cls.data_3 = cls.get_user_data(cls.district)
cls.data_3.update({"username": "user_3", "password": "password"})
cls.data_3.update(
{
"username": "user_3",
"password": "password",
"user_type": User.TYPE_VALUE_MAP["Doctor"],
}
)
cls.user_3 = cls.create_user(**cls.data_3)
cls.link_user_with_facility(cls.user_3, cls.facility, cls.super_user)

cls.data_4 = cls.get_user_data(cls.district)
cls.data_4.update(
{
"username": "user_4",
"password": "password",
"user_type": User.TYPE_VALUE_MAP["DistrictAdmin"],
}
)
cls.user_4 = cls.create_user(**cls.data_4)
cls.link_user_with_facility(cls.user_4, cls.facility, cls.super_user)

cls.data_5 = cls.get_user_data(cls.district)
cls.data_5.update(
{
"username": "user_5",
"password": "password",
"user_type": User.TYPE_VALUE_MAP["WardAdmin"],
}
)
cls.user_5 = cls.create_user(**cls.data_5)
cls.link_user_with_facility(cls.user_5, cls.facility, cls.super_user)

def test_user_can_access_url(self):
"""Test user can access the url by location"""
username = self.user.username
Expand All @@ -152,7 +242,7 @@ def test_user_can_read_all_users_within_accessible_facility(self):
response = self.client.get("/api/v1/users/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
res_data_json = response.json()
self.assertEqual(res_data_json["count"], 2)
self.assertEqual(res_data_json["count"], 3)
results = res_data_json["results"]
self.assertIn(self.user.id, {r["id"] for r in results})
self.assertIn(self.user_3.id, {r["id"] for r in results})
Expand All @@ -176,12 +266,12 @@ def test_user_can_modify_themselves(self):
User.objects.get(username=username).date_of_birth, date(2005, 4, 1)
)

def test_user_cannot_read_others(self):
"""Test 1 user can read the attributes of the other user"""
username = self.data_2["username"]
def test_user_can_read_others(self):
"""Test 1 user can read the attributes of any other user"""
username = self.user_2.username
response = self.client.get(f"/api/v1/users/{username}/")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.json()["detail"], "User not found")
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json()["first_name"], self.user_2.first_name)
Jacobjeevan marked this conversation as resolved.
Show resolved Hide resolved

def test_user_cannot_modify_others(self):
"""Test a user can't modify others"""
Expand All @@ -207,6 +297,38 @@ def test_user_cannot_delete_others(self):
User.objects.get(username=self.data_2[field]).username,
)

def test_user_cannot_change_password_of_others(self):
"""Test a user cannot change password of others"""
username = self.data_2["username"]
password = self.data_2["password"]
response = self.client.put(
"/api/v1/password_change/",
{
"username": username,
"old_password": password,
"new_password": "password2",
},
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)

def test_user_with_districtadmin_access_can_modify_others(self):
"""Test a user with district admin access can modify others underneath the hierarchy"""
self.client.force_authenticate(self.user_4)
username = self.data_2["username"]
response = self.client.patch(
f"/api/v1/users/{username}/",
{
"date_of_birth": date(2005, 4, 1),
},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json()["date_of_birth"], "2005-04-01")

def test_user_gets_error_when_accessing_user_details_with_invalid_username(self):
"""Test a user gets error when accessing user details with invalid username"""
response = self.client.get("/api/v1/users/foobar/")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)


class TestUserFilter(TestUtils, APITestCase):
@classmethod
Expand Down
6 changes: 3 additions & 3 deletions care/utils/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ def get_local_body_district_state_representation(self, obj):
response.update(self.get_state_representation(getattr(obj, "state", None)))
return response

def get_local_body_representation(self, local_body: LocalBody):
def get_local_body_representation(self, local_body: LocalBody | None):
if local_body is None:
return {"local_body": None, "local_body_object": None}
return {
Expand All @@ -778,7 +778,7 @@ def get_local_body_representation(self, local_body: LocalBody):
},
}

def get_district_representation(self, district: District):
def get_district_representation(self, district: District | None):
if district is None:
return {"district": None, "district_object": None}
return {
Expand All @@ -790,7 +790,7 @@ def get_district_representation(self, district: District):
},
}

def get_state_representation(self, state: State):
def get_state_representation(self, state: State | None):
if state is None:
return {"state": None, "state_object": None}
return {"state": state.id, "state_object": {"id": state.id, "name": state.name}}
Expand Down
Loading