Skip to content

Commit

Permalink
Hookup File Service Client (#5977)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rakshith Bhyravabhotla authored Jun 20, 2019
1 parent ee7a6a2 commit f4d1a6d
Show file tree
Hide file tree
Showing 6 changed files with 264 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,53 @@ class StorageErrorCode(str, Enum):
cannot_delete_file_or_directory = "CannotDeleteFileOrDirectory"
file_lock_conflict = "FileLockConflict"
invalid_file_or_directory_path_name = "InvalidFileOrDirectoryPathName"

class DictMixin(object):

def __setitem__(self, key, item):
self.__dict__[key] = item

def __getitem__(self, key):
return self.__dict__[key]

def __repr__(self):
return str(self)

def __len__(self):
return len(self.keys())

def __delitem__(self, key):
self.__dict__[key] = None

def __eq__(self, other):
"""Compare objects by comparing all attributes."""
if isinstance(other, self.__class__):
return self.__dict__ == other.__dict__
return False

def __ne__(self, other):
"""Compare objects by comparing all attributes."""
return not self.__eq__(other)

def __str__(self):
return str({k: v for k, v in self.__dict__.items() if not k.startswith('_')})

def has_key(self, k):
return k in self.__dict__

def update(self, *args, **kwargs):
return self.__dict__.update(*args, **kwargs)

def keys(self):
return [k for k in self.__dict__ if not k.startswith('_')]

def values(self):
return [v for k, v in self.__dict__.items() if not k.startswith('_')]

def items(self):
return [(k, v) for k, v in self.__dict__.items() if not k.startswith('_')]

def get(self, key, default=None):
if key in self.__dict__:
return self.__dict__[key]
return default
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datetime import date

from .constants import X_MS_VERSION
from .utils import _sign_string, url_quote, _QueryStringConstants
from .utils import _sign_string, url_quote, _QueryStringConstants, _to_str


if sys.version_info < (3,):
Expand Down
28 changes: 24 additions & 4 deletions sdk/storage/azure-storage-file/azure/storage/file/_shared/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import base64
import hashlib
import hmac
import sys
import logging
from os import fstat
from io import (SEEK_END, SEEK_SET, UnsupportedOperation)
Expand Down Expand Up @@ -65,6 +66,19 @@

_LOGGER = logging.getLogger(__name__)

if sys.version_info < (3,):
def _str(value):
if isinstance(value, str):
return value.encode('utf-8')

return str(value)
else:
_str = str


def _to_str(value):
return _str(value) if value is not None else None


class _QueryStringConstants(object):
SIGNED_SIGNATURE = 'sig'
Expand Down Expand Up @@ -235,17 +249,23 @@ def format_shared_key_credential(account, credential):
return credential


def parse_connection_str(conn_str, credentials=None):
def parse_connection_str(conn_str, credential=None):
conn_settings = dict([s.split('=', 1) for s in conn_str.split(';')])
if not credential:
try:
credential = {
'account_name': conn_settings['AccountName'],
'account_key': conn_settings['AccountKey']
}
except KeyError:
credential = conn_settings.get('SharedAccessSignature')
try:
account_url = "{}://{}.file.{}".format(
conn_settings['DefaultEndpointsProtocol'],
conn_settings['AccountName'],
conn_settings['EndpointSuffix']
)
creds = credentials or SharedKeyCredentials(
conn_settings['AccountName'], conn_settings['AccountKey'])
return account_url, creds
return account_url, credential
except KeyError as error:
raise ValueError("Connection string missing setting: '{}'".format(error.args[0]))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,31 @@
# license information.
# --------------------------------------------------------------------------

import functools
from typing import ( # pylint: disable=unused-import
Union, Optional, Any, Iterable, Dict, List,
TYPE_CHECKING
)
try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse

class FileServiceClient():
from .share_client import ShareClient
from ._shared.shared_access_signature import SharedAccessSignature
from ._shared.utils import (
StorageAccountHostsMixin,
return_response_headers,
parse_connection_str,
process_storage_error,
parse_query)

from .models import SharePropertiesPaged
from ._generated import AzureFileStorage
from ._generated.models import StorageErrorException, StorageServiceProperties
from ._generated.version import VERSION

class FileServiceClient(StorageAccountHostsMixin):
""" A client interact with the File Service at the account level.
This client provides operations to retrieve and configure the account properties
Expand Down Expand Up @@ -35,7 +58,7 @@ class FileServiceClient():

def __init__(
self, account_url, # type: str
credentials=None, # type: Optional[Any]
credential=None, # type: Optional[Any]
configuration=None, # type: Optional[Configuration]
**kwargs # type: Any
):
Expand All @@ -47,21 +70,37 @@ def __init__(
in the URL path (e.g. share or file) will be discarded. This URL can be optionally
authenticated with a SAS token.
:param credential:
The credentials with which to authenticate. This is optional if the
The credential with which to authenticate. This is optional if the
account URL already has a SAS token. The value can be a SAS token string, and account
shared access key, or an instance of a TokenCredentials class from azure.identity.
:param ~azure.storage.file.Configuration configuration:
An optional pipeline configuration.
"""
try:
if not account_url.lower().startswith('http'):
account_url = "https://" + account_url
except AttributeError:
raise ValueError("Account URL must be a string.")
parsed_url = urlparse(account_url.rstrip('/'))
if not parsed_url.netloc:
raise ValueError("Invalid URL: {}".format(account_url))

_, sas_token = parse_query(parsed_url.query)
self._query_str, credential = self._format_query_string(sas_token, credential)
super(FileServiceClient, self).__init__(parsed_url, credential, configuration, **kwargs)
self.url = account_url if not parsed_url.path else self._format_url(parsed_url.hostname)
self._client = AzureFileStorage(version=VERSION, url=self.url, pipeline=self._pipeline)

def _format_url(self, hostname):
"""Format the endpoint URL according to the current location
mode hostname.
"""
return "{}://{}/{}".format(self.scheme, hostname, self._query_str)

@classmethod
def from_connection_string(
cls, conn_str, # type: str
credential=None, # type: Optional[Any]
configuration=None, # type: Optional[Configuration]
**kwargs # type: Any
):
Expand All @@ -70,14 +109,16 @@ def from_connection_string(
:param str conn_str:
A connection string to an Azure Storage account.
:param credential:
The credentials with which to authenticate. This is optional if the
The credential with which to authenticate. This is optional if the
account URL already has a SAS token, or the connection string already has shared
access key values. The value can be a SAS token string, and account shared access
key, or an instance of a TokenCredentials class from azure.identity.
:param configuration:
Optional pipeline configuration settings.
:type configuration: ~azure.core.configuration.Configuration
"""
account_url, credential = parse_connection_str(conn_str, credential)
return cls(account_url, credential=credential, configuration=configuration, **kwargs)

def generate_shared_access_signature(
self, resource_types, # type: Union[ResourceTypes, str]
Expand Down Expand Up @@ -126,40 +167,12 @@ def generate_shared_access_signature(
:return: A Shared Access Signature (sas) token.
:rtype: str
"""
if not hasattr(self.credential, 'account_key') and not self.credential.account_key:
raise ValueError("No account SAS key available.")

def get_account_information(self, **kwargs):
# type: (Optional[int]) -> Dict[str, str]
"""Gets information related to the storage account.
The information can also be retrieved if the user has a SAS to a share or file.
:returns: A dict of account information (SKU and account type).
:rtype: dict(str, str)
"""

def get_service_stats(self, timeout=None, **kwargs):
# type: (Optional[int], **Any) -> Dict[str, Any]
"""Retrieves statistics related to replication for the File service. It is
only available when read-access geo-redundant replication is enabled for
the storage account.
With geo-redundant replication, Azure Storage maintains your data durable
in two locations. In both locations, Azure Storage constantly maintains
multiple healthy replicas of your data. The location where you read,
create, update, or delete data is the primary storage account location.
The primary location exists in the region you choose at the time you
create an account via the Azure Management Azure classic portal, for
example, North Central US. The location to which your data is replicated
is the secondary location. The secondary location is automatically
determined based on the location of the primary; it is in a second data
center that resides in the same region as the primary location. Read-only
access is available from the secondary location, if read-access geo-redundant
replication is enabled for your storage account.
:param int timeout:
The timeout parameter is expressed in seconds.
:return: The file service stats.
:rtype: ~azure.storage.file._generated.models.StorageServiceStats
"""
sas = SharedAccessSignature(self.credential.account_name, self.credential.account_key)
return sas.generate_account(resource_types, permission,
expiry, start=start, ip=ip, protocol=protocol)

def get_service_properties(self, timeout=None, **kwargs):
# type(Optional[int]) -> Dict[str, Any]
Expand All @@ -170,6 +183,10 @@ def get_service_properties(self, timeout=None, **kwargs):
The timeout parameter is expressed in seconds.
:rtype: ~azure.storage.file._generated.models.StorageServiceProperties
"""
try:
return self._client.service.get_properties(timeout=timeout, **kwargs)
except StorageErrorException as error:
process_storage_error(error)

def set_service_properties(
self, logging=None, # type: Optional[Logging]
Expand Down Expand Up @@ -220,6 +237,16 @@ def set_service_properties(
The timeout parameter is expressed in seconds.
:rtype: None
"""
props = StorageServiceProperties(
logging=logging,
hour_metrics=hour_metrics,
minute_metrics=minute_metrics,
cors=cors
)
try:
self._client.service.set_properties(props, timeout=timeout, **kwargs)
except StorageErrorException as error:
process_storage_error(error)

def list_shares(
self, prefix=None, # type: Optional[str]
Expand All @@ -245,6 +272,14 @@ def list_shares(
:returns: An iterable (auto-paging) of ShareProperties.
:rtype: ~azure.core.file.models.SharePropertiesPaged
"""
include = 'metadata' if include_metadata else None
command = functools.partial(
self._client.service.list_shares_segment,
prefix=prefix,
include=include,
timeout=timeout,
**kwargs)
return SharePropertiesPaged(command, prefix=prefix, **kwargs)

def create_share(
self, share_name, # type: str
Expand All @@ -269,6 +304,9 @@ def create_share(
The timeout parameter is expressed in seconds.
:rtype: ~azure.storage.file.share_client.ShareClient
"""
share = self.get_share_client(share_name)
share.create_share(metadata, quota, timeout, **kwargs)
return share

def delete_share(
self, share_name, # type: Union[ShareProperties, str]
Expand All @@ -291,6 +329,8 @@ def delete_share(
The timeout parameter is expressed in seconds.
:rtype: None
"""
share = self.get_share_client(share_name)
share.delete_share(delete_snapshots, timeout, **kwargs)

def get_share_client(self, share_name, snapshot=None):
# type: (Union[ShareProperties, str],Optional[Union[SnapshotProperties, str]]) -> ShareClient
Expand All @@ -304,3 +344,6 @@ def get_share_client(self, share_name, snapshot=None):
:returns: A ShareClient.
:rtype: ~azure.core.file.share_client.ShareClient
"""
return ShareClient(
self.url, share_name=share_name, snapshot=snapshot,
credential=self.credential, configuration=self._config)
Loading

0 comments on commit f4d1a6d

Please sign in to comment.