From d03405386467b3a7f27146b8c9cf0a8fccbd0cda Mon Sep 17 00:00:00 2001 From: Aakash Singh Date: Thu, 28 Dec 2023 20:49:56 +0530 Subject: [PATCH] escape tokens --- care/facility/api/viewsets/icd.py | 4 +++- care/facility/api/viewsets/prescription.py | 12 ++++++------ care/hcx/api/viewsets/gateway.py | 3 ++- care/utils/static_data/helpers.py | 11 +++++++++++ 4 files changed, 22 insertions(+), 8 deletions(-) create mode 100644 care/utils/static_data/helpers.py diff --git a/care/facility/api/viewsets/icd.py b/care/facility/api/viewsets/icd.py index 6e6d8993f3..9d1f722886 100644 --- a/care/facility/api/viewsets/icd.py +++ b/care/facility/api/viewsets/icd.py @@ -4,6 +4,7 @@ from rest_framework.viewsets import ViewSet from care.facility.static_data.icd11 import ICD11 +from care.utils.static_data.helpers import query_builder class ICDViewSet(ViewSet): @@ -17,9 +18,10 @@ def list(self, request): limit = min(int(request.query_params.get("limit")), 20) except (ValueError, TypeError): limit = 20 + query = [] if q := request.query_params.get("query"): - query = [ICD11.label % f"{'* '.join(q.strip().rsplit(maxsplit=3))}*"] + query.append(ICD11.label % query_builder(q)) result = FindQuery(expressions=query, model=ICD11, limit=limit).execute( exhaust_results=False diff --git a/care/facility/api/viewsets/prescription.py b/care/facility/api/viewsets/prescription.py index 7fb7ea5ceb..e31359ac85 100644 --- a/care/facility/api/viewsets/prescription.py +++ b/care/facility/api/viewsets/prescription.py @@ -22,6 +22,7 @@ from care.facility.static_data.medibase import MedibaseMedicine from care.utils.filters.choicefilter import CareChoiceFilter from care.utils.queryset.consultation import get_consultation_queryset +from care.utils.static_data.helpers import query_builder, token_escaper def inverse_choices(choices): @@ -163,14 +164,13 @@ def list(self, request): query = [] if type := request.query_params.get("type"): - query = MedibaseMedicine.type == type + query.append(MedibaseMedicine.type == type) - if search_query := request.query_params.get("query"): - q = (MedibaseMedicine.name == search_query) | ( - MedibaseMedicine.vec - % f"{'* '.join(search_query.strip().rsplit(maxsplit=3))}*" + if q := request.query_params.get("query"): + query.append( + (MedibaseMedicine.name == token_escaper.escape(q)) + | (MedibaseMedicine.vec % query_builder(q)) ) - query = [query & q if query else q] result = FindQuery( expressions=query, model=MedibaseMedicine, limit=limit diff --git a/care/hcx/api/viewsets/gateway.py b/care/hcx/api/viewsets/gateway.py index ca135f3701..a236df222d 100644 --- a/care/hcx/api/viewsets/gateway.py +++ b/care/hcx/api/viewsets/gateway.py @@ -41,6 +41,7 @@ from care.hcx.utils.hcx import Hcx from care.hcx.utils.hcx.operations import HcxOperations from care.utils.queryset.communications import get_communications +from care.utils.static_data.helpers import query_builder class HcxGatewayViewSet(GenericViewSet): @@ -332,7 +333,7 @@ def pmjy_packages(self, request): query = [] if q := request.query_params.get("query"): - query = [PMJYPackage.vec % f"{'* '.join(q.strip().rsplit(maxsplit=3))}*"] + query.append(PMJYPackage.vec % query_builder(q)) results = FindQuery(expressions=query, model=PMJYPackage, limit=limit).execute( exhaust_results=False diff --git a/care/utils/static_data/helpers.py b/care/utils/static_data/helpers.py new file mode 100644 index 0000000000..b26d5bc507 --- /dev/null +++ b/care/utils/static_data/helpers.py @@ -0,0 +1,11 @@ +from redis_om.model.token_escaper import TokenEscaper + +token_escaper = TokenEscaper() + + +def query_builder(query: str) -> str: + """ + Builds a query for redis full text search from a given query string. + """ + words = query.strip().rsplit(maxsplit=3) + return f"{'* '.join([token_escaper.escape(word) for word in words])}*"