-
Notifications
You must be signed in to change notification settings - Fork 433
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
`boto3` sessions are not thread safe. When used in the object store logger with `use_procs: False`, the default session was shared across threads, which caused us to run into boto/boto3#1592. To fix, this PR creates a new session within each `S3ObjectStore` instance. Closes https://mosaicml.atlassian.net/browse/CO-651
- Loading branch information
1 parent
b92c8eb
commit c51a542
Showing
2 changed files
with
39 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Copyright 2022 MosaicML Composer authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import os | ||
import pathlib | ||
import threading | ||
|
||
import pytest | ||
|
||
from composer.utils.object_store import S3ObjectStore | ||
|
||
|
||
def _worker(bucket: str, tmp_path: pathlib.Path, tid: int): | ||
object_store = S3ObjectStore(bucket=bucket) | ||
os.makedirs(tmp_path / str(tid)) | ||
with pytest.raises(FileNotFoundError): | ||
object_store.download_object('this_key_should_not_exist', filename=tmp_path / str(tid) / 'dummy_file') | ||
|
||
|
||
@pytest.mark.timeout(15) | ||
# This test requires properly configured aws credentials; otherwise the s3 client would hit a NoCredentialsError | ||
# when constructing the Session, which occurs before the bug this test checks | ||
@pytest.mark.remote | ||
def test_s3_object_store_multi_threads(tmp_path: pathlib.Path): | ||
"""Test to verify that we do not hit https://github.com/boto/boto3/issues/1592.""" | ||
pytest.importorskip('boto3') | ||
# TODO(Bandish) -- once we have integration tests configured, change the bucket below | ||
# to an integration test bucket | ||
bucket = 'allenai-c4' | ||
|
||
threads = [] | ||
# Manually tried fewer threads; it seems that 100 is needed to reliably re-produce the bug | ||
for i in range(100): | ||
t = threading.Thread(target=_worker, kwargs={'bucket': bucket, 'tid': i, 'tmp_path': tmp_path}) | ||
t.start() | ||
threads.append(t) | ||
for t in threads: | ||
t.join() |