diff --git a/sdk/storage/azure-storage-file/azure/storage/file/_share_utils.py b/sdk/storage/azure-storage-file/azure/storage/file/_share_utils.py index 4032ec44c2c2..f185529b5a70 100644 --- a/sdk/storage/azure-storage-file/azure/storage/file/_share_utils.py +++ b/sdk/storage/azure-storage-file/azure/storage/file/_share_utils.py @@ -4,7 +4,11 @@ # license information. # -------------------------------------------------------------------------- -from .models import ShareProperties +from .models import ShareProperties, DirectoryProperties, FileProperty +from ._shared.utils import return_response_headers +from ._upload_chunking import _upload_file_chunks +from ._generated.models import StorageErrorException +from ._shared.utils import process_storage_error def deserialize_metadata(response, obj, headers): @@ -17,4 +21,76 @@ def deserialize_share_properties(response, obj, headers): metadata=metadata, **headers ) - return share_properties \ No newline at end of file + return share_properties + +def deserialize_directory_properties(response, obj, headers): + metadata = deserialize_metadata(response, obj, headers) + directory_properties = DirectoryProperties( + metadata=metadata, + **headers + ) + return directory_properties + +def deserialize_file_properties(response, obj, headers): + metadata = deserialize_metadata(response, obj, headers) + file_properties = FileProperty( + metadata=metadata, + **headers + ) + return file_properties + +def upload_file_helper( + client, + share_name, + directory_name, + file_name, + stream, + size, + headers, + file_http_headers, + validate_content, + timeout, + max_connections, + file_settings, + encryption_data, + **kwargs): + try: + if size is None or size < 0: + raise ValueError("A content size must be specified for a File.") + if size % 512 != 0: + raise ValueError("Invalidfile size: {0}. " + "The size must be aligned to a 512-byte boundary.".format(size)) + if encryption_data is not None: + headers['x-ms-meta-encryptiondata'] = encryption_data + response = client.create( + file_content_length=size, + timeout=timeout, + file_http_headers=file_http_headers, + headers=headers, + cls=return_response_headers, + **kwargs) + if size == 0: + return response + + return _upload_file_chunks( + file_service=client, + file_size=size, + share_name=share_name, + directory_name=directory_name, + file_name=file_name, + block_size=file_settings.max_page_size, + stream=stream, + max_connections=max_connections, + validate_content=validate_content, + timeout=timeout, + **kwargs) + except StorageErrorException as error: + process_storage_error(error) + +class StorageStreamDownloader(object): # pylint: disable=too-many-instance-attributes + + def __init__( + self, name, share_name, directory_name, service, config, length, validate_content, + timeout, require_encryption, key_encryption_key, key_resolver_function, **kwargs + ): + pass \ No newline at end of file diff --git a/sdk/storage/azure-storage-file/azure/storage/file/_shared/encryption.py b/sdk/storage/azure-storage-file/azure/storage/file/_shared/encryption.py new file mode 100644 index 000000000000..c65e6b97dab9 --- /dev/null +++ b/sdk/storage/azure-storage-file/azure/storage/file/_shared/encryption.py @@ -0,0 +1,441 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from json import ( + dumps, + loads, +) +from os import urandom +from collections import OrderedDict + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import Cipher +from cryptography.hazmat.primitives.ciphers.algorithms import AES +from cryptography.hazmat.primitives.ciphers.modes import CBC +from cryptography.hazmat.primitives.padding import PKCS7 + +from .version import __version__ +from .authentication import _encode_base64, _decode_base64_to_bytes + + +_ENCRYPTION_PROTOCOL_V1 = '1.0' +_ERROR_VALUE_NONE = '{0} should not be None.' +_ERROR_OBJECT_INVALID = \ + '{0} does not define a complete interface. Value of {1} is either missing or invalid.' +_ERROR_DATA_NOT_ENCRYPTED = 'Encryption required, but received data does not contain appropriate metatadata.' + \ + 'Data was either not encrypted or metadata has been lost.' +_ERROR_UNSUPPORTED_ENCRYPTION_ALGORITHM = \ + 'Specified encryption algorithm is not supported.' + + +def _validate_not_none(param_name, param): + if param is None: + raise ValueError(_ERROR_VALUE_NONE.format(param_name)) + + +def _validate_key_encryption_key_wrap(kek): + # Note that None is not callable and so will fail the second clause of each check. + if not hasattr(kek, 'wrap_key') or not callable(kek.wrap_key): + raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'wrap_key')) + if not hasattr(kek, 'get_kid') or not callable(kek.get_kid): + raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'get_kid')) + if not hasattr(kek, 'get_key_wrap_algorithm') or not callable(kek.get_key_wrap_algorithm): + raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'get_key_wrap_algorithm')) + + +class _EncryptionAlgorithm(object): + ''' + Specifies which client encryption algorithm is used. + ''' + AES_CBC_256 = 'AES_CBC_256' + + +class _WrappedContentKey: + ''' + Represents the envelope key details stored on the service. + ''' + + def __init__(self, algorithm, encrypted_key, key_id): + ''' + :param str algorithm: + The algorithm used for wrapping. + :param bytes encrypted_key: + The encrypted content-encryption-key. + :param str key_id: + The key-encryption-key identifier string. + ''' + + _validate_not_none('algorithm', algorithm) + _validate_not_none('encrypted_key', encrypted_key) + _validate_not_none('key_id', key_id) + + self.algorithm = algorithm + self.encrypted_key = encrypted_key + self.key_id = key_id + + +class _EncryptionAgent: + ''' + Represents the encryption agent stored on the service. + It consists of the encryption protocol version and encryption algorithm used. + ''' + + def __init__(self, encryption_algorithm, protocol): + ''' + :param _EncryptionAlgorithm encryption_algorithm: + The algorithm used for encrypting the message contents. + :param str protocol: + The protocol version used for encryption. + ''' + + _validate_not_none('encryption_algorithm', encryption_algorithm) + _validate_not_none('protocol', protocol) + + self.encryption_algorithm = str(encryption_algorithm) + self.protocol = protocol + + +class _EncryptionData: + ''' + Represents the encryption data that is stored on the service. + ''' + + def __init__(self, content_encryption_IV, encryption_agent, wrapped_content_key, + key_wrapping_metadata): + ''' + :param bytes content_encryption_IV: + The content encryption initialization vector. + :param _EncryptionAgent encryption_agent: + The encryption agent. + :param _WrappedContentKey wrapped_content_key: + An object that stores the wrapping algorithm, the key identifier, + and the encrypted key bytes. + :param dict key_wrapping_metadata: + A dict containing metadata related to the key wrapping. + ''' + + _validate_not_none('content_encryption_IV', content_encryption_IV) + _validate_not_none('encryption_agent', encryption_agent) + _validate_not_none('wrapped_content_key', wrapped_content_key) + + self.content_encryption_IV = content_encryption_IV + self.encryption_agent = encryption_agent + self.wrapped_content_key = wrapped_content_key + self.key_wrapping_metadata = key_wrapping_metadata + + +def _generate_encryption_data_dict(kek, cek, iv): + ''' + Generates and returns the encryption metadata as a dict. + + :param object kek: The key encryption key. See calling functions for more information. + :param bytes cek: The content encryption key. + :param bytes iv: The initialization vector. + :return: A dict containing all the encryption metadata. + :rtype: dict + ''' + # Encrypt the cek. + wrapped_cek = kek.wrap_key(cek) + + # Build the encryption_data dict. + # Use OrderedDict to comply with Java's ordering requirement. + wrapped_content_key = OrderedDict() + wrapped_content_key['KeyId'] = kek.get_kid() + wrapped_content_key['EncryptedKey'] = _encode_base64(wrapped_cek) + wrapped_content_key['Algorithm'] = kek.get_key_wrap_algorithm() + + encryption_agent = OrderedDict() + encryption_agent['Protocol'] = _ENCRYPTION_PROTOCOL_V1 + encryption_agent['EncryptionAlgorithm'] = _EncryptionAlgorithm.AES_CBC_256 + + encryption_data_dict = OrderedDict() + encryption_data_dict['WrappedContentKey'] = wrapped_content_key + encryption_data_dict['EncryptionAgent'] = encryption_agent + encryption_data_dict['ContentEncryptionIV'] = _encode_base64(iv) + encryption_data_dict['KeyWrappingMetadata'] = {'EncryptionLibrary': 'Python ' + __version__} + + return encryption_data_dict + + +def _dict_to_encryption_data(encryption_data_dict): + ''' + Converts the specified dictionary to an EncryptionData object for + eventual use in decryption. + + :param dict encryption_data_dict: + The dictionary containing the encryption data. + :return: an _EncryptionData object built from the dictionary. + :rtype: _EncryptionData + ''' + try: + if encryption_data_dict['EncryptionAgent']['Protocol'] != _ENCRYPTION_PROTOCOL_V1: + raise ValueError("Unsupported encryption version.") + except KeyError: + raise ValueError("Unsupported encryption version.") + wrapped_content_key = encryption_data_dict['WrappedContentKey'] + wrapped_content_key = _WrappedContentKey(wrapped_content_key['Algorithm'], + _decode_base64_to_bytes(wrapped_content_key['EncryptedKey']), + wrapped_content_key['KeyId']) + + encryption_agent = encryption_data_dict['EncryptionAgent'] + encryption_agent = _EncryptionAgent(encryption_agent['EncryptionAlgorithm'], + encryption_agent['Protocol']) + + if 'KeyWrappingMetadata' in encryption_data_dict: + key_wrapping_metadata = encryption_data_dict['KeyWrappingMetadata'] + else: + key_wrapping_metadata = None + + encryption_data = _EncryptionData(_decode_base64_to_bytes(encryption_data_dict['ContentEncryptionIV']), + encryption_agent, + wrapped_content_key, + key_wrapping_metadata) + + return encryption_data + + +def _generate_AES_CBC_cipher(cek, iv): + ''' + Generates and returns an encryption cipher for AES CBC using the given cek and iv. + + :param bytes[] cek: The content encryption key for the cipher. + :param bytes[] iv: The initialization vector for the cipher. + :return: A cipher for encrypting in AES256 CBC. + :rtype: ~cryptography.hazmat.primitives.ciphers.Cipher + ''' + + backend = default_backend() + algorithm = AES(cek) + mode = CBC(iv) + return Cipher(algorithm, mode, backend) + + +def _validate_and_unwrap_cek(encryption_data, key_encryption_key=None, key_resolver=None): + ''' + Extracts and returns the content_encryption_key stored in the encryption_data object + and performs necessary validation on all parameters. + :param _EncryptionData encryption_data: + The encryption metadata of the retrieved value. + :param obj key_encryption_key: + The key_encryption_key used to unwrap the cek. Please refer to high-level service object + instance variables for more details. + :param func key_resolver: + A function used that, given a key_id, will return a key_encryption_key. Please refer + to high-level service object instance variables for more details. + :return: the content_encryption_key stored in the encryption_data object. + :rtype: bytes[] + ''' + + _validate_not_none('content_encryption_IV', encryption_data.content_encryption_IV) + _validate_not_none('encrypted_key', encryption_data.wrapped_content_key.encrypted_key) + + if _ENCRYPTION_PROTOCOL_V1 != encryption_data.encryption_agent.protocol: + raise ValueError('Encryption version is not supported.') + + content_encryption_key = None + + # If the resolver exists, give priority to the key it finds. + if key_resolver is not None: + key_encryption_key = key_resolver(encryption_data.wrapped_content_key.key_id) + + _validate_not_none('key_encryption_key', key_encryption_key) + if not hasattr(key_encryption_key, 'get_kid') or not callable(key_encryption_key.get_kid): + raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'get_kid')) + if not hasattr(key_encryption_key, 'unwrap_key') or not callable(key_encryption_key.unwrap_key): + raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'unwrap_key')) + if encryption_data.wrapped_content_key.key_id != key_encryption_key.get_kid(): + raise ValueError('Provided or resolved key-encryption-key does not match the id of key used to encrypt.') + # Will throw an exception if the specified algorithm is not supported. + content_encryption_key = key_encryption_key.unwrap_key(encryption_data.wrapped_content_key.encrypted_key, + encryption_data.wrapped_content_key.algorithm) + _validate_not_none('content_encryption_key', content_encryption_key) + + return content_encryption_key + + +def _encrypt_blob(blob, key_encryption_key): + ''' + Encrypts the given blob using AES256 in CBC mode with 128 bit padding. + Wraps the generated content-encryption-key using the user-provided key-encryption-key (kek). + Returns a json-formatted string containing the encryption metadata. This method should + only be used when a blob is small enough for single shot upload. Encrypting larger blobs + is done as a part of the upload_blob_chunks method. + + :param bytes blob: + The blob to be encrypted. + :param object key_encryption_key: + The user-provided key-encryption-key. Must implement the following methods: + wrap_key(key)--wraps the specified key using an algorithm of the user's choice. + get_key_wrap_algorithm()--returns the algorithm used to wrap the specified symmetric key. + get_kid()--returns a string key id for this key-encryption-key. + :return: A tuple of json-formatted string containing the encryption metadata and the encrypted blob data. + :rtype: (str, bytes) + ''' + + _validate_not_none('blob', blob) + _validate_not_none('key_encryption_key', key_encryption_key) + _validate_key_encryption_key_wrap(key_encryption_key) + + # AES256 uses 256 bit (32 byte) keys and always with 16 byte blocks + content_encryption_key = urandom(32) + initialization_vector = urandom(16) + + cipher = _generate_AES_CBC_cipher(content_encryption_key, initialization_vector) + + # PKCS7 with 16 byte blocks ensures compatibility with AES. + padder = PKCS7(128).padder() + padded_data = padder.update(blob) + padder.finalize() + + # Encrypt the data. + encryptor = cipher.encryptor() + encrypted_data = encryptor.update(padded_data) + encryptor.finalize() + encryption_data = _generate_encryption_data_dict(key_encryption_key, content_encryption_key, + initialization_vector) + encryption_data['EncryptionMode'] = 'FullBlob' + + return dumps(encryption_data), encrypted_data + + +def _generate_blob_encryption_data(key_encryption_key): + ''' + Generates the encryption_metadata for the blob. + + :param bytes key_encryption_key: + The key-encryption-key used to wrap the cek associate with this blob. + :return: A tuple containing the cek and iv for this blob as well as the + serialized encryption metadata for the blob. + :rtype: (bytes, bytes, str) + ''' + encryption_data = None + content_encryption_key = None + initialization_vector = None + if key_encryption_key: + _validate_key_encryption_key_wrap(key_encryption_key) + content_encryption_key = urandom(32) + initialization_vector = urandom(16) + encryption_data = _generate_encryption_data_dict(key_encryption_key, + content_encryption_key, + initialization_vector) + encryption_data['EncryptionMode'] = 'FullBlob' + encryption_data = dumps(encryption_data) + + return content_encryption_key, initialization_vector, encryption_data + +def _generate_file_encryption_data(key_encryption_key): + ''' + Generates the encryption_metadata for the file. + + :param bytes key_encryption_key: + The key-encryption-key used to wrap the cek associate with this file. + :return: A tuple containing the cek and iv for this file as well as the + serialized encryption metadata for the file. + :rtype: (bytes, bytes, str) + ''' + encryption_data = None + content_encryption_key = None + initialization_vector = None + if key_encryption_key: + _validate_key_encryption_key_wrap(key_encryption_key) + content_encryption_key = urandom(32) + initialization_vector = urandom(16) + encryption_data = _generate_encryption_data_dict(key_encryption_key, + content_encryption_key, + initialization_vector) + encryption_data['EncryptionMode'] = 'FullFile' + encryption_data = dumps(encryption_data) + + return content_encryption_key, initialization_vector, encryption_data + + +def _decrypt_blob(require_encryption, key_encryption_key, key_resolver, + response, start_offset, end_offset): + ''' + Decrypts the given blob contents and returns only the requested range. + + :param bool require_encryption: + Whether or not the calling blob service requires objects to be decrypted. + :param object key_encryption_key: + The user-provided key-encryption-key. Must implement the following methods: + wrap_key(key)--wraps the specified key using an algorithm of the user's choice. + get_key_wrap_algorithm()--returns the algorithm used to wrap the specified symmetric key. + get_kid()--returns a string key id for this key-encryption-key. + :param key_resolver(kid): + The user-provided key resolver. Uses the kid string to return a key-encryption-key + implementing the interface defined above. + :return: The decrypted blob content. + :rtype: bytes + ''' + if response is None: + raise ValueError("Response cannot be None.") + content = b"".join(list(response)) + if not content: + return content + + try: + encryption_data = _dict_to_encryption_data(loads(response.response.headers['x-ms-meta-encryptiondata'])) + except: # pylint: disable=bare-except + if require_encryption: + raise ValueError(_ERROR_DATA_NOT_ENCRYPTED) + + return content + + if encryption_data.encryption_agent.encryption_algorithm != _EncryptionAlgorithm.AES_CBC_256: + raise ValueError(_ERROR_UNSUPPORTED_ENCRYPTION_ALGORITHM) + + blob_type = response.response.headers['x-ms-blob-type'] + + iv = None + unpad = False + if 'content-range' in response.response.headers: + content_range = response.response.headers['content-range'] + # Format: 'bytes x-y/size' + + # Ignore the word 'bytes' + content_range = content_range.split(' ') + + content_range = content_range[1].split('-') + content_range = content_range[1].split('/') + end_range = int(content_range[0]) + blob_size = int(content_range[1]) + + if start_offset >= 16: + iv = content[:16] + content = content[16:] + start_offset -= 16 + else: + iv = encryption_data.content_encryption_IV + + if end_range == blob_size - 1: + unpad = True + else: + unpad = True + iv = encryption_data.content_encryption_IV + + if blob_type == 'PageBlob': + unpad = False + + content_encryption_key = _validate_and_unwrap_cek(encryption_data, key_encryption_key, key_resolver) + cipher = _generate_AES_CBC_cipher(content_encryption_key, iv) + decryptor = cipher.decryptor() + + content = decryptor.update(content) + decryptor.finalize() + if unpad: + unpadder = PKCS7(128).unpadder() + content = unpadder.update(content) + unpadder.finalize() + + return content[start_offset: len(content) - end_offset] + + +def _get_blob_encryptor_and_padder(cek, iv, should_pad): + encryptor = None + padder = None + + if cek is not None and iv is not None: + cipher = _generate_AES_CBC_cipher(cek, iv) + encryptor = cipher.encryptor() + padder = PKCS7(128).padder() if should_pad else None + + return encryptor, padder diff --git a/sdk/storage/azure-storage-file/azure/storage/file/_shared/version.py b/sdk/storage/azure-storage-file/azure/storage/file/_shared/version.py new file mode 100644 index 000000000000..b111042a2721 --- /dev/null +++ b/sdk/storage/azure-storage-file/azure/storage/file/_shared/version.py @@ -0,0 +1,7 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +__version__ = "3.0.0a8" diff --git a/sdk/storage/azure-storage-file/azure/storage/file/_upload_chunking.py b/sdk/storage/azure-storage-file/azure/storage/file/_upload_chunking.py index c6fb34f77bf2..308586f6e804 100644 --- a/sdk/storage/azure-storage-file/azure/storage/file/_upload_chunking.py +++ b/sdk/storage/azure-storage-file/azure/storage/file/_upload_chunking.py @@ -4,7 +4,8 @@ # license information. # -------------------------------------------------------------------------- import threading - +import six +from io import UnsupportedOperation def _upload_file_chunks(file_service, share_name, directory_name, file_name, file_size, block_size, stream, max_connections, @@ -131,3 +132,50 @@ def _upload_chunk_with_progress(self, chunk_start, chunk_data): range_id = 'bytes={0}-{1}'.format(chunk_start, chunk_end) self._update_progress(len(chunk_data)) return range_id + + +class IterStreamer(object): + """ + File-like streaming iterator. + """ + def __init__(self, generator, encoding='UTF-8'): + self.generator = generator + self.iterator = iter(generator) + self.leftover = b'' + self.encoding = encoding + + def __len__(self): + return self.generator.__len__() + + def __iter__(self): + return self.iterator + + def seekable(self): + return False + + def next(self): + return next(self.iterator) + + def tell(self, *args, **kwargs): + raise UnsupportedOperation("Data generator does not support tell.") + + def seek(self, *args, **kwargs): + raise UnsupportedOperation("Data generator is unseekable.") + + def read(self, size): + data = self.leftover + count = len(self.leftover) + try: + while count < size: + chunk = self.next() + if isinstance(chunk, six.text_type): + chunk = chunk.encode(self.encoding) + data += chunk + count += len(chunk) + except StopIteration: + pass + + if count > size: + self.leftover = data[size:] + + return data[:size] diff --git a/sdk/storage/azure-storage-file/azure/storage/file/directory_client.py b/sdk/storage/azure-storage-file/azure/storage/file/directory_client.py index e09756b00b33..11a6b32dad28 100644 --- a/sdk/storage/azure-storage-file/azure/storage/file/directory_client.py +++ b/sdk/storage/azure-storage-file/azure/storage/file/directory_client.py @@ -4,8 +4,33 @@ # license information. # -------------------------------------------------------------------------- +import functools -class DirectoryClient(): +try: + from urllib.parse import urlparse, quote, unquote +except ImportError: + from urlparse import urlparse + from urllib2 import quote, unquote + +from .file_client import FileClient + +from .models import DirectoryPropertiesPaged +from ._generated import AzureFileStorage +from ._generated.version import VERSION +from ._generated.models import StorageErrorException, SignedIdentifier +from ._shared.utils import ( + StorageAccountHostsMixin, + serialize_iso, + return_headers_and_deserialized, + parse_query, + return_response_headers, + add_metadata_headers, + process_storage_error, + parse_connection_str) + +from ._share_utils import deserialize_directory_properties + +class DirectoryClient(StorageAccountHostsMixin): """ A client to interact with the sirectory. """ @@ -14,7 +39,8 @@ def __init__( self, share_name=None, # type: Optional[Union[str, ShareProperties]] directory_path=None, # type: Optional[str] credential=None, # type: Optional[Any] - configuration=None # type: Optional[Configuration] + configuration=None, # type: Optional[Configuration] + **kwargs, # type: Optional[Any] ): # type: (...) -> DirectoryClient """Creates a new DirectoryClient. This client represents interaction with a specific @@ -27,19 +53,47 @@ def __init__( :param configuration: A optional pipeline configuration. This can be retrieved with :func:`DirectoryClient.create_configuration()` """ - + try: + if not directory_path.lower().startswith('http'): + directory_path = "https://" + directory_path + except AttributeError: + raise ValueError("directory_path must be a string.") + parsed_url = urlparse(directory_path.rstrip('/')) + if not parsed_url.path and not share_name: + raise ValueError("Please specify a share name.") + if not parsed_url.netloc: + raise ValueError("Invalid URL: {}".format(directory_path)) + + share, path_dir = "", "" + if parsed_url.path: + share, _, path_dir = parsed_url.path.lstrip('/').partition('/') + _, sas_token = parse_query(parsed_url.query) + try: + self.share_name = share_name.name + except AttributeError: + self.share_name = share_name or unquote(share) + + self.directory_path = path_dir or "" + + self._query_str, credential = self._format_query_string(sas_token, credential) + super(DirectoryClient, self).__init__(parsed_url, credential, configuration, **kwargs) + self._client = AzureFileStorage(version=VERSION, url=self.url, pipeline=self._pipeline) + @classmethod def from_connection_string( cls, conn_str, # type: str share_name=None, # type: Optional[Union[str, ShareProperties]] directory_path=None, # type: Optional[str] configuration=None, # type: Optional[Configuration] + credential=None, # type: Optiona[Any] **kwargs # type: Any ): # type: (...) -> DirectoryClient """ Create DirectoryClient from a Connection String. """ + _, credential = parse_connection_str(conn_str, credential) + return cls(share_name, directory_path, credential=credential, configuration=configuration, **kwargs) def get_file_client(self, file_name): """Get a client to interact with the specified file. @@ -51,8 +105,15 @@ def get_file_client(self, file_name): :returns: A File Client. :rtype: ~azure.core.file.file_client.FileClient """ + file_url = self.directory_path.rstrip('/') + "/" + quote(file_name) + return FileClient( + file_url=file_url, + share_name=self.share_name, + directory_path=self.directory_path, + credential=self.credential, + configuration=self._config) - def get_subdirectory_client(self, directory_name): + def get_subdirectory_client(self, directory_name, **kwargs): """Get a client to interact with the specified subdirectory. The subdirectory need not already exist. @@ -62,10 +123,13 @@ def get_subdirectory_client(self, directory_name): :returns: A Directory Client. :rtype: ~azure.core.file.directory_client.DirectoryClient """ + directory_path = self.directory_path.rstrip('/') + "/" + quote(directory_name) + return DirectoryClient(self.share_name, directory_path, self.credential, self._config, **kwargs) def create_directory( self, metadata=None, # type: Optional[Dict[str, str]] - timeout=None # type: Optional[int] + timeout=None, # type: Optional[int] + **kwargs # type: Optional[Any] ): # type: (...) -> Dict[str, Any] """Creates a new Directory. @@ -77,6 +141,16 @@ def create_directory( :returns: Directory-updated property dict (Etag and last modified). :rtype: dict(str, Any) """ + headers = kwargs.pop('headers', {}) + headers.update(add_metadata_headers(metadata)) + try: + return self._client.directory.create( + timeout=timeout, + cls=return_response_headers, + headers=headers, + **kwargs) + except StorageErrorException as error: + process_storage_error(error) def delete_directory(self, timeout=None, **kwargs): # type: (Optional[int]) -> None @@ -87,18 +161,37 @@ def delete_directory(self, timeout=None, **kwargs): The timeout parameter is expressed in seconds. :rtype: None """ + try: + self._client.directory.delete(timeout=timeout, **kwargs) + except StorageErrorException as error: + process_storage_error(error) def list_directies_and_files(self, prefix=None, timeout=None, **kwargs): # type: (Optional[str], Optional[int]) -> DirectoryProperties """ - :returns: An auto-paging iterable of dict-like DirectoryProperties and FileProperties + :returns: An auto-paging iterable of dict-like DirectoryProperties and FileProperties """ + results_per_page = kwargs.pop('results_per_page', None) + command = functools.partial( + self._client.directory.list_files_and_directories_segment, + prefix=prefix, + timeout=timeout, + **kwargs) + return DirectoryPropertiesPaged(command, prefix=prefix, results_per_page=results_per_page) def get_directory_properties(self, timeout=None, **kwargs): # type: (Optional[int], Any) -> DirectoryProperties """ :returns: DirectoryProperties """ + try: + response = self._client.directory.get_properties( + timeout=timeout, + cls=deserialize_directory_properties, + **kwargs) + except StorageErrorException as error: + process_storage_error(error) + return response def set_directory_metadata(self, metadata, timeout=None, **kwargs): # type: (Dict[str, Any], Optional[int], Any) -> Dict[str, Any] @@ -111,6 +204,16 @@ def set_directory_metadata(self, metadata, timeout=None, **kwargs): :returns: directory-updated property dict (Etag and last modified). :rtype: dict(str, Any) """ + headers = kwargs.pop('headers', {}) + headers.update(add_metadata_headers(metadata)) + try: + return self._client.directory.set_metadata( + timeout=timeout, + cls=return_response_headers, + headers=headers, + **kwargs) + except StorageErrorException as error: + process_storage_error(error) def create_subdirectory( self, directory_name, # type: str, @@ -150,7 +253,8 @@ def create_file( size, # type: int content_settings=None, # type: Any metadata=None, #type: Optional[Dict[str, Any]] - timeout=None # type: Optional[int] + timeout=None, # type: Optional[int] + **kwargs # type: Optional[Any] ): # type: (...) -> FileClient """Creates a new file. @@ -170,10 +274,15 @@ def create_file( :returns: FileClient :rtype: FileClient """ + file_client = self.get_file_client(file_name) + file_client.create_file(size, content_settings, metadata, timeout, **kwargs) + return file_client + def delete_file( self, file_name, # type: str, - timeout=None # type: Optional[int] + timeout=None, # type: Optional[int], + **kwargs # type: Optional[Any] ): # type: (...) -> None """Deletes a file. @@ -184,3 +293,5 @@ def delete_file( The timeout parameter is expressed in seconds. :returns: None """ + file_client = self.get_file_client(file_name) + file_client.delete_file(timeout, **kwargs) diff --git a/sdk/storage/azure-storage-file/azure/storage/file/file_client.py b/sdk/storage/azure-storage-file/azure/storage/file/file_client.py index e86c51a122fa..d871bc656e14 100644 --- a/sdk/storage/azure-storage-file/azure/storage/file/file_client.py +++ b/sdk/storage/azure-storage-file/azure/storage/file/file_client.py @@ -4,8 +4,35 @@ # license information. # -------------------------------------------------------------------------- +import functools +from io import BytesIO -class FileClient(): +try: + from urllib.parse import urlparse, quote, unquote +except ImportError: + from urlparse import urlparse + from urllib2 import quote, unquote + +from .models import DirectoryPropertiesPaged +from ._generated import AzureFileStorage +from ._generated.version import VERSION +from ._generated.models import StorageErrorException, FileHTTPHeaders +from ._shared.utils import ( + StorageAccountHostsMixin, + parse_query, + return_response_headers, + add_metadata_headers, + process_storage_error, + parse_connection_str) +from ._shared.encryption import _generate_file_encryption_data +from ._share_utils import upload_file_helper, deserialize_file_properties, StorageStreamDownloader +from ._upload_chunking import IterStreamer +from .polling import CopyStatusPoller + +from ._share_utils import deserialize_directory_properties + + +class FileClient(StorageAccountHostsMixin): """ A client to interact with the file. """ @@ -14,7 +41,7 @@ def __init__( share_name=None, # type: Optional[Union[str, ShareProperties]] directory_path=None, # type: Optional[str] snapshot=None, # type: Optional[Union[str, Dict[str, Any]]] - credentials=None, # type: Optional[Any] + credential=None, # type: Optional[Any] configuration=None, # type: Optional[Configuration] **kwargs # type: Any ): @@ -28,6 +55,31 @@ def __init__( :param credentials: :param configuration: A optional pipeline configuration. """ + try: + if not file_url.lower().startswith('http'): + file_url = "https://" + file_url + except AttributeError: + raise ValueError("File URL must be a string.") + parsed_url = urlparse(file_url.rstrip('/')) + if not parsed_url.path and not (share_name and directory_path): + raise ValueError("Please specify a directory_path and share_name.") + if not parsed_url.netloc: + raise ValueError("Invalid URL: {}".format(file_url)) + + path_share = "" + if parsed_url.path: + path_share, _, path_file = parsed_url.path.lstrip('/').partition('/') + _, sas_token = parse_query(parsed_url.query) + + try: + self.share_name = share_name.name + except AttributeError: + self.share_name = share_name or unquote(path_share) + self.file_name = unquote(path_file.split('/')[-1]) or "" + self.directory_name = unquote(path_file.split('/')[-2]) if directory_path else "" + self._query_str, credential = self._format_query_string(sas_token, credential) + super(FileClient, self).__init__(parsed_url, credential, configuration, **kwargs) + self._client = AzureFileStorage(version=VERSION, url=self.url, pipeline=self._pipeline) @classmethod def from_connection_string( @@ -35,6 +87,7 @@ def from_connection_string( share_name=None, # type: Optional[Union[str, ShareProperties]] directory_path=None, # type: Optional[str] file_name=None, # type: Optional[str] + credential=None, # type: Optional[Any] configuration=None, # type: Optional[Configuration] **kwargs # type: Any ): @@ -42,11 +95,18 @@ def from_connection_string( """ Create FileClient from a Connection String. """ + account_url, credential = parse_connection_str(conn_str, credential) + + return cls( + account_url, share_name=share_name, directory_path=directory_path, + credential=credential, configuration=configuration, **kwargs) def create_file( - self, source_url, # type: str + self, size=None, # type: Optional[int] + content_settings=None, # type: Optional[ContentSettings] metadata=None, # type: Optional[Dict[str, str]] - timeout=None # type: Optional[int] + timeout=None, # type: Optional[int] + **kwargs # type: Any ): # type: (...) -> Dict[str, Any] """Creates a new FileClient. @@ -60,6 +120,32 @@ def create_file( :returns: File-updated property dict (Etag and last modified). :rtype: dict(str, Any) """ + if self.require_encryption and not self.key_encryption_key: + raise ValueError("Encryption required but no key was provided.") + + headers = kwargs.pop('headers', {}) + headers.update(add_metadata_headers(metadata)) + file_http_headers = None + if content_settings: + file_http_headers = FileHTTPHeaders( + file_cache_control=content_settings.cache_control, + file_content_type=content_settings.content_type, + file_content_md5=bytearray(content_settings.content_md5) if content_settings.content_md5 else None, + file_content_encoding=content_settings.content_encoding, + file_content_language=content_settings.content_language, + file_content_disposition=content_settings.content_disposition + ) + try: + return self._client.file.create( + file_content_length=size, + timeout=timeout, + metadata=metadata, + file_http_headers=file_http_headers, + headers=headers, + cls=return_response_headers, + **kwargs) + except StorageErrorException as error: + process_storage_error(error) def upload_file( self, data, # type: Any @@ -68,7 +154,8 @@ def upload_file( metadata=None, # type: Optional[Dict[str, str]] validate_content=False, # type: bool max_connections=1, # type: Optional[int] - timeout=None # type: Optional[int] + timeout=None, # type: Optional[int] + **kwargs # type: Any ): # type: (...) -> Dict[str, Any] """Uploads a new file. @@ -93,11 +180,58 @@ def upload_file( :returns: File-updated property dict (Etag and last modified). :rtype: dict(str, Any) """ + if self.require_encryption and not self.key_encryption_key: + raise ValueError("Encryption required but no key was provided.") + + _, _, encryption_data = None, None, None + if self.key_encryption_key is not None: + _, _, encryption_data = _generate_file_encryption_data(self.key_encryption_key) + + if isinstance(data, bytes): + data = data[:size] + + if isinstance(data, bytes): + stream = BytesIO(data) + elif hasattr(data, 'read'): + stream = data + elif hasattr(data, '__iter__'): + stream = IterStreamer(data) + else: + raise TypeError("Unsupported data type: {}".format(type(data))) + + headers = kwargs.pop('headers', {}) + headers.update(add_metadata_headers(metadata)) + file_http_headers = None + if content_settings: + file_http_headers = FileHTTPHeaders( + file_cache_control=content_settings.cache_control, + file_content_type=content_settings.content_type, + file_content_md5=bytearray(content_settings.content_md5) if content_settings.content_md5 else None, + file_content_encoding=content_settings.content_encoding, + file_content_language=content_settings.content_language, + file_content_disposition=content_settings.content_disposition + ) + return upload_file_helper( + self._client.file, + self.share_name, + self.directory_name, + self.file_name, + stream, + size, + headers, + file_http_headers, + validate_content, + timeout, + max_connections, + self._config.file_settings, + encryption_data, + **kwargs) def copy_from_url( self, source_url, # type: str metadata=None, # type: Optional[Dict[str, str]] - timeout=None # type: Optional[int] + timeout=None, # type: Optional[int] + **kwargs # type: Any ): # type: (...) -> Any """Creates a new FileClient. @@ -111,6 +245,25 @@ def copy_from_url( :returns: Polling object in order to wait on or abort the operation :rtype: Any """ + headers = kwargs.pop('headers', {}) + headers.update(add_metadata_headers(metadata)) + + try: + start_copy = self._client.file.start_copy( + source_url, + timeout=None, + metadata=metadata, + headers=headers, + cls=return_response_headers, + **kwargs) + except StorageErrorException as error: + process_storage_error(error) + + poller = CopyStatusPoller( + self, start_copy, + configuration=self._config, + timeout=timeout) + return poller def download_file( self, start_range=None, # type: Optional[int] @@ -124,9 +277,26 @@ def download_file( """ :returns: A iterable data generator (stream) """ + if self.require_encryption and not self.key_encryption_key: + raise ValueError("Encryption required but no key was provided.") + + return StorageStreamDownloader( + name=self.file_name, + share_name=self.share_name, + directory_name=self.directory_name, + service=self._client.file, + config=self._config.file_settings, + length=length, + validate_content=validate_content, + timeout=timeout, + require_encryption=self.require_encryption, + key_encryption_key=self.key_encryption_key, + key_resolver_function=self.key_resolver_function, + **kwargs) def download_file_to_stream( - self, start_range=None, # type: Optional[int] + self, stream, # type: Any + start_range=None, # type: Optional[int] end_range=None, # type: Optional[int] length=None, # type: Optional[int] validate_content=False, # type: bool @@ -148,26 +318,70 @@ def delete_file(self, timeout=None, **kwargs): The timeout parameter is expressed in seconds. :rtype: None """ + try: + self._client.file.delete(timeout=timeout, **kwargs) + except StorageErrorException as error: + process_storage_error(error) def get_file_properties(self, timeout=None, **kwargs): # type: (Optional[int], Any) -> FileProperties """ :returns: FileProperties """ + try: + file_props = self._client.file.get_properties( + timeout=timeout, + cls=deserialize_file_properties, + **kwargs) + except StorageErrorException as error: + process_storage_error(error) + file_props.name = self.file_name + file_props.share_name = self.share_name + return file_props - def set_http_headers(self, content_settings, timeout=None): - #type: (ContentSettings, Optional[int]) -> Dict[str, Any] + def set_http_headers(self, content_settings, timeout=None, **kwargs): + #type: (ContentSettings, Optional[int], Optional[Any]) -> Dict[str, Any] """ :returns: File-updated property dict (Etag and last modified). :rtype: dict(str, Any) """ + file_content_length = kwargs.pop('size', None) + file_http_headers = FileHTTPHeaders( + file_cache_control=content_settings.cache_control, + file_content_type=content_settings.content_type, + file_content_md5=bytearray(content_settings.content_md5) if content_settings.content_md5 else None, + file_content_encoding=content_settings.content_encoding, + file_content_language=content_settings.content_language, + file_content_disposition=content_settings.content_disposition + ) + try: + return self._client.file.set_http_headers( + timeout=timeout, + file_content_length=file_content_length, + file_http_headers=file_http_headers, + cls=return_response_headers, + **kwargs) + except StorageErrorException as error: + process_storage_error(error) - def set_file_metadata(self, metadata=None, timeout=None): - #type: (Optional[Dict[str, Any]], Optional[int]) -> Dict[str, Any] + + def set_file_metadata(self, metadata=None, timeout=None, **kwargs): + #type: (Optional[Dict[str, Any]], Optional[int], Optional[Any]) -> Dict[str, Any] """ :returns: File-updated property dict (Etag and last modified). :rtype: dict(str, Any) """ + headers = kwargs.pop('headers', {}) + headers.update(add_metadata_headers(metadata)) + try: + return self._client.file.set_metadata( + timeout=timeout, + cls=return_response_headers, + headers=headers, + metadata=metadata, + **kwargs) + except StorageErrorException as error: + process_storage_error(error) def update_ranges( self, start_range=None, # type: Optional[int] @@ -178,7 +392,7 @@ def update_ranges( ): # type: (...) -> List[dict[str, int]] """ - Returns the list of valid ranges of a file. + Returns dict with etag and last-modified. :param int start_range: Start of byte range to use for getting valid page ranges. If no end_range is given, all bytes after the start_range will be searched. @@ -198,8 +412,25 @@ def update_ranges( :type validate_content: bool :param int timeout: The timeout parameter is expressed in seconds. - :returns: A list of page ranges. + :returns: Returns dict with etag and last-modified . """ + if self.require_encryption or (self.key_encryption_key is not None): + raise ValueError("Unsupported method for encryption.") + + if start_range is None or start_range % 512 != 0: + raise ValueError("start_range must be an integer that aligns with 512 file size") + if end_range is None or end_range % 512 != 511: + raise ValueError("end_range must be an integer that aligns with 512 file size") + content_range = 'bytes={0}-{1}'.format(start_range, end_range) + try: + return self._client.file.upload_range( + range=content_range, + timeout=timeout, + cls=return_response_headers, + **kwargs) + except StorageErrorException as error: + process_storage_error(error) + def get_ranges( self, start_range=None, # type: Optional[int] @@ -228,6 +459,22 @@ def get_ranges( The timeout parameter is expressed in seconds. :returns: A list of page ranges. """ + if self.require_encryption or (self.key_encryption_key is not None): + raise ValueError("Unsupported method for encryption.") + + if start_range is None or start_range % 512 != 0: + raise ValueError("start_range must be an integer that aligns with 512 file size") + if end_range is None or end_range % 512 != 511: + raise ValueError("end_range must be an integer that aligns with 512 file size") + content_range = 'bytes={0}-{1}'.format(start_range, end_range) + try: + return self._client.file.get_range_list( + timeout=timeout, + cls=return_response_headers, + range=content_range, + **kwargs) + except StorageErrorException as error: + process_storage_error(error) def clear_range( self, start_range, # type: int @@ -250,9 +497,27 @@ def clear_range( :returns: File-updated property dict (Etag and last modified). :rtype: Dict[str, Any] """ + if self.require_encryption or (self.key_encryption_key is not None): + raise ValueError("Unsupported method for encryption.") + + if start_range is None or start_range % 512 != 0: + raise ValueError("start_range must be an integer that aligns with 512 file size") + if end_range is None or end_range % 512 != 511: + raise ValueError("end_range must be an integer that aligns with 512 file size") + content_range = 'bytes={0}-{1}'.format(start_range, end_range) + try: + return self._client.file.upload_range( + timeout=timeout, + cls=return_response_headers, + content_length=0, + file_range_write="clear", + range=content_range, + **kwargs) + except StorageErrorException as error: + process_storage_error(error) - def resize_file(self, size, timeout=None): - # type: (int, Optional[int]) -> Dict[str, Any] + def resize_file(self, size, timeout=None, **kwargs): + # type: (int, Optional[int], Optional[Any]) -> Dict[str, Any] """Resizes a file to the specified size. :param int size: Size to resize file to. @@ -261,3 +526,11 @@ def resize_file(self, size, timeout=None): :returns: File-updated property dict (Etag and last modified). :rtype: Dict[str, Any] """ + try: + return self._client.file.set_http_headers( + timeout=timeout, + file_content_length=size, + cls=return_response_headers, + **kwargs) + except StorageErrorException as error: + process_storage_error(error) diff --git a/sdk/storage/azure-storage-file/azure/storage/file/models.py b/sdk/storage/azure-storage-file/azure/storage/file/models.py index cac96f3b0564..17fcb0014eb5 100644 --- a/sdk/storage/azure-storage-file/azure/storage/file/models.py +++ b/sdk/storage/azure-storage-file/azure/storage/file/models.py @@ -106,3 +106,117 @@ def __next__(self): return ShareProperties._from_generated(item) # pylint: disable=protected-access next = __next__ + +class DirectoryProperties(DictMixin): + """Directory's properties class. + :param datetime last_modified: + A datetime object representing the last time the directory was modified. + :param str etag: + The ETag contains a value that you can use to perform operations + conditionally. + :param dict metadata: A dict with name_value pairs to associate with the + directory as metadata. + """ + + def __init__(self, **kwargs): + self.name = None + self.last_modified = kwargs.get('Last-Modified') + self.etag = kwargs.get('ETag') + self.is_server_encrypted = kwargs.get('is_server_encrypted') + self.metadata = kwargs.get('metadata') + + @classmethod + def _from_generated(cls, generated): + props = cls() + props.name = generated.name + props.last_modified = generated.properties.last_modified + props.etag = generated.properties.etag + props.is_server_encrypted = generated.properties.is_server_encrypted + props.metadata = generated.metadata + return props + +class DirectoryPropertiesPaged(Paged): + """Directory properties paged. + :param callable command: Function to retrieve the next page of items. + :param str prefix: Filters the results to return only directors whose names + begin with the specified prefix. + :param int results_per_page: The maximum number of share names to retrieve per + call. + :param str marker: An opaque continuation token. + """ + def __init__(self, command, prefix=None, results_per_page=None, marker=None, **kwargs): + super(DirectoryPropertiesPaged, self).__init__(command, None) + self.service_endpoint = None + self.prefix = prefix + self.current_marker = None + self.results_per_page = results_per_page + self.next_marker = marker or "" + self.location_mode = None + + def _advance_page(self): + # type: () -> List[Model] + """Force moving the cursor to the next azure call. + This method is for advanced usage, iterator protocol is prefered. + :raises: StopIteration if no further page + :return: The current page list + :rtype: list + """ + if self.next_marker is None: + raise StopIteration("End of paging") + self._current_page_iter_index = 0 + try: + self.location_mode, self._response = self._get_next( + marker=self.next_marker or None, + maxresults=self.results_per_page, + cls=return_context_and_deserialized, + use_location=self.location_mode) + except StorageErrorException as error: + process_storage_error(error) + + self.service_endpoint = self._response.service_endpoint + self.prefix = self._response.prefix + self.current_marker = self._response.marker + self.results_per_page = self._response.max_results + self.current_page = self._response.directory_items + self.next_marker = self._response.next_marker or None + return self.current_page + + def __next__(self): + item = super(DirectoryPropertiesPaged, self).__next__() + if isinstance(item, DirectoryProperties): + return item + return DirectoryProperties._from_generated(item) # pylint: disable=protected-access + + next = __next__ + +class FileProperty(DictMixin): + """File's properties class. + :param datetime last_modified: + A datetime object representing the last time the file was modified. + :param str etag: + The ETag contains a value that you can use to perform operations + conditionally. + :param int quota: + The allocated quota. + :param str public_access: Specifies whether data in the file may be accessed + publicly and the level of access. + :param bool has_immutability_policy: + Represents whether the file has an immutability policy. + :param bool has_legal_hold: + Represents whether the file has a legal hold. + :param dict metadata: A dict with name_value pairs to associate with the + file as metadata. + """ + + def __init__(self, **kwargs): + self.name = None + self.content_length = kwargs.get('content_length') + self.metadata = kwargs.get('metadata') + + @classmethod + def _from_generated(cls, generated): + props = cls() + props.name = generated.name + props.content_length = generated.properties.content_length + props.metadata = generated.properties.metadata + return props diff --git a/sdk/storage/azure-storage-file/azure/storage/file/polling.py b/sdk/storage/azure-storage-file/azure/storage/file/polling.py new file mode 100644 index 000000000000..ea1afde8f773 --- /dev/null +++ b/sdk/storage/azure-storage-file/azure/storage/file/polling.py @@ -0,0 +1,132 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import logging +import time + +from azure.core.polling import PollingMethod, LROPoller + +from ._shared.utils import process_storage_error +from ._generated.models import StorageErrorException +from ._share_utils import deserialize_file_properties + + +logger = logging.getLogger(__name__) + + +class CopyStatusPoller(LROPoller): + + def __init__(self, client, copy_id, polling=True, configuration=None, **kwargs): + if configuration: + polling_interval = configuration.file_settings.copy_polling_interval + else: + polling_interval = 2 + polling_method = CopyFilePolling if polling else CopyFile + poller = polling_method(polling_interval, **kwargs) + super(CopyStatusPoller, self).__init__(client, copy_id, None, poller) + + def copy_id(self): + return self._polling_method.id + + def abort(self): # Check whether this is in API guidelines, or should remain specific to Storage + return self._polling_method.abort() + + +class CopyFile(PollingMethod): + """An empty poller that returns the deserialized initial response. + """ + def __init__(self, interval, **kwargs): + self._client = None + self._status = None + self._exception = None + self.id = None + self.etag = None + self.last_modified = None + self.polling_interval = interval + self.kwargs = kwargs + self.file = None + + def _update_status(self): + try: + self.file = self._client._client.file.get_properties( # pylint: disable=protected-access + cls=deserialize_file_properties, **self.kwargs) + except StorageErrorException as error: + process_storage_error(error) + self._status = self.file.copy.status + self.etag = self.file.etag + self.last_modified = self.file.last_modified + + def initialize(self, client, initial_status, _): # pylint: disable=arguments-differ + # type: (Any, requests.Response, Callable) -> None + self._client = client + if isinstance(initial_status, str): + self.id = initial_status + self._update_status() + else: + self._status = initial_status['copy_status'] + self.id = initial_status['copy_id'] + self.etag = initial_status['etag'] + self.last_modified = initial_status['last_modified'] + + def run(self): + # type: () -> None + """Empty run, no polling. + """ + + def abort(self): + try: + return self._client._client.file.abort_copy_from_url( # pylint: disable=protected-access + self.id, **self.kwargs) + except StorageErrorException as error: + process_storage_error(error) + + def status(self): + self._update_status() + return self._status + + def finished(self): + # type: () -> bool + """Is this polling finished? + :rtype: bool + """ + return str(self.status()).lower() in ['success', 'aborted', 'failed'] + + def resource(self): + # type: () -> Any + self._update_status() + return self.file + + +class CopyFilePolling(CopyFile): + + def run(self): + # type: () -> None + try: + while not self.finished(): + self._update_status() + time.sleep(self.polling_interval) + if str(self.status()).lower() == 'aborted': + raise ValueError("Copy operation aborted.") + if str(self.status()).lower() == 'failed': + raise ValueError("Copy operation failed: {}".format(self.file.copy.status_description)) + except Exception as e: + logger.warning(str(e)) + raise + + def status(self): + # type: () -> str + """Return the current status as a string. + :rtype: str + """ + try: + return self._status.value + except AttributeError: + return self._status + + def resource(self): + # type: () -> Any + if not self.file: + self._update_status() + return self.file