From a1039c8568e60d1db7e48eb15bb83bb31827eb92 Mon Sep 17 00:00:00 2001 From: Xiaoxi Fu <49707495+xiafu-msft@users.noreply.github.com> Date: Wed, 12 Aug 2020 13:40:29 -0700 Subject: [PATCH] [Storage][Blob][Bug]Support parsing blob url with '/' in blob name (#12619) * [Storage][Blob][Bug]Support parsing blob url with '/' in blob name * [Storage]support parsing emulator url * fix pylint * make format_shared_key_credential to private * fix test --- .../azure/storage/blob/_blob_client.py | 20 +++++++---- .../azure/storage/blob/_shared/base_client.py | 15 +++++--- .../tests/test_blob_client.py | 26 ++++++++++++++ .../tests/test_blob_client_async.py | 11 ++++++ .../tests/test_largest_block_blob.py | 12 +++---- .../tests/test_largest_block_blob_async.py | 12 +++---- .../filedatalake/_shared/base_client.py | 15 +++++--- .../tests/test_large_file.py | 4 +-- .../tests/test_large_file_async.py | 4 +-- .../azure/storage/fileshare/_share_client.py | 35 ++++++++++++++----- .../storage/fileshare/_shared/base_client.py | 15 +++++--- .../tests/test_share.py | 9 +++++ .../storage/queue/_shared/base_client.py | 15 +++++--- 13 files changed, 143 insertions(+), 50 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py index e527e4ab88d7..b5366051d7ac 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py @@ -13,8 +13,8 @@ try: from urllib.parse import urlparse, quote, unquote except ImportError: - from urlparse import urlparse # type: ignore - from urllib2 import quote, unquote # type: ignore + from urlparse import urlparse # type: ignore + from urllib2 import quote, unquote # type: ignore import six from azure.core.tracing.decorator import distributed_trace @@ -180,7 +180,7 @@ def _format_url(self, hostname): @classmethod def from_blob_url(cls, blob_url, credential=None, snapshot=None, **kwargs): # type: (str, Optional[Any], Optional[Union[str, Dict[str, Any]]], Any) -> BlobClient - """Create BlobClient from a blob url. + """Create BlobClient from a blob url. This doesn't support customized blob url with '/' in blob name. :param str blob_url: The full endpoint URL to the Blob, including SAS token and snapshot if used. This could be @@ -209,10 +209,18 @@ def from_blob_url(cls, blob_url, credential=None, snapshot=None, **kwargs): if not parsed_url.netloc: raise ValueError("Invalid URL: {}".format(blob_url)) - path_blob = parsed_url.path.lstrip('/').split('/') account_path = "" - if len(path_blob) > 2: - account_path = "/" + "/".join(path_blob[:-2]) + if ".core." in parsed_url.netloc: + # .core. is indicating non-customized url. Blob name with directory info can also be parsed. + path_blob = parsed_url.path.lstrip('/').split('/', 1) + elif "localhost" in parsed_url.netloc or "127.0.0.1" in parsed_url.netloc: + path_blob = parsed_url.path.lstrip('/').split('/', 2) + account_path += path_blob[0] + else: + # for customized url. blob name that has directory info cannot be parsed. + path_blob = parsed_url.path.lstrip('/').split('/') + if len(path_blob) > 2: + account_path = "/" + "/".join(path_blob[:-2]) account_url = "{}://{}{}?{}".format( parsed_url.scheme, parsed_url.netloc.rstrip('/'), diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py index 79bab02d80f5..361931ae1656 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py @@ -84,12 +84,17 @@ def __init__( raise ValueError("Invalid service: {}".format(service)) service_name = service.split('-')[0] account = parsed_url.netloc.split(".{}.core.".format(service_name)) + self.account_name = account[0] if len(account) > 1 else None - secondary_hostname = None + if not self.account_name and parsed_url.netloc.startswith("localhost") \ + or parsed_url.netloc.startswith("127.0.0.1"): + self.account_name = parsed_url.path.strip("/") - self.credential = format_shared_key_credential(account, credential) + self.credential = _format_shared_key_credential(self.account_name, credential) if self.scheme.lower() != "https" and hasattr(self.credential, "get_token"): raise ValueError("Token credential is only supported with HTTPS.") + + secondary_hostname = None if hasattr(self.credential, "account_name"): self.account_name = self.credential.account_name secondary_hostname = "{}-secondary.{}.{}".format( @@ -326,11 +331,11 @@ def __exit__(self, *args): # pylint: disable=arguments-differ pass -def format_shared_key_credential(account, credential): +def _format_shared_key_credential(account_name, credential): if isinstance(credential, six.string_types): - if len(account) < 2: + if not account_name: raise ValueError("Unable to determine account name for shared key credential.") - credential = {"account_name": account[0], "account_key": credential} + credential = {"account_name": account_name, "account_key": credential} if isinstance(credential, dict): if "account_name" not in credential: raise ValueError("Shared key credential missing 'account_name") diff --git a/sdk/storage/azure-storage-blob/tests/test_blob_client.py b/sdk/storage/azure-storage-blob/tests/test_blob_client.py index 5a333e99955d..c82ba195c256 100644 --- a/sdk/storage/azure-storage-blob/tests/test_blob_client.py +++ b/sdk/storage/azure-storage-blob/tests/test_blob_client.py @@ -382,6 +382,7 @@ def test_create_service_with_cstr_succeeds_if_sec_with_prim(self, resource_group self.assertTrue(service.primary_endpoint.startswith('https://www.mydomain.com/')) self.assertTrue(service.secondary_endpoint.startswith('https://www-sec.mydomain.com/')) + def test_create_service_with_custom_account_endpoint_path(self): account_name = "blobstorage" account_key = "blobkey" @@ -438,6 +439,31 @@ def test_create_service_with_custom_account_endpoint_path(self): self.assertEqual(service.primary_hostname, 'local-machine:11002/custom/account/path') self.assertEqual(service.url, 'http://local-machine:11002/custom/account/path/foo/bar?snapshot=baz') + def test_create_blob_client_with_sub_directory_path_in_blob_name(self): + blob_url = "https://testaccount.blob.core.windows.net/containername/dir1/sub000/2010_Unit150_Ivan097_img0003.jpg" + blob_client = BlobClient.from_blob_url(blob_url) + self.assertEqual(blob_client.container_name, "containername") + self.assertEqual(blob_client.blob_name, "dir1/sub000/2010_Unit150_Ivan097_img0003.jpg") + + blob_emulator_url = 'http://127.0.0.1:1000/devstoreaccount1/containername/dir1/sub000/2010_Unit150_Ivan097_img0003.jpg' + blob_client = BlobClient.from_blob_url(blob_emulator_url) + self.assertEqual(blob_client.container_name, "containername") + self.assertEqual(blob_client.blob_name, "dir1/sub000/2010_Unit150_Ivan097_img0003.jpg") + + def test_create_client_for_emulator(self): + container_client = ContainerClient( + account_url='http://127.0.0.1:1000/devstoreaccount1', + container_name='newcontainer', + credential='Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==') + + self.assertEqual(container_client.container_name, "newcontainer") + self.assertEqual(container_client.account_name, "devstoreaccount1") + + ContainerClient.from_container_url('http://127.0.0.1:1000/devstoreaccount1/newcontainer') + self.assertEqual(container_client.container_name, "newcontainer") + self.assertEqual(container_client.account_name, "devstoreaccount1") + + @GlobalStorageAccountPreparer() def test_request_callback_signed_header(self, resource_group, location, storage_account, storage_account_key): # Arrange diff --git a/sdk/storage/azure-storage-blob/tests/test_blob_client_async.py b/sdk/storage/azure-storage-blob/tests/test_blob_client_async.py index 44f98d93fc4e..0de666766538 100644 --- a/sdk/storage/azure-storage-blob/tests/test_blob_client_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_blob_client_async.py @@ -449,6 +449,17 @@ def test_create_service_with_custom_account_endpoint_path(self): self.assertEqual(service.primary_hostname, 'local-machine:11002/custom/account/path') self.assertEqual(service.url, 'http://local-machine:11002/custom/account/path/foo/bar?snapshot=baz') + def test_create_blob_client_with_sub_directory_path_in_blob_name(self): + blob_url = "https://testaccount.blob.core.windows.net/containername/dir1/sub000/2010_Unit150_Ivan097_img0003.jpg" + blob_client = BlobClient.from_blob_url(blob_url) + self.assertEqual(blob_client.container_name, "containername") + self.assertEqual(blob_client.blob_name, "dir1/sub000/2010_Unit150_Ivan097_img0003.jpg") + + blob_emulator_url = 'http://127.0.0.1:1000/devstoreaccount1/containername/dir1/sub000/2010_Unit150_Ivan097_img0003.jpg' + blob_client = BlobClient.from_blob_url(blob_emulator_url) + self.assertEqual(blob_client.container_name, "containername") + self.assertEqual(blob_client.blob_name, "dir1/sub000/2010_Unit150_Ivan097_img0003.jpg") + @GlobalStorageAccountPreparer() @AsyncStorageTestCase.await_prepared_test async def test_request_callback_signed_header_async(self, resource_group, location, storage_account, storage_account_key): diff --git a/sdk/storage/azure-storage-blob/tests/test_largest_block_blob.py b/sdk/storage/azure-storage-blob/tests/test_largest_block_blob.py index 0fa722338cb5..b03814c4c486 100644 --- a/sdk/storage/azure-storage-blob/tests/test_largest_block_blob.py +++ b/sdk/storage/azure-storage-blob/tests/test_largest_block_blob.py @@ -17,7 +17,7 @@ BlobServiceClient, BlobBlock ) -from azure.storage.blob._shared.base_client import format_shared_key_credential +from azure.storage.blob._shared.base_client import _format_shared_key_credential from _shared.testcase import StorageTestCase, GlobalStorageAccountPreparer @@ -97,7 +97,7 @@ def test_put_block_bytes_largest(self, resource_group, location, storage_account @GlobalStorageAccountPreparer() def test_put_block_bytes_largest_without_network(self, resource_group, location, storage_account, storage_account_key): payload_dropping_policy = PayloadDroppingPolicy() - credential_policy = format_shared_key_credential([storage_account.name, "dummy"], storage_account_key) + credential_policy = _format_shared_key_credential(storage_account.name, storage_account_key) self._setup(storage_account, storage_account_key, [payload_dropping_policy, credential_policy]) blob = self._create_blob() @@ -157,7 +157,7 @@ def test_put_block_stream_largest(self, resource_group, location, storage_accoun @GlobalStorageAccountPreparer() def test_put_block_stream_largest_without_network(self, resource_group, location, storage_account, storage_account_key): payload_dropping_policy = PayloadDroppingPolicy() - credential_policy = format_shared_key_credential([storage_account.name, "dummy"], storage_account_key) + credential_policy = _format_shared_key_credential(storage_account.name, storage_account_key) self._setup(storage_account, storage_account_key, [payload_dropping_policy, credential_policy]) blob = self._create_blob() @@ -211,7 +211,7 @@ def test_create_largest_blob_from_path(self, resource_group, location, storage_a @GlobalStorageAccountPreparer() def test_create_largest_blob_from_path_without_network(self, resource_group, location, storage_account, storage_account_key): payload_dropping_policy = PayloadDroppingPolicy() - credential_policy = format_shared_key_credential([storage_account.name, "dummy"], storage_account_key) + credential_policy = _format_shared_key_credential(storage_account.name, storage_account_key) self._setup(storage_account, storage_account_key, [payload_dropping_policy, credential_policy]) blob_name = self._get_blob_reference() blob = self.bsc.get_blob_client(self.container_name, blob_name) @@ -237,7 +237,7 @@ def test_create_largest_blob_from_path_without_network(self, resource_group, loc @GlobalStorageAccountPreparer() def test_create_largest_blob_from_stream_without_network(self, resource_group, location, storage_account, storage_account_key): payload_dropping_policy = PayloadDroppingPolicy() - credential_policy = format_shared_key_credential([storage_account.name, "dummy"], storage_account_key) + credential_policy = _format_shared_key_credential(storage_account.name, storage_account_key) self._setup(storage_account, storage_account_key, [payload_dropping_policy, credential_policy]) blob_name = self._get_blob_reference() blob = self.bsc.get_blob_client(self.container_name, blob_name) @@ -257,7 +257,7 @@ def test_create_largest_blob_from_stream_without_network(self, resource_group, l @GlobalStorageAccountPreparer() def test_create_largest_blob_from_stream_single_upload_without_network(self, resource_group, location, storage_account, storage_account_key): payload_dropping_policy = PayloadDroppingPolicy() - credential_policy = format_shared_key_credential([storage_account.name, "dummy"], storage_account_key) + credential_policy = _format_shared_key_credential(storage_account.name, storage_account_key) self._setup(storage_account, storage_account_key, [payload_dropping_policy, credential_policy], max_single_put_size=LARGEST_SINGLE_UPLOAD_SIZE) blob_name = self._get_blob_reference() diff --git a/sdk/storage/azure-storage-blob/tests/test_largest_block_blob_async.py b/sdk/storage/azure-storage-blob/tests/test_largest_block_blob_async.py index 6d782f1d419a..e58c5017430e 100644 --- a/sdk/storage/azure-storage-blob/tests/test_largest_block_blob_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_largest_block_blob_async.py @@ -23,7 +23,7 @@ from azure.storage.blob import ( BlobBlock ) -from azure.storage.blob._shared.base_client import format_shared_key_credential +from azure.storage.blob._shared.base_client import _format_shared_key_credential from azure.storage.blob._shared.constants import CONNECTION_TIMEOUT, READ_TIMEOUT from _shared.asynctestcase import AsyncStorageTestCase @@ -123,7 +123,7 @@ async def test_put_block_bytes_largest(self, resource_group, location, storage_a @AsyncStorageTestCase.await_prepared_test async def test_put_block_bytes_largest_without_network(self, resource_group, location, storage_account, storage_account_key): payload_dropping_policy = PayloadDroppingPolicy() - credential_policy = format_shared_key_credential([storage_account.name, "dummy"], storage_account_key) + credential_policy = _format_shared_key_credential(storage_account.name, storage_account_key) await self._setup(storage_account, storage_account_key, [payload_dropping_policy, credential_policy]) blob = await self._create_blob() @@ -185,7 +185,7 @@ async def test_put_block_stream_largest(self, resource_group, location, storage_ @AsyncStorageTestCase.await_prepared_test async def test_put_block_stream_largest_without_network(self, resource_group, location, storage_account, storage_account_key): payload_dropping_policy = PayloadDroppingPolicy() - credential_policy = format_shared_key_credential([storage_account.name, "dummy"], storage_account_key) + credential_policy = _format_shared_key_credential(storage_account.name, storage_account_key) await self._setup(storage_account, storage_account_key, [payload_dropping_policy, credential_policy]) blob = await self._create_blob() @@ -241,7 +241,7 @@ async def test_create_largest_blob_from_path(self, resource_group, location, sto @AsyncStorageTestCase.await_prepared_test async def test_create_largest_blob_from_path_without_network(self, resource_group, location, storage_account, storage_account_key): payload_dropping_policy = PayloadDroppingPolicy() - credential_policy = format_shared_key_credential([storage_account.name, "dummy"], storage_account_key) + credential_policy = _format_shared_key_credential(storage_account.name, storage_account_key) await self._setup(storage_account, storage_account_key, [payload_dropping_policy, credential_policy]) blob_name = self._get_blob_reference() blob = self.bsc.get_blob_client(self.container_name, blob_name) @@ -268,7 +268,7 @@ async def test_create_largest_blob_from_path_without_network(self, resource_grou @AsyncStorageTestCase.await_prepared_test async def test_create_largest_blob_from_stream_without_network(self, resource_group, location, storage_account, storage_account_key): payload_dropping_policy = PayloadDroppingPolicy() - credential_policy = format_shared_key_credential([storage_account.name, "dummy"], storage_account_key) + credential_policy = _format_shared_key_credential(storage_account.name, storage_account_key) await self._setup(storage_account, storage_account_key, [payload_dropping_policy, credential_policy]) blob_name = self._get_blob_reference() blob = self.bsc.get_blob_client(self.container_name, blob_name) @@ -289,7 +289,7 @@ async def test_create_largest_blob_from_stream_without_network(self, resource_gr @AsyncStorageTestCase.await_prepared_test async def test_create_largest_blob_from_stream_single_upload_without_network(self, resource_group, location, storage_account, storage_account_key): payload_dropping_policy = PayloadDroppingPolicy() - credential_policy = format_shared_key_credential([storage_account.name, "dummy"], storage_account_key) + credential_policy = _format_shared_key_credential(storage_account.name, storage_account_key) await self._setup(storage_account, storage_account_key, [payload_dropping_policy, credential_policy], max_single_put_size=LARGEST_SINGLE_UPLOAD_SIZE) blob_name = self._get_blob_reference() diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client.py index 53348f8ecc22..14deea6a977f 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client.py @@ -84,12 +84,17 @@ def __init__( raise ValueError("Invalid service: {}".format(service)) service_name = service.split('-')[0] account = parsed_url.netloc.split(".{}.core.".format(service_name)) + self.account_name = account[0] if len(account) > 1 else None - secondary_hostname = None + if not self.account_name and parsed_url.netloc.startswith("localhost") \ + or parsed_url.netloc.startswith("127.0.0.1"): + self.account_name = parsed_url.path.strip("/") - self.credential = format_shared_key_credential(account, credential) + self.credential = _format_shared_key_credential(self.account_name, credential) if self.scheme.lower() != "https" and hasattr(self.credential, "get_token"): raise ValueError("Token credential is only supported with HTTPS.") + + secondary_hostname = None if hasattr(self.credential, "account_name"): self.account_name = self.credential.account_name secondary_hostname = "{}-secondary.{}.{}".format( @@ -320,11 +325,11 @@ def __exit__(self, *args): # pylint: disable=arguments-differ pass -def format_shared_key_credential(account, credential): +def _format_shared_key_credential(account_name, credential): if isinstance(credential, six.string_types): - if len(account) < 2: + if not account_name: raise ValueError("Unable to determine account name for shared key credential.") - credential = {"account_name": account[0], "account_key": credential} + credential = {"account_name": account_name, "account_key": credential} if isinstance(credential, dict): if "account_name" not in credential: raise ValueError("Shared key credential missing 'account_name") diff --git a/sdk/storage/azure-storage-file-datalake/tests/test_large_file.py b/sdk/storage/azure-storage-file-datalake/tests/test_large_file.py index 7bb1396e8cbe..e540508f2452 100644 --- a/sdk/storage/azure-storage-file-datalake/tests/test_large_file.py +++ b/sdk/storage/azure-storage-file-datalake/tests/test_large_file.py @@ -13,7 +13,7 @@ from azure.core.pipeline.policies import HTTPPolicy from azure.core.exceptions import ResourceExistsError -from azure.storage.blob._shared.base_client import format_shared_key_credential +from azure.storage.blob._shared.base_client import _format_shared_key_credential from azure.storage.filedatalake import DataLakeServiceClient from testcase import ( StorageTestCase, @@ -34,7 +34,7 @@ def setUp(self): super(LargeFileTest, self).setUp() url = self._get_account_url() self.payload_dropping_policy = PayloadDroppingPolicy() - credential_policy = format_shared_key_credential([self.settings.STORAGE_DATA_LAKE_ACCOUNT_NAME, "dummy"], + credential_policy = _format_shared_key_credential(self.settings.STORAGE_DATA_LAKE_ACCOUNT_NAME, self.settings.STORAGE_DATA_LAKE_ACCOUNT_KEY) self.dsc = DataLakeServiceClient(url, credential=self.settings.STORAGE_DATA_LAKE_ACCOUNT_KEY, diff --git a/sdk/storage/azure-storage-file-datalake/tests/test_large_file_async.py b/sdk/storage/azure-storage-file-datalake/tests/test_large_file_async.py index 872aa674a685..3b20b78b860f 100644 --- a/sdk/storage/azure-storage-file-datalake/tests/test_large_file_async.py +++ b/sdk/storage/azure-storage-file-datalake/tests/test_large_file_async.py @@ -15,7 +15,7 @@ from azure.core.exceptions import ResourceExistsError from azure.core.pipeline.policies import SansIOHTTPPolicy -from azure.storage.blob._shared.base_client import format_shared_key_credential +from azure.storage.blob._shared.base_client import _format_shared_key_credential from azure.storage.filedatalake.aio import DataLakeServiceClient from testcase import ( StorageTestCase, @@ -36,7 +36,7 @@ def setUp(self): super(LargeFileTest, self).setUp() url = self._get_account_url() self.payload_dropping_policy = PayloadDroppingPolicy() - credential_policy = format_shared_key_credential([self.settings.STORAGE_DATA_LAKE_ACCOUNT_NAME, "dummy"], + credential_policy = _format_shared_key_credential(self.settings.STORAGE_DATA_LAKE_ACCOUNT_NAME, self.settings.STORAGE_DATA_LAKE_ACCOUNT_KEY) self.dsc = DataLakeServiceClient(url, credential=self.settings.STORAGE_DATA_LAKE_ACCOUNT_KEY, diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_share_client.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_share_client.py index 8dd53cfdfb40..765ec5e6124a 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_share_client.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_share_client.py @@ -110,10 +110,10 @@ def __init__( # type: ignore @classmethod def from_share_url(cls, share_url, # type: str - snapshot=None, # type: Optional[Union[str, Dict[str, Any]]] - credential=None, # type: Optional[Any] - **kwargs # type: Any - ): + snapshot=None, # type: Optional[Union[str, Dict[str, Any]]] + credential=None, # type: Optional[Any] + **kwargs # type: Any + ): # type: (...) -> ShareClient """ :param str share_url: The full URI to the share. @@ -135,12 +135,31 @@ def from_share_url(cls, share_url, # type: str parsed_url = urlparse(share_url.rstrip('/')) if not (parsed_url.path and parsed_url.netloc): raise ValueError("Invalid URL: {}".format(share_url)) - account_url = parsed_url.netloc.rstrip('/') + "?" + parsed_url.query + + share_path = parsed_url.path.lstrip('/').split('/') + account_path = "" + if len(share_path) > 1: + account_path = "/" + "/".join(share_path[:-1]) + account_url = "{}://{}{}?{}".format( + parsed_url.scheme, + parsed_url.netloc.rstrip('/'), + account_path, + parsed_url.query) + + share_name = unquote(share_path[-1]) path_snapshot, _ = parse_query(parsed_url.query) - share_name = unquote(parsed_url.path.lstrip('/')) - snapshot = snapshot or unquote(path_snapshot) + if snapshot: + try: + path_snapshot = snapshot.snapshot # type: ignore + except AttributeError: + try: + path_snapshot = snapshot['snapshot'] # type: ignore + except TypeError: + path_snapshot = snapshot - return cls(account_url, share_name, snapshot, credential, **kwargs) + if not share_name: + raise ValueError("Invalid URL. Please provide a URL with a valid share name") + return cls(account_url, share_name, path_snapshot, credential, **kwargs) def _format_url(self, hostname): """Format the endpoint URL according to the current location diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client.py index 53348f8ecc22..14deea6a977f 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client.py @@ -84,12 +84,17 @@ def __init__( raise ValueError("Invalid service: {}".format(service)) service_name = service.split('-')[0] account = parsed_url.netloc.split(".{}.core.".format(service_name)) + self.account_name = account[0] if len(account) > 1 else None - secondary_hostname = None + if not self.account_name and parsed_url.netloc.startswith("localhost") \ + or parsed_url.netloc.startswith("127.0.0.1"): + self.account_name = parsed_url.path.strip("/") - self.credential = format_shared_key_credential(account, credential) + self.credential = _format_shared_key_credential(self.account_name, credential) if self.scheme.lower() != "https" and hasattr(self.credential, "get_token"): raise ValueError("Token credential is only supported with HTTPS.") + + secondary_hostname = None if hasattr(self.credential, "account_name"): self.account_name = self.credential.account_name secondary_hostname = "{}-secondary.{}.{}".format( @@ -320,11 +325,11 @@ def __exit__(self, *args): # pylint: disable=arguments-differ pass -def format_shared_key_credential(account, credential): +def _format_shared_key_credential(account_name, credential): if isinstance(credential, six.string_types): - if len(account) < 2: + if not account_name: raise ValueError("Unable to determine account name for shared key credential.") - credential = {"account_name": account[0], "account_key": credential} + credential = {"account_name": account_name, "account_key": credential} if isinstance(credential, dict): if "account_name" not in credential: raise ValueError("Shared key credential missing 'account_name") diff --git a/sdk/storage/azure-storage-file-share/tests/test_share.py b/sdk/storage/azure-storage-file-share/tests/test_share.py index 0a75d28465ad..c304e7e10a4d 100644 --- a/sdk/storage/azure-storage-file-share/tests/test_share.py +++ b/sdk/storage/azure-storage-file-share/tests/test_share.py @@ -77,6 +77,15 @@ def _delete_shares(self, prefix=TEST_SHARE_PREFIX): pass # --Test cases for shares ----------------------------------------- + def test_create_share_client(self): + share_client = ShareClient.from_share_url("http://127.0.0.1:11002/account/customized/path/share?snapshot=baz&", credential={"account_name": "myaccount", "account_key": "key"}) + self.assertEqual(share_client.share_name, "share") + self.assertEqual(share_client.snapshot, "baz") + + share_client = ShareClient.from_share_url("http://127.0.0.1:11002/account/share?snapshot=baz&", credential="credential") + self.assertEqual(share_client.share_name, "share") + self.assertEqual(share_client.snapshot, "baz") + @GlobalStorageAccountPreparer() def test_create_share(self, resource_group, location, storage_account, storage_account_key): self._setup(storage_account, storage_account_key) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py index f7656f89b73f..fd3c32ee7fa0 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client.py @@ -84,12 +84,17 @@ def __init__( raise ValueError("Invalid service: {}".format(service)) service_name = service.split('-')[0] account = parsed_url.netloc.split(".{}.core.".format(service_name)) + self.account_name = account[0] if len(account) > 1 else None - secondary_hostname = None + if not self.account_name and parsed_url.netloc.startswith("localhost") \ + or parsed_url.netloc.startswith("127.0.0.1"): + self.account_name = parsed_url.path.strip("/") - self.credential = format_shared_key_credential(account, credential) + self.credential = _format_shared_key_credential(self.account_name, credential) if self.scheme.lower() != "https" and hasattr(self.credential, "get_token"): raise ValueError("Token credential is only supported with HTTPS.") + + secondary_hostname = None if hasattr(self.credential, "account_name"): self.account_name = self.credential.account_name secondary_hostname = "{}-secondary.{}.{}".format( @@ -320,11 +325,11 @@ def __exit__(self, *args): # pylint: disable=arguments-differ pass -def format_shared_key_credential(account, credential): +def _format_shared_key_credential(account_name, credential): if isinstance(credential, six.string_types): - if len(account) < 2: + if not account_name: raise ValueError("Unable to determine account name for shared key credential.") - credential = {"account_name": account[0], "account_key": credential} + credential = {"account_name": account_name, "account_key": credential} if isinstance(credential, dict): if "account_name" not in credential: raise ValueError("Shared key credential missing 'account_name")