diff --git a/ai_flow_plugins/blob_manager_plugins/oss_blob_manager.py b/ai_flow_plugins/blob_manager_plugins/oss_blob_manager.py index 044e61d49..c1d0ebb40 100644 --- a/ai_flow_plugins/blob_manager_plugins/oss_blob_manager.py +++ b/ai_flow_plugins/blob_manager_plugins/oss_blob_manager.py @@ -16,7 +16,11 @@ # specific language governing permissions and limitations # under the License. # +import os +import logging import tempfile +import fcntl +import time from typing import Text, Dict, Any from pathlib import Path import oss2 @@ -24,6 +28,8 @@ from ai_flow.util.file_util.zip_file_util import make_dir_zipfile from ai_flow_plugins.blob_manager_plugins.blob_manager_utils import extract_project_zip_file +logger = logging.getLogger(__name__) + class OssBlobManager(BlobManager): """ @@ -34,14 +40,14 @@ class OssBlobManager(BlobManager): 3. bucket: The oss bucket name. 4. local_repository: It represents the root path of the downloaded project package. """ + def __init__(self, config: Dict[str, Any]): super().__init__(config) - ack_id = config.get('access_key_id', None) - ack_secret = config.get('access_key_secret', None) - endpoint = config.get('endpoint', None) - bucket_name = config.get('bucket', None) - auth = oss2.Auth(ack_id, ack_secret) - self.bucket = oss2.Bucket(auth, endpoint, bucket_name) + self.ack_id = config.get('access_key_id', None) + self.ack_secret = config.get('access_key_secret', None) + self.endpoint = config.get('endpoint', None) + self.bucket_name = config.get('bucket', None) + self._bucket = None self.repo_name = config.get('repo_name', '') self._local_repo = config.get('local_repository', None) @@ -81,8 +87,50 @@ def download_project(self, workflow_snapshot_id, remote_path: Text, local_path: repo_path = Path(tempfile.gettempdir()) local_zip_file_path = str(repo_path / local_zip_file_name) + '.zip' extract_path = str(repo_path / local_zip_file_name) - self.bucket.get_object_to_file(oss_object_key, filename=local_zip_file_path) + + if not os.path.exists(local_zip_file_path): + logger.debug("{} not exist".format(local_zip_file_path)) + lock_file_path = os.path.join(repo_path, "{}.lock".format(local_zip_file_name)) + lock_file = open(lock_file_path, 'w') + logger.debug("Locking file {}".format(lock_file_path)) + fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) + logger.debug("Locked file {}".format(lock_file_path)) + try: + if not os.path.exists(local_zip_file_path): + logger.info("Downloading oss object: {}".format(oss_object_key)) + self._get_oss_object(local_zip_file_path, oss_object_key) + except Exception as e: + logger.error("Failed to download oss file: {}".format(oss_object_key), exc_info=e) + finally: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) + logger.debug('Unlocked file {}'.format(lock_file_path)) + lock_file.close() + if os.path.exists(lock_file_path): + os.remove(lock_file_path) + else: + logger.info("Oss file: {} already exist at {}".format(oss_object_key, local_zip_file_path)) + return extract_project_zip_file(workflow_snapshot_id=workflow_snapshot_id, local_root_path=repo_path, zip_file_path=local_zip_file_path, extract_project_path=extract_path) + + @property + def bucket(self): + if self._bucket: + return self._bucket + auth = oss2.Auth(self.ack_id, self.ack_secret) + self._bucket = oss2.Bucket(auth, self.endpoint, self.bucket_name) + return self._bucket + + def _get_oss_object(self, dest, oss_object_key, retry_sleep_sec=5): + for i in range(3): + try: + self.bucket.get_object_to_file(oss_object_key, filename=dest) + return + except Exception as e: + logger.error("Downloading object {} failed, retrying {}/3 in {} second".format(oss_object_key, i+1, + retry_sleep_sec), + exc_info=e) + time.sleep(retry_sleep_sec) + raise RuntimeError("Failed to download oss file: {}".format(oss_object_key)) diff --git a/ai_flow_plugins/tests/blob_manager_plugins/test_oss_blob_manager.py b/ai_flow_plugins/tests/blob_manager_plugins/test_oss_blob_manager.py index d6ef2fcb0..ac7cf9c32 100644 --- a/ai_flow_plugins/tests/blob_manager_plugins/test_oss_blob_manager.py +++ b/ai_flow_plugins/tests/blob_manager_plugins/test_oss_blob_manager.py @@ -16,10 +16,15 @@ # specific language governing permissions and limitations # under the License. # +import threading import unittest +from unittest import mock + import os + from ai_flow.util.path_util import get_file_dir from ai_flow.plugin_interface.blob_manager_interface import BlobConfig, BlobManagerFactory +from ai_flow_plugins.blob_manager_plugins.oss_blob_manager import OssBlobManager class TestOSSBlobManager(unittest.TestCase): @@ -51,6 +56,73 @@ def test_project_upload_download_oss(self): downloaded_path = blob_manager.download_project('1', uploaded_path) self.assertEqual('/tmp/workflow_1_project/project', downloaded_path) + def test_download_oss_file_concurrently(self): + config = {} + oss_blob_manager = OssBlobManager(config) + + zip_file_path = None + call_count = 0 + + def mock_get_oss_object(dest, oss_object_key): + nonlocal zip_file_path, call_count + call_count += 1 + zip_file_path = dest + with open(dest, 'w') as f: + pass + + oss_blob_manager._get_oss_object = mock_get_oss_object + + # get_oss_object_func = mock.patch.object(oss_blob_manager, '_get_oss_object', wraps=mock_get_oss_object) + with mock.patch( + 'ai_flow_plugins.blob_manager_plugins.oss_blob_manager.extract_project_zip_file'): + + def download_loop(): + for i in range(1000): + oss_blob_manager.download_project('1', 'dummy_path', '/tmp') + + try: + t1 = threading.Thread(target=download_loop) + t1.start() + + download_loop() + t1.join() + + self.assertEqual(1, call_count) + finally: + if zip_file_path is not None: + os.remove(zip_file_path) + + def test_lazily_init_bucket(self): + config = {} + oss_blob_manager = OssBlobManager(config) + self.assertIsNone(oss_blob_manager._bucket) + + with mock.patch('ai_flow_plugins.blob_manager_plugins.oss_blob_manager.oss2') as mock_oss: + mock_bucket = mock.Mock() + mock_oss.Bucket.return_value = mock_bucket + bucket = oss_blob_manager.bucket + mock_oss.Auth.assert_called_once() + mock_oss.Bucket.assert_called_once() + self.assertEqual(mock_bucket, bucket) + self.assertEqual(mock_bucket, oss_blob_manager.bucket) + + def test__get_oss_object_retry(self): + config = {} + oss_blob_manager = OssBlobManager(config) + + with mock.patch.object(oss_blob_manager, '_bucket') as mock_bucket: + mock_bucket.get_object_to_file.side_effect = [RuntimeError("boom"), RuntimeError("boom"), + RuntimeError("boom")] + with self.assertRaises(RuntimeError): + oss_blob_manager._get_oss_object('dummy_dest', 'key', retry_sleep_sec=0.1) + self.assertEqual(3, mock_bucket.get_object_to_file.call_count) + + with mock.patch.object(oss_blob_manager, '_bucket') as mock_bucket: + mock_bucket.get_object_to_file.side_effect = [RuntimeError("boom"), None] + oss_blob_manager._get_oss_object('dummy_dest', 'key', retry_sleep_sec=0.1) + self.assertEqual(2, mock_bucket.get_object_to_file.call_count) + if __name__ == '__main__': unittest.main() +