diff --git a/care/users/api/viewsets/users.py b/care/users/api/viewsets/users.py index d581b170c0..f67ce6ac42 100644 --- a/care/users/api/viewsets/users.py +++ b/care/users/api/viewsets/users.py @@ -57,9 +57,13 @@ class UserFilterSet(filters.FilterSet): ) last_login = filters.DateFromToRangeFilter(field_name="last_login") district_id = filters.NumberFilter(field_name="district_id", lookup_expr="exact") - home_facility = filters.UUIDFilter( - field_name="home_facility__external_id", lookup_expr="exact" - ) + home_facility = filters.CharFilter(method="filter_home_facility") + + def filter_home_facility(self, queryset, name, value): + if value == "NONE": + return queryset.filter(home_facility__isnull=True) + return queryset.filter(home_facility__external_id=value) + last_active_days = filters.CharFilter(method="last_active_after") def get_user_type( diff --git a/care/users/tests/test_api.py b/care/users/tests/test_api.py index a91a891a74..ef58f25e7c 100644 --- a/care/users/tests/test_api.py +++ b/care/users/tests/test_api.py @@ -211,6 +211,9 @@ def setUpTestData(cls) -> None: cls.local_body = cls.create_local_body(cls.district) cls.super_user = cls.create_super_user("su", cls.district) cls.facility = cls.create_facility(cls.super_user, cls.district, cls.local_body) + cls.facility_2 = cls.create_facility( + cls.super_user, cls.district, cls.local_body + ) cls.user_1 = cls.create_user("staff1", cls.district, home_facility=cls.facility) @@ -218,6 +221,12 @@ def setUpTestData(cls) -> None: cls.user_3 = cls.create_user("staff3", cls.district, home_facility=cls.facility) + cls.user_4 = cls.create_user( + "staff4", cls.district, home_facility=cls.facility_2 + ) + + cls.user_5 = cls.create_user("doctor", cls.district) + def setUp(self): self.client.force_authenticate(self.super_user) self.user_1.last_login = timezone.now() - timedelta(hours=1) @@ -248,7 +257,37 @@ def test_last_active_filter(self): response = self.client.get("/api/v1/users/?last_active_days=never") self.assertEqual(response.status_code, status.HTTP_200_OK) res_data_json = response.json() - self.assertEqual(res_data_json["count"], 1) + self.assertEqual(res_data_json["count"], 3) self.assertIn( self.user_3.username, {r["username"] for r in res_data_json["results"]} ) + + def test_home_facility_filter(self): + """Test home facility filter""" + response = self.client.get("/api/v1/users/?home_facility=NOT_A_VALID_UUID") + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + response = self.client.get( + f"/api/v1/users/?home_facility={self.facility.external_id}" + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + res_data_json = response.json() + self.assertEqual(res_data_json["count"], 3) + self.assertIn( + self.user_1.username, {r["username"] for r in res_data_json["results"]} + ) + + response = self.client.get( + f"/api/v1/users/?home_facility={self.facility_2.external_id}" + ) + res_data_json = response.json() + self.assertEqual(res_data_json["count"], 1) + self.assertIn( + self.user_4.username, {r["username"] for r in res_data_json["results"]} + ) + + response = self.client.get("/api/v1/users/?home_facility=NONE") + res_data_json = response.json() + self.assertEqual(res_data_json["count"], 1) + self.assertIn( + self.user_5.username, {r["username"] for r in res_data_json["results"]} + )