diff --git a/sdk/core/azure-core/CHANGELOG.md b/sdk/core/azure-core/CHANGELOG.md index d31de967c9da..dbeecdcc81f9 100644 --- a/sdk/core/azure-core/CHANGELOG.md +++ b/sdk/core/azure-core/CHANGELOG.md @@ -4,6 +4,7 @@ ### Features +- Added `azure.core.utils.parse_connection_string` function to parse connection strings across SDKs, with common validation and support for case insensitive keys. - Supported adding custom policies #16519 ### Bug fixes diff --git a/sdk/core/azure-core/azure/core/__init__.py b/sdk/core/azure-core/azure/core/__init__.py index ddd1d8da4b69..f4543b08c571 100644 --- a/sdk/core/azure-core/azure/core/__init__.py +++ b/sdk/core/azure-core/azure/core/__init__.py @@ -35,7 +35,7 @@ __all__ = [ "PipelineClient", "MatchConditions", - "CaseInsensitiveEnumMeta" + "CaseInsensitiveEnumMeta", ] try: diff --git a/sdk/core/azure-core/azure/core/utils/__init__.py b/sdk/core/azure-core/azure/core/utils/__init__.py new file mode 100644 index 000000000000..13d98f4b9b68 --- /dev/null +++ b/sdk/core/azure-core/azure/core/utils/__init__.py @@ -0,0 +1,36 @@ +# -------------------------------------------------------------------------- +# +# Copyright (c) Microsoft Corporation. All rights reserved. +# +# The MIT License (MIT) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the ""Software""), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +# -------------------------------------------------------------------------- +""" + +This `utils` module provides functionality that is intended to be used by developers +building on top of `azure-core`. + +""" +from ._connection_string_parser import ( + parse_connection_string +) + +__all__ = ["parse_connection_string"] diff --git a/sdk/core/azure-core/azure/core/utils/_connection_string_parser.py b/sdk/core/azure-core/azure/core/utils/_connection_string_parser.py new file mode 100644 index 000000000000..a074df43b586 --- /dev/null +++ b/sdk/core/azure-core/azure/core/utils/_connection_string_parser.py @@ -0,0 +1,46 @@ +# coding=utf-8 +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from typing import Mapping + + +def parse_connection_string(conn_str, case_sensitive_keys=False): + # type: (str, bool) -> Mapping[str, str] + """Parses the connection string into a dict of its component parts, with the option of preserving case + of keys, and validates that each key in the connection string has a provided value. If case of keys + is not preserved (ie. `case_sensitive_keys=False`), then a dict with LOWERCASE KEYS will be returned. + + :param str conn_str: String with connection details provided by Azure services. + :param bool case_sensitive_keys: Indicates whether the casing of the keys will be preserved. When `False`(the + default), all keys will be lower-cased. If set to `True`, the original casing of the keys will be preserved. + :rtype: Mapping + :raises: + ValueError: if each key in conn_str does not have a corresponding value and + for other bad formatting of connection strings - including duplicate + args, bad syntax, etc. + """ + + cs_args = [s.split("=", 1) for s in conn_str.strip().rstrip(";").split(";")] + if any(len(tup) != 2 or not all(tup) for tup in cs_args): + raise ValueError("Connection string is either blank or malformed.") + args_dict = dict(cs_args) # type: ignore + + if len(cs_args) != len(args_dict): + raise ValueError("Connection string is either blank or malformed.") + + if not case_sensitive_keys: + # if duplicate case insensitive keys are passed in, raise error + new_args_dict = {} + for key in args_dict.keys(): + new_key = key.lower() + if new_key in new_args_dict: + raise ValueError( + "Duplicate key in connection string: {}".format(new_key) + ) + new_args_dict[new_key] = args_dict[key] + return new_args_dict + + return args_dict diff --git a/sdk/core/azure-core/tests/test_connection_string_parsing.py b/sdk/core/azure-core/tests/test_connection_string_parsing.py new file mode 100644 index 000000000000..1decf6ce45a5 --- /dev/null +++ b/sdk/core/azure-core/tests/test_connection_string_parsing.py @@ -0,0 +1,131 @@ +import sys +import pytest +from azure.core.utils import parse_connection_string + +from devtools_testutils import AzureMgmtTestCase + +class CoreConnectionStringParserTests(AzureMgmtTestCase): + + def test_parsing_with_case_sensitive_keys_for_sensitive_conn_str(self, **kwargs): + conn_str = 'Endpoint=XXXXENDPOINTXXXX;SharedAccessKeyName=XXXXPOLICYXXXX;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + parse_result = parse_connection_string(conn_str, True) + assert parse_result["Endpoint"] == 'XXXXENDPOINTXXXX' + assert parse_result["SharedAccessKeyName"] == 'XXXXPOLICYXXXX' + assert parse_result["SharedAccessKey"] == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + with pytest.raises(KeyError): + parse_result["endPoint"] + with pytest.raises(KeyError): + parse_result["sharedAccESSkEynAME"] + with pytest.raises(KeyError): + parse_result["sharedaccesskey"] + + def test_parsing_with_case_insensitive_keys_for_sensitive_conn_str(self, **kwargs): + conn_str = 'Endpoint=XXXXENDPOINTXXXX;SharedAccessKeyName=XXXXPOLICYXXXX;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + parse_result = parse_connection_string(conn_str, False) + assert parse_result["endpoint"] == 'XXXXENDPOINTXXXX' + assert parse_result["sharedaccesskeyname"] == 'XXXXPOLICYXXXX' + assert parse_result["sharedaccesskey"] == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + + def test_parsing_with_case_insensitive_keys_for_insensitive_conn_str(self, **kwargs): + conn_str = 'enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + parse_result = parse_connection_string(conn_str, False) + assert parse_result["endpoint"] == 'XXXXENDPOINTXXXX' + assert parse_result["sharedaccesskeyname"] == 'XXXXPOLICYXXXX' + assert parse_result["sharedaccesskey"] == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + + def test_error_with_duplicate_case_sensitive_keys_for_sensitive_conn_str(self, **kwargs): + conn_str = 'Endpoint=XXXXENDPOINTXXXX;Endpoint=XXXXENDPOINT2XXXX;SharedAccessKeyName=XXXXPOLICYXXXX;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + with pytest.raises(ValueError) as e: + parse_result = parse_connection_string(conn_str, True) + assert str(e.value) == "Connection string is either blank or malformed." + + def test_success_with_duplicate_case_sensitive_keys_for_sensitive_conn_str(self, **kwargs): + conn_str = 'enDpoInt=XXXXENDPOINTXXXX;Endpoint=XXXXENDPOINT2XXXX;' + parse_result = parse_connection_string(conn_str, True) + assert parse_result["enDpoInt"] == 'XXXXENDPOINTXXXX' + assert parse_result["Endpoint"] == 'XXXXENDPOINT2XXXX' + + def test_error_with_duplicate_case_insensitive_keys_for_insensitive_conn_str(self, **kwargs): + conn_str = 'endPoinT=XXXXENDPOINTXXXX;eNdpOint=XXXXENDPOINT2XXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + with pytest.raises(ValueError) as e: + parse_result = parse_connection_string(conn_str, False) + assert str(e.value) == "Duplicate key in connection string: endpoint" + + def test_error_with_malformed_conn_str(self): + for conn_str in ["", "foobar", "foo;bar;baz", ";", "foo=;bar=;", "=", "=;=="]: + with pytest.raises(ValueError) as e: + parse_result = parse_connection_string(conn_str) + self.assertEqual(str(e.value), "Connection string is either blank or malformed.") + + def test_case_insensitive_clear_method(self): + conn_str = 'enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + parse_result = parse_connection_string(conn_str, False) + parse_result.clear() + assert len(parse_result) == 0 + + def test_case_insensitive_copy_method(self): + conn_str = 'enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + parse_result = parse_connection_string(conn_str, False) + copied = parse_result.copy() + assert copied == parse_result + + def test_case_insensitive_get_method(self): + conn_str = 'Endpoint=XXXXENDPOINTXXXX;SharedAccessKeyName=XXXXPOLICYXXXX;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + parse_result = parse_connection_string(conn_str, False) + assert parse_result.get("sharedaccesskeyname") == 'XXXXPOLICYXXXX' + assert parse_result.get("sharedaccesskey") == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + assert parse_result.get("accesskey") is None + assert parse_result.get("accesskey", "XXothertestkeyXX=") == "XXothertestkeyXX=" + + def test_case_insensitive_keys_method(self): + conn_str = 'enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + parse_result = parse_connection_string(conn_str, False) + keys = parse_result.keys() + assert len(keys) == 3 + assert "endpoint" in keys + + def test_case_insensitive_pop_method(self): + conn_str = 'enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + parse_result = parse_connection_string(conn_str, False) + endpoint = parse_result.pop("endpoint") + sharedaccesskey = parse_result.pop("sharedaccesskey") + assert len(parse_result) == 1 + assert endpoint == "XXXXENDPOINTXXXX" + assert sharedaccesskey == "THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=" + + def test_case_insensitive_update_with_insensitive_method(self): + conn_str = 'enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + conn_str2 = 'hostName=XXXXENDPOINTXXXX;ACCessKEy=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;' + parse_result_insensitive = parse_connection_string(conn_str, False) + parse_result_insensitive2 = parse_connection_string(conn_str2, False) + + parse_result_insensitive.update(parse_result_insensitive2) + assert len(parse_result_insensitive) == 5 + assert parse_result_insensitive["hostname"] == "XXXXENDPOINTXXXX" + assert parse_result_insensitive["accesskey"] == "THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=" + + # check that update replace duplicate case insensitive keys + conn_str_duplicate_key = "endpoint=XXXXENDPOINT2XXXX;ACCessKEy=TestKey" + parse_result_insensitive_dupe = parse_connection_string(conn_str_duplicate_key, False) + parse_result_insensitive.update(parse_result_insensitive_dupe) + assert parse_result_insensitive_dupe["endpoint"] == "XXXXENDPOINT2XXXX" + assert parse_result_insensitive_dupe["accesskey"] == "TestKey" + assert len(parse_result_insensitive) == 5 + + def test_case_sensitive_update_with_insensitive_method(self): + conn_str = 'enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + conn_str2 = 'hostName=XXXXENDPOINTXXXX;ACCessKEy=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;' + parse_result_insensitive = parse_connection_string(conn_str, False) + parse_result_sensitive = parse_connection_string(conn_str2, True) + + parse_result_sensitive.update(parse_result_insensitive) + assert len(parse_result_sensitive) == 5 + assert parse_result_sensitive["hostName"] == "XXXXENDPOINTXXXX" + with pytest.raises(KeyError): + parse_result_sensitive["hostname"] + + def test_case_insensitive_values_method(self): + conn_str = 'enDpoiNT=XXXXENDPOINTXXXX;sharedaccesskeyname=XXXXPOLICYXXXX;SHAREDACCESSKEY=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + parse_result = parse_connection_string(conn_str, False) + values = parse_result.values() + assert len(values) == 3 \ No newline at end of file