From 8f9aab641fae8fa230278efebb51bddc1806713a Mon Sep 17 00:00:00 2001 From: Naman Jain Date: Fri, 24 Apr 2020 15:11:50 -0400 Subject: [PATCH] Support GCS caching for parallel processing GCS caching is broken for parallel processing because GCS client objects cannot be pickled even with cloudpickle. This change stops attempting to pickle those GCS objects and recreates them back in the subprocess. --- bionic/cache.py | 25 ++++++++++++++++++++++--- bionic/flow.py | 4 ---- bionic/util.py | 5 +++++ tests/conftest.py | 2 +- tests/test_flow/test_persistence_gcs.py | 1 - 5 files changed, 28 insertions(+), 9 deletions(-) diff --git a/bionic/cache.py b/bionic/cache.py index 814ef9e6..24600f8f 100644 --- a/bionic/cache.py +++ b/bionic/cache.py @@ -767,10 +767,29 @@ def __init__(self, url): self.url = url bucket_name, object_prefix = self._bucket_and_object_names_from_url(url) - logger.info("Initializing GCS client ...") - self._client = get_gcs_client_without_warnings() - self._bucket = self._client.get_bucket(bucket_name) + self._bucket_name = bucket_name self._object_prefix = object_prefix + self._init_client() + + def __getstate__(self): + # Copy the object's state from self.__dict__ which contains + # all our instance attributes. Always use the dict.copy() + # method to avoid modifying the original state. + state = self.__dict__.copy() + # Remove the unpicklable entries. + del state["_client"] + del state["_bucket"] + return state + + def __setstate__(self, state): + # Restore instance attributes. + self.__dict__.update(state) + # Restore the client and bucket. + self._init_client() + + def _init_client(self): + self._client = get_gcs_client_without_warnings() + self._bucket = self._client.get_bucket(self._bucket_name) def blob_from_url(self, url): object_name = self._validated_object_name_from_url(url) diff --git a/bionic/flow.py b/bionic/flow.py index ee717574..676cc6b4 100644 --- a/bionic/flow.py +++ b/bionic/flow.py @@ -45,10 +45,6 @@ oneline, ) -import logging - -logger = logging.getLogger(__name__) - DEFAULT_PROTOCOL = protos.CombinedProtocol( protos.ParquetDataFrameProtocol(), protos.ImageProtocol(), diff --git a/bionic/util.py b/bionic/util.py index a5cfef13..3f5a11ad 100644 --- a/bionic/util.py +++ b/bionic/util.py @@ -13,6 +13,10 @@ from .optdep import import_optional_dependency, oneline +import logging + +logger = logging.getLogger(__name__) + def n_present(*items): "Returns the number of non-None arguments." @@ -123,6 +127,7 @@ def get_gcs_client_without_warnings(cache_value=True): warnings.filterwarnings( "ignore", "Your application has authenticated using end user credentials" ) + logger.info("Initializing GCS client ...") return gcs.Client() diff --git a/tests/conftest.py b/tests/conftest.py index 658fd65f..d29dbeae 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,7 @@ def process_executor(request): return None loky = import_optional_dependency("loky", purpose="parallel processing") - return loky.get_reusable_executor() + return loky.get_reusable_executor(max_workers=1) @pytest.fixture(scope="session") diff --git a/tests/test_flow/test_persistence_gcs.py b/tests/test_flow/test_persistence_gcs.py index 3313bf2b..24a2ad9e 100644 --- a/tests/test_flow/test_persistence_gcs.py +++ b/tests/test_flow/test_persistence_gcs.py @@ -53,7 +53,6 @@ def gcs_builder(builder, tmp_gcs_url_prefix): # place. # TODO Now that we have a workspace fixture and cached client initialization, # this may not be true anymore. -@pytest.mark.no_parallel def test_gcs_caching(gcs_builder, make_counter): # Setup.