From 78f5d4f5493f139ba95eeec89723c10f4729437d Mon Sep 17 00:00:00 2001 From: kicksent Date: Wed, 5 Feb 2025 01:17:05 -0700 Subject: [PATCH] change def in load and save, and add test --- burr/integrations/persisters/b_pymongo.py | 4 ++-- tests/integrations/persisters/test_b_mongodb.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/burr/integrations/persisters/b_pymongo.py b/burr/integrations/persisters/b_pymongo.py index 015e23dc..e32fcd2b 100644 --- a/burr/integrations/persisters/b_pymongo.py +++ b/burr/integrations/persisters/b_pymongo.py @@ -96,7 +96,7 @@ def list_app_ids(self, partition_key: str, **kwargs) -> list[str]: return app_ids def load( - self, partition_key: str, app_id: str, sequence_id: int = None, **kwargs + self, partition_key: Optional[str], app_id: str, sequence_id: int = None, **kwargs ) -> Optional[persistence.PersistedStateData]: """Load the state data for a given partition key, app id, and sequence id.""" query = {"partition_key": partition_key, "app_id": app_id} @@ -118,7 +118,7 @@ def load( def save( self, - partition_key: str, + partition_key: Optional[str], app_id: str, sequence_id: int, position: str, diff --git a/tests/integrations/persisters/test_b_mongodb.py b/tests/integrations/persisters/test_b_mongodb.py index e08073db..93d17286 100644 --- a/tests/integrations/persisters/test_b_mongodb.py +++ b/tests/integrations/persisters/test_b_mongodb.py @@ -66,3 +66,16 @@ def test_serialization_with_pickle(mongodb_persister): data = deserialized_persister.load("pk", "app_id_serde", 1) assert data["state"].get_all() == {"a": 1, "b": 2} + +def test_partition_key_is_optional(mongodb_persister): + # 1. Save and load with partition key = None + mongodb_persister.save(None, "app_id_none", 1, "pos1", state.State({"foo": "bar"}), "in_progress") + loaded_data = mongodb_persister.load(None, "app_id_none", 1) + assert loaded_data is not None + assert loaded_data["state"].get_all() == {"foo": "bar"} + + # 2. Save and load again (different key/index) with partition key = None + mongodb_persister.save(None, "app_id_none2", 2, "pos2", state.State({"hello": "world"}), "completed") + loaded_data2 = mongodb_persister.load(None, "app_id_none2", 2) + assert loaded_data2 is not None + assert loaded_data2["state"].get_all() == {"hello": "world"}