Skip to content

Commit

Permalink
use enum value in request validation (kubeflow#2249)
Browse files Browse the repository at this point in the history
Other method such as `.predict` assume that the enum value is used for the protocol, but the validation checks against the actual enum instead.

Signed-off-by: luranhe <luranjhe@gmail.com>
  • Loading branch information
luranhe authored Jun 21, 2022
1 parent cf4ed1d commit 7e74dca
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/kserve/kserve/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,13 @@ def _grpc_client(self):
return self._grpc_client_stub

def validate(self, request):
if self.protocol == PredictorProtocol.REST_V2:
if self.protocol == PredictorProtocol.REST_V2.value:
if "inputs" in request and not isinstance(request["inputs"], list):
raise tornado.web.HTTPError(
status_code=HTTPStatus.BAD_REQUEST,
reason="Expected \"inputs\" to be a list"
)
elif isinstance(request, Dict) or self.protocol == PredictorProtocol.REST_V1:
elif isinstance(request, Dict) or self.protocol == PredictorProtocol.REST_V1.value:
if "instances" in request and not isinstance(request["instances"], list):
raise tornado.web.HTTPError(
status_code=HTTPStatus.BAD_REQUEST,
Expand Down
22 changes: 22 additions & 0 deletions python/kserve/test/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
from kserve import Model
from kserve import ModelServer
from kserve import ModelRepository
from kserve.model import PredictorProtocol
from tornado.httpclient import HTTPClientError
from tornado.web import HTTPError
from ray import serve


Expand Down Expand Up @@ -159,6 +161,26 @@ async def load(self, name: str) -> bool:
return False


class TestModel:

async def test_validate(self):
model = DummyModel("TestModel")
good_request = {"instances": []}
validated_request = model.validate(good_request)
assert validated_request == good_request
bad_request = {"instances": "invalid"}
with pytest.raises(HTTPError):
model.validate(bad_request)

model.protocol = PredictorProtocol.REST_V2.value
good_request = {"inputs": []}
validated_request = model.validate(good_request)
assert validated_request == good_request
bad_request = {"inputs": "invalid"}
with pytest.raises(HTTPError):
model.validate(bad_request)


class TestTFHttpServer:

@pytest.fixture(scope="class")
Expand Down

0 comments on commit 7e74dca

Please sign in to comment.