From 26d91159e6d6b38a5e09c1df5670b56123d14100 Mon Sep 17 00:00:00 2001 From: "Afshin T. Darian" Date: Tue, 29 Oct 2024 12:29:57 +0000 Subject: [PATCH] clean up --- jupyter_server/services/events/handlers.py | 13 ++++++------- tests/services/events/test_api.py | 1 + 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/jupyter_server/services/events/handlers.py b/jupyter_server/services/events/handlers.py index 135cc9119..611c571eb 100644 --- a/jupyter_server/services/events/handlers.py +++ b/jupyter_server/services/events/handlers.py @@ -71,7 +71,9 @@ def on_close(self): self.event_logger.remove_listener(listener=self.event_listener) -def validate_model(data: dict[str, Any], schema: jupyter_events.schema.EventSchema) -> None: +def validate_model( + data: dict[str, Any], registry: jupyter_events.schema_registry.SchemaRegistry +) -> None: """Validates for required fields in the JSON request body and verifies that a registered schema/version exists""" required_keys = {"schema_id", "version", "data"} @@ -81,9 +83,7 @@ def validate_model(data: dict[str, Any], schema: jupyter_events.schema.EventSche raise Exception(message) schema_id = cast(str, data.get("schema_id")) version = cast(int, data.get("version")) - if schema is None: - message = f"Unregistered schema: `{schema_id}`" - raise Exception(message) + schema = registry.get(schema_id) if schema.version != version: message = f"Unregistered version: `{version}` for `{schema_id}`" raise Exception(message) @@ -121,10 +121,9 @@ async def post(self): raise web.HTTPError(400, "No JSON data provided") try: - schema = self.event_logger.schemas.get(cast(str, payload.get("schema_id"))) - validate_model(payload, schema) + validate_model(payload, self.event_logger.schemas) self.event_logger.emit( - schema_id=schema.id, + schema_id=cast(str, payload.get("schema_id")), data=cast("Dict[str, Any]", payload.get("data")), timestamp_override=get_timestamp(payload), ) diff --git a/tests/services/events/test_api.py b/tests/services/events/test_api.py index b8421929f..49599e838 100644 --- a/tests/services/events/test_api.py +++ b/tests/services/events/test_api.py @@ -155,4 +155,5 @@ async def test_post_event(jp_fetch, event_logger_sink, payload): async def test_post_event_400(jp_fetch, event_logger, payload): with pytest.raises(tornado.httpclient.HTTPClientError) as e: await jp_fetch("api", "events", method="POST", body=payload) + assert expected_http_error(e, 400)