diff --git a/burr/integrations/persisters/b_mongodb.py b/burr/integrations/persisters/b_mongodb.py index 0481db17..4e135cb2 100644 --- a/burr/integrations/persisters/b_mongodb.py +++ b/burr/integrations/persisters/b_mongodb.py @@ -1,7 +1,7 @@ import json import logging from datetime import datetime, timezone -from typing import Literal, Optional +from typing import Any, Literal, Optional from pymongo import MongoClient @@ -35,6 +35,11 @@ class MongoDBBasePersister(persistence.BaseStatePersister): this change backwards compatible. """ + @classmethod + def default_client(cls) -> Any: + """Returns the default client for the persister.""" + return MongoClient + @classmethod def from_values( cls, @@ -47,7 +52,7 @@ def from_values( """Initializes the MongoDBBasePersister class.""" if mongo_client_kwargs is None: mongo_client_kwargs = {} - client = MongoClient(uri, **mongo_client_kwargs) + client = cls.default_client()(uri, **mongo_client_kwargs) return cls( client=client, db_name=db_name, @@ -130,14 +135,18 @@ def save( def __del__(self): self.client.close() - def __getstate__(self) -> dict: - state = self.__dict__.copy() - state["connection_params"] = { + def get_connection_params(self) -> dict: + """Get the connection parameters for the MongoDB persister.""" + return { "uri": self.client.address[0], "port": self.client.address[1], "db_name": self.db.name, "collection_name": self.collection.name, } + + def __getstate__(self) -> dict: + state = self.__dict__.copy() + state["connection_params"] = self.get_connection_params() del state["client"] del state["db"] del state["collection"] @@ -146,7 +155,7 @@ def __getstate__(self) -> dict: def __setstate__(self, state: dict): connection_params = state.pop("connection_params") # we assume MongoClient. - self.client = MongoClient(connection_params["uri"], connection_params["port"]) + self.client = self.default_client()(connection_params["uri"], connection_params["port"]) self.db = self.client[connection_params["db_name"]] self.collection = self.db[connection_params["collection_name"]] self.__dict__.update(state) @@ -169,7 +178,7 @@ def __init__( """Initializes the MongoDBPersister class.""" if mongo_client_kwargs is None: mongo_client_kwargs = {} - client = MongoClient(uri, **mongo_client_kwargs) + client = self.default_client()(uri, **mongo_client_kwargs) super(MongoDBPersister, self).__init__( client=client, db_name=db_name, diff --git a/burr/integrations/persisters/b_redis.py b/burr/integrations/persisters/b_redis.py index 25231bf9..9014914b 100644 --- a/burr/integrations/persisters/b_redis.py +++ b/burr/integrations/persisters/b_redis.py @@ -9,7 +9,7 @@ import json import logging from datetime import datetime, timezone -from typing import Literal, Optional +from typing import Any, Literal, Optional from burr.core import persistence, state @@ -28,6 +28,11 @@ class RedisBasePersister(persistence.BaseStatePersister): so this is an attempt to fix that in a backwards compatible way. """ + @classmethod + def default_client(cls) -> Any: + """Returns the default client for the persister.""" + return redis.Redis + @classmethod def from_values( cls, @@ -42,7 +47,7 @@ def from_values( """Creates a new instance of the RedisBasePersister from passed in values.""" if redis_client_kwargs is None: redis_client_kwargs = {} - connection = redis.Redis( + connection = cls.default_client()( host=host, port=port, db=db, password=password, **redis_client_kwargs ) return cls(connection, serde_kwargs, namespace) @@ -160,24 +165,26 @@ def save( def __del__(self): self.connection.close() - def __getstate__(self) -> dict: - state = self.__dict__.copy() - if not hasattr(self.connection, "connection_pool"): - logger.warning("Redis connection is not serializable.") - return state - state["connection_params"] = { + def get_connection_params(self) -> dict: + """Get the connection parameters for the Redis connection.""" + return { "host": self.connection.connection_pool.connection_kwargs["host"], "port": self.connection.connection_pool.connection_kwargs["port"], "db": self.connection.connection_pool.connection_kwargs["db"], "password": self.connection.connection_pool.connection_kwargs["password"], } + + def __getstate__(self) -> dict: + state = self.__dict__.copy() + # override self.get_connection_params if needed + state["connection_params"] = self.get_connection_params() del state["connection"] return state def __setstate__(self, state: dict): connection_params = state.pop("connection_params") - # we assume normal redis client. - self.connection = redis.Redis(**connection_params) + # override self.default_client if needed + self.connection = self.default_client()(**connection_params) self.__dict__.update(state) @@ -211,7 +218,7 @@ def __init__( """ if redis_client_kwargs is None: redis_client_kwargs = {} - connection = redis.Redis( + connection = self.default_client()( host=host, port=port, db=db, password=password, **redis_client_kwargs ) super(RedisPersister, self).__init__(connection, serde_kwargs, namespace) diff --git a/burr/integrations/persisters/postgresql.py b/burr/integrations/persisters/postgresql.py index 86f50f66..089cb1fb 100644 --- a/burr/integrations/persisters/postgresql.py +++ b/burr/integrations/persisters/postgresql.py @@ -7,7 +7,7 @@ import json import logging -from typing import Literal, Optional +from typing import Any, Literal, Optional from burr.core import persistence, state @@ -51,6 +51,11 @@ def from_config(cls, config: dict) -> "PostgreSQLPersister": table_name=config.get("table_name", "burr_state"), ) + @classmethod + def default_client(cls) -> Any: + """Returns the default client for the persister.""" + return psycopg2.connect + @classmethod def from_values( cls, @@ -70,7 +75,7 @@ def from_values( :param port: the port of the PostgreSQL database. :param table_name: the table name to store things under. """ - connection = psycopg2.connect( + connection = cls.default_client()( dbname=db_name, user=user, password=password, host=host, port=port ) return cls(connection, table_name) @@ -246,27 +251,25 @@ def __del__(self): # closes connection at end when things are being shutdown. self.connection.close() - def __getstate__(self) -> dict: - state = self.__dict__.copy() - if not hasattr(self.connection, "info"): - logger.warning( - "Postgresql information for connection object not available. Cannot serialize persister." - ) - return state - state["connection_params"] = { + def get_connection_params(self) -> dict: + """Returns the connection parameters for the persister.""" + return { "dbname": self.connection.info.dbname, "user": self.connection.info.user, "password": self.connection.info.password, "host": self.connection.info.host, "port": self.connection.info.port, } + + def __getstate__(self) -> dict: + state = self.__dict__.copy() + state["connection_params"] = self.get_connection_params() del state["connection"] return state def __setstate__(self, state: dict): connection_params = state.pop("connection_params") - # we assume normal psycopg2 client. - self.connection = psycopg2.connect(**connection_params) + self.connection = self.default_client()(**connection_params) self.__dict__.update(state)