Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a deepcopy of config in api client #172

Merged
merged 11 commits into from
Jun 21, 2023
4 changes: 2 additions & 2 deletions databricks/sdk/__init__.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 13 additions & 2 deletions databricks/sdk/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import base64
import configparser
import copy
import functools
import json
import logging
Expand All @@ -13,7 +14,7 @@
import urllib.parse
from datetime import datetime
from json import JSONDecodeError
from typing import Callable, Dict, Iterable, List, Optional
from typing import Callable, Dict, Iterable, List, Optional, Union

import requests
import requests.auth
Expand Down Expand Up @@ -370,7 +371,7 @@ def refresh(self) -> Token:
self.METADATA_SERVICE_VERSION_HEADER: self.METADATA_SERVICE_VERSION,
self.METADATA_SERVICE_HOST_HEADER: self.host
})
json_resp: dict[str, str] = resp.json()
json_resp: dict[str, Union[str, float]] = resp.json()
access_token = json_resp.get("access_token", None)
if access_token is None:
raise ValueError("Metadata Service returned empty token")
Expand Down Expand Up @@ -817,6 +818,16 @@ def _init_auth(self):
def __repr__(self):
return f'<{self.debug_string()}>'

def copy(self):
"""Creates a copy of the config object.
All the copies share most of their internal state (ie, shared reference to fields such as credential_provider).
Copies have their own instances of the following fields
- `_user_agent_other_info`
"""
cpy: Config = copy.copy(self)
cpy._user_agent_other_info = copy.deepcopy(self._user_agent_other_info)
return cpy


class DatabricksError(IOError):
""" Generic error from Databricks REST API """
Expand Down
51 changes: 48 additions & 3 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import pytest

from databricks.sdk.core import (Config, DatabricksCliTokenSource,
from databricks.sdk.core import (Config, CredentialsProvider,
DatabricksCliTokenSource, HeaderFactory,
databricks_cli)
from databricks.sdk.version import __version__

Expand Down Expand Up @@ -140,10 +141,54 @@ def system(self):
monkeypatch.setenv('DATABRICKS_SDK_UPSTREAM', "upstream-product")
monkeypatch.setenv('DATABRICKS_SDK_UPSTREAM_VERSION', "0.0.1")

config = Config(host='http://localhost', username="something", password="something", product='test', product_version='0.0.0')\
.with_user_agent_extra('test-extra-1', '1')\
config = Config(host='http://localhost', username="something", password="something", product='test',
product_version='0.0.0') \
.with_user_agent_extra('test-extra-1', '1') \
.with_user_agent_extra('test-extra-2', '2')

assert config.user_agent == (
f"test/0.0.0 databricks-sdk-py/{__version__} python/3.0.0 os/testos auth/basic"
f" test-extra-1/1 test-extra-2/2 upstream/upstream-product upstream-version/0.0.1")


def test_config_copy_shallow_copies_credential_provider():

class TestCredentialsProvider(CredentialsProvider):

def __init__(self):
super().__init__()
self._token = "token1"

def auth_type(self) -> str:
return "test"

def __call__(self, cfg: 'Config') -> HeaderFactory:
return lambda: {"token": self._token}

def refresh(self):
self._token = "token2"

credential_provider = TestCredentialsProvider()
config = Config(credentials_provider=credential_provider)
config_copy = config.copy()

assert config.authenticate()["token"] == "token1"
assert config_copy.authenticate()["token"] == "token1"

credential_provider.refresh()

assert config.authenticate()["token"] == "token2"
assert config_copy.authenticate()["token"] == "token2"
assert config._credentials_provider == config_copy._credentials_provider


def test_config_copy_deep_copies_user_agent_other_info(config):
config_copy = config.copy()

config.with_user_agent_extra("test", "test1")
assert "test/test1" not in config_copy.user_agent
assert "test/test1" in config.user_agent

config_copy.with_user_agent_extra("test", "test2")
assert "test/test2" in config_copy.user_agent
assert "test/test2" not in config.user_agent