Skip to content

Commit

Permalink
Merge pull request #113 from square/naman/parallel-processing-gcs
Browse files Browse the repository at this point in the history
Support gcs caching for parallel processing
  • Loading branch information
namanjain authored Apr 24, 2020
2 parents 15e9a15 + 8f9aab6 commit c49b47a
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 9 deletions.
25 changes: 22 additions & 3 deletions bionic/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,10 +767,29 @@ def __init__(self, url):
self.url = url
bucket_name, object_prefix = self._bucket_and_object_names_from_url(url)

logger.info("Initializing GCS client ...")
self._client = get_gcs_client_without_warnings()
self._bucket = self._client.get_bucket(bucket_name)
self._bucket_name = bucket_name
self._object_prefix = object_prefix
self._init_client()

def __getstate__(self):
# Copy the object's state from self.__dict__ which contains
# all our instance attributes. Always use the dict.copy()
# method to avoid modifying the original state.
state = self.__dict__.copy()
# Remove the unpicklable entries.
del state["_client"]
del state["_bucket"]
return state

def __setstate__(self, state):
# Restore instance attributes.
self.__dict__.update(state)
# Restore the client and bucket.
self._init_client()

def _init_client(self):
self._client = get_gcs_client_without_warnings()
self._bucket = self._client.get_bucket(self._bucket_name)

def blob_from_url(self, url):
object_name = self._validated_object_name_from_url(url)
Expand Down
4 changes: 0 additions & 4 deletions bionic/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@
oneline,
)

import logging

logger = logging.getLogger(__name__)

DEFAULT_PROTOCOL = protos.CombinedProtocol(
protos.ParquetDataFrameProtocol(),
protos.ImageProtocol(),
Expand Down
5 changes: 5 additions & 0 deletions bionic/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@

from .optdep import import_optional_dependency, oneline

import logging

logger = logging.getLogger(__name__)


def n_present(*items):
"Returns the number of non-None arguments."
Expand Down Expand Up @@ -123,6 +127,7 @@ def get_gcs_client_without_warnings(cache_value=True):
warnings.filterwarnings(
"ignore", "Your application has authenticated using end user credentials"
)
logger.info("Initializing GCS client ...")
return gcs.Client()


Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def process_executor(request):
return None

loky = import_optional_dependency("loky", purpose="parallel processing")
return loky.get_reusable_executor()
return loky.get_reusable_executor(max_workers=1)


@pytest.fixture(scope="session")
Expand Down
1 change: 0 additions & 1 deletion tests/test_flow/test_persistence_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def gcs_builder(builder, tmp_gcs_url_prefix):
# place.
# TODO Now that we have a workspace fixture and cached client initialization,
# this may not be true anymore.
@pytest.mark.no_parallel
def test_gcs_caching(gcs_builder, make_counter):
# Setup.

Expand Down

0 comments on commit c49b47a

Please sign in to comment.