Skip to content

Commit

Permalink
[AIFlow] Fix oss blob manager download concurrently
Browse files Browse the repository at this point in the history
  • Loading branch information
Sxnan committed Oct 15, 2021
1 parent 01dcc9d commit d3581ab
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 7 deletions.
62 changes: 55 additions & 7 deletions ai_flow_plugins/blob_manager_plugins/oss_blob_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,20 @@
# specific language governing permissions and limitations
# under the License.
#
import os
import logging
import tempfile
import fcntl
import time
from typing import Text, Dict, Any
from pathlib import Path
import oss2
from ai_flow.plugin_interface.blob_manager_interface import BlobManager
from ai_flow.util.file_util.zip_file_util import make_dir_zipfile
from ai_flow_plugins.blob_manager_plugins.blob_manager_utils import extract_project_zip_file

logger = logging.getLogger(__name__)


class OssBlobManager(BlobManager):
"""
Expand All @@ -34,14 +40,14 @@ class OssBlobManager(BlobManager):
3. bucket: The oss bucket name.
4. local_repository: It represents the root path of the downloaded project package.
"""

def __init__(self, config: Dict[str, Any]):
super().__init__(config)
ack_id = config.get('access_key_id', None)
ack_secret = config.get('access_key_secret', None)
endpoint = config.get('endpoint', None)
bucket_name = config.get('bucket', None)
auth = oss2.Auth(ack_id, ack_secret)
self.bucket = oss2.Bucket(auth, endpoint, bucket_name)
self.ack_id = config.get('access_key_id', None)
self.ack_secret = config.get('access_key_secret', None)
self.endpoint = config.get('endpoint', None)
self.bucket_name = config.get('bucket', None)
self._bucket = None
self.repo_name = config.get('repo_name', '')
self._local_repo = config.get('local_repository', None)

Expand Down Expand Up @@ -81,8 +87,50 @@ def download_project(self, workflow_snapshot_id, remote_path: Text, local_path:
repo_path = Path(tempfile.gettempdir())
local_zip_file_path = str(repo_path / local_zip_file_name) + '.zip'
extract_path = str(repo_path / local_zip_file_name)
self.bucket.get_object_to_file(oss_object_key, filename=local_zip_file_path)

if not os.path.exists(local_zip_file_path):
logger.debug("{} not exist".format(local_zip_file_path))
lock_file_path = os.path.join(repo_path, "{}.lock".format(local_zip_file_name))
lock_file = open(lock_file_path, 'w')
logger.debug("Locking file {}".format(lock_file_path))
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
logger.debug("Locked file {}".format(lock_file_path))
try:
if not os.path.exists(local_zip_file_path):
logger.info("Downloading oss object: {}".format(oss_object_key))
self._get_oss_object(local_zip_file_path, oss_object_key)
except Exception as e:
logger.error("Failed to download oss file: {}".format(oss_object_key), exc_info=e)
finally:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
logger.debug('Unlocked file {}'.format(lock_file_path))
lock_file.close()
if os.path.exists(lock_file_path):
os.remove(lock_file_path)
else:
logger.info("Oss file: {} already exist at {}".format(oss_object_key, local_zip_file_path))

return extract_project_zip_file(workflow_snapshot_id=workflow_snapshot_id,
local_root_path=repo_path,
zip_file_path=local_zip_file_path,
extract_project_path=extract_path)

@property
def bucket(self):
if self._bucket:
return self._bucket
auth = oss2.Auth(self.ack_id, self.ack_secret)
self._bucket = oss2.Bucket(auth, self.endpoint, self.bucket_name)
return self._bucket

def _get_oss_object(self, dest, oss_object_key, retry_sleep_sec=5):
for i in range(3):
try:
self.bucket.get_object_to_file(oss_object_key, filename=dest)
return
except Exception as e:
logger.error("Downloading object {} failed, retrying {}/3 in {} second".format(oss_object_key, i+1,
retry_sleep_sec),
exc_info=e)
time.sleep(retry_sleep_sec)
raise RuntimeError("Failed to download oss file: {}".format(oss_object_key))
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@
# specific language governing permissions and limitations
# under the License.
#
import threading
import unittest
from unittest import mock

import os

from ai_flow.util.path_util import get_file_dir
from ai_flow.plugin_interface.blob_manager_interface import BlobConfig, BlobManagerFactory
from ai_flow_plugins.blob_manager_plugins.oss_blob_manager import OssBlobManager


class TestOSSBlobManager(unittest.TestCase):
Expand Down Expand Up @@ -51,6 +56,73 @@ def test_project_upload_download_oss(self):
downloaded_path = blob_manager.download_project('1', uploaded_path)
self.assertEqual('/tmp/workflow_1_project/project', downloaded_path)

def test_download_oss_file_concurrently(self):
config = {}
oss_blob_manager = OssBlobManager(config)

zip_file_path = None
call_count = 0

def mock_get_oss_object(dest, oss_object_key):
nonlocal zip_file_path, call_count
call_count += 1
zip_file_path = dest
with open(dest, 'w') as f:
pass

oss_blob_manager._get_oss_object = mock_get_oss_object

# get_oss_object_func = mock.patch.object(oss_blob_manager, '_get_oss_object', wraps=mock_get_oss_object)
with mock.patch(
'ai_flow_plugins.blob_manager_plugins.oss_blob_manager.extract_project_zip_file'):

def download_loop():
for i in range(1000):
oss_blob_manager.download_project('1', 'dummy_path', '/tmp')

try:
t1 = threading.Thread(target=download_loop)
t1.start()

download_loop()
t1.join()

self.assertEqual(1, call_count)
finally:
if zip_file_path is not None:
os.remove(zip_file_path)

def test_lazily_init_bucket(self):
config = {}
oss_blob_manager = OssBlobManager(config)
self.assertIsNone(oss_blob_manager._bucket)

with mock.patch('ai_flow_plugins.blob_manager_plugins.oss_blob_manager.oss2') as mock_oss:
mock_bucket = mock.Mock()
mock_oss.Bucket.return_value = mock_bucket
bucket = oss_blob_manager.bucket
mock_oss.Auth.assert_called_once()
mock_oss.Bucket.assert_called_once()
self.assertEqual(mock_bucket, bucket)
self.assertEqual(mock_bucket, oss_blob_manager.bucket)

def test__get_oss_object_retry(self):
config = {}
oss_blob_manager = OssBlobManager(config)

with mock.patch.object(oss_blob_manager, '_bucket') as mock_bucket:
mock_bucket.get_object_to_file.side_effect = [RuntimeError("boom"), RuntimeError("boom"),
RuntimeError("boom")]
with self.assertRaises(RuntimeError):
oss_blob_manager._get_oss_object('dummy_dest', 'key', retry_sleep_sec=0.1)
self.assertEqual(3, mock_bucket.get_object_to_file.call_count)

with mock.patch.object(oss_blob_manager, '_bucket') as mock_bucket:
mock_bucket.get_object_to_file.side_effect = [RuntimeError("boom"), None]
oss_blob_manager._get_oss_object('dummy_dest', 'key', retry_sleep_sec=0.1)
self.assertEqual(2, mock_bucket.get_object_to_file.call_count)


if __name__ == '__main__':
unittest.main()

0 comments on commit d3581ab

Please sign in to comment.