diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a7ae830..9b351c69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ ## Changes +- Added [DremioRestClient](dbt/adapters/dremio/api/rest/client.py) to isolate all Dremio API calls inside one class + ## Dependency - [#222](https://github.com/dremio/dbt-dremio/issues/222) Upgrade dbt-core to 1.8.8 and dbt-tests-adapter to 1.8.0 diff --git a/dbt/adapters/dremio/api/__init__.py b/dbt/adapters/dremio/api/__init__.py deleted file mode 100644 index 66db7a8e..00000000 --- a/dbt/adapters/dremio/api/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (C) 2022 Dremio Corporation - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# __init__.py -from .rest.endpoints import ( - delete_catalog, - sql_endpoint, - job_status, - create_catalog_api, - get_catalog_item, - login, -) diff --git a/dbt/adapters/dremio/api/cursor.py b/dbt/adapters/dremio/api/cursor.py index 462b6eae..cc58264c 100644 --- a/dbt/adapters/dremio/api/cursor.py +++ b/dbt/adapters/dremio/api/cursor.py @@ -17,13 +17,7 @@ import agate -from dbt.adapters.dremio.api.rest.endpoints import ( - sql_endpoint, - job_status, - job_results, - job_cancel_api, -) -from dbt.adapters.dremio.api.parameters import Parameters +from dbt.adapters.dremio.api.rest.client import DremioRestClient from dbt.adapters.events.logging import AdapterLogger @@ -31,8 +25,8 @@ class DremioCursor: - def __init__(self, api_parameters: Parameters): - self._parameters = api_parameters + def __init__(self, rest_client: DremioRestClient): + self._rest_client = rest_client self._closed = False self._job_id = None @@ -41,10 +35,6 @@ def __init__(self, api_parameters: Parameters): self._table_results: agate.Table = None self._description = None - @property - def parameters(self): - return self._parameters - @property def description(self): return self._description @@ -80,7 +70,7 @@ def job_results(self): def job_cancel(self): # cancels current job logger.debug(f"Cancelling job {self._job_id}") - return job_cancel_api(self._parameters, self._job_id) + return self._rest_client.job_cancel_api(self._job_id) def close(self): if self.closed: @@ -94,7 +84,7 @@ def execute(self, sql, bindings=None, fetch=False): if bindings is None: self._initialize() - json_payload = sql_endpoint(self._parameters, sql, context=None) + json_payload = self._rest_client.sql_endpoint(sql, context=None) self._job_id = json_payload["id"] @@ -130,7 +120,7 @@ def _populate_rowcount(self): job_id = self._job_id last_job_state = "" - job_status_response = job_status(self._parameters, job_id) + job_status_response = self._rest_client.job_status(job_id) job_status_state = job_status_response["jobState"] while True: @@ -145,7 +135,7 @@ def _populate_rowcount(self): if job_status_state == "COMPLETED" or job_status_state == "CANCELLED": break last_job_state = job_status_state - job_status_response = job_status(self._parameters, job_id) + job_status_response = self._rest_client.job_status(job_id) job_status_state = job_status_response["jobState"] # this is done as job status does not return a rowCount if there are no rows affected (even in completed job_state) @@ -161,8 +151,7 @@ def _populate_rowcount(self): def _populate_job_results(self, row_limit=500): if self._job_results == None: - combined_job_results = job_results( - self._parameters, + combined_job_results = self._rest_client.job_results( self._job_id, offset=0, limit=row_limit, @@ -177,8 +166,7 @@ def _populate_job_results(self, row_limit=500): while current_row_count < total_row_count: combined_job_results["rows"].extend( - job_results( - self._parameters, + self._rest_client.job_results( self._job_id, offset=current_row_count, limit=row_limit, diff --git a/dbt/adapters/dremio/api/handle.py b/dbt/adapters/dremio/api/handle.py index e1825267..addac087 100644 --- a/dbt/adapters/dremio/api/handle.py +++ b/dbt/adapters/dremio/api/handle.py @@ -14,7 +14,7 @@ from dbt.adapters.dremio.api.cursor import DremioCursor from dbt.adapters.dremio.api.parameters import Parameters -from dbt.adapters.dremio.api.rest.endpoints import login +from dbt.adapters.dremio.api.rest.client import DremioRestClient from dbt.adapters.events.logging import AdapterLogger @@ -23,19 +23,19 @@ class DremioHandle: def __init__(self, parameters: Parameters): - self._parameters = parameters + self._rest_client = DremioRestClient(parameters) self._cursor = None self.closed = False - def get_parameters(self): - return self._parameters + def get_client(self): + return self._rest_client def cursor(self): if self.closed: raise Exception("HandleClosed") if self._cursor is None or self._cursor.closed: - self._parameters = login(self._parameters) - self._cursor = DremioCursor(self._parameters) + self._rest_client.start() + self._cursor = DremioCursor(self._rest_client) return self._cursor def close(self): diff --git a/dbt/adapters/dremio/api/rest/client.py b/dbt/adapters/dremio/api/rest/client.py new file mode 100644 index 00000000..26a2b307 --- /dev/null +++ b/dbt/adapters/dremio/api/rest/client.py @@ -0,0 +1,135 @@ +# Copyright (C) 2022 Dremio Corporation + +# Copyright (c) 2019 Ryan Murray. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + +import requests + +from dbt.adapters.dremio.api.authentication import DremioPatAuthentication +from dbt.adapters.dremio.api.parameters import Parameters +from dbt.adapters.dremio.api.rest.utils import _post, _get, _delete +from dbt.adapters.dremio.api.rest.url_builder import UrlBuilder + +from dbt.adapters.events.logging import AdapterLogger + +logger = AdapterLogger("dremio") + +session = requests.Session() + + +class DremioRestClient: + def __init__(self, api_parameters: Parameters): + self._parameters = api_parameters + + def start(self): + self._parameters = self.__login() + + def __login(self, timeout=10): + if isinstance(self._parameters.authentication, DremioPatAuthentication): + return self._parameters + + url = UrlBuilder.login_url(self._parameters) + response = _post( + url, + json={ + "userName": self._parameters.authentication.username, + "password": self._parameters.authentication.password, + }, + timeout=timeout, + ssl_verify=self._parameters.authentication.verify_ssl, + ) + + self._parameters.authentication.token = response["token"] + + return self._parameters + + def sql_endpoint(self, query, context=None): + url = UrlBuilder.sql_url(self._parameters) + return _post( + url, + self._parameters.authentication.get_headers(), + ssl_verify=self._parameters.authentication.verify_ssl, + json={"sql": query, "context": context}, + ) + + def job_status(self, job_id): + url = UrlBuilder.job_status_url(self._parameters, job_id) + return _get( + url, + self._parameters.authentication.get_headers(), + ssl_verify=self._parameters.authentication.verify_ssl, + ) + + def job_cancel_api(self, job_id): + url = UrlBuilder.job_cancel_url(self._parameters, job_id) + return _post( + url, + self._parameters.authentication.get_headers(), + json=None, + ssl_verify=self._parameters.authentication.verify_ssl, + ) + + def job_results(self, job_id, offset=0, limit=100): + url = UrlBuilder.job_results_url( + self._parameters, + job_id, + offset, + limit, + ) + return _get( + url, + self._parameters.authentication.get_headers(), + ssl_verify=self._parameters.authentication.verify_ssl, + ) + + def create_catalog_api(self, json): + url = UrlBuilder.catalog_url(self._parameters) + return _post( + url, + self._parameters.authentication.get_headers(), + json=json, + ssl_verify=self._parameters.authentication.verify_ssl, + ) + + def get_catalog_item(self, catalog_id=None, catalog_path=None): + if catalog_id is None and catalog_path is None: + raise TypeError( + "both id and path can't be None for a catalog_item call") + + # Will use path if both id and path are specified + if catalog_path: + url = UrlBuilder.catalog_item_by_path_url( + self._parameters, + catalog_path, + ) + else: + url = UrlBuilder.catalog_item_by_id_url( + self._parameters, + catalog_id, + ) + return _get( + url, + self._parameters.authentication.get_headers(), + ssl_verify=self._parameters.authentication.verify_ssl, + ) + + def delete_catalog(self, cid): + url = UrlBuilder.delete_catalog_url(self._parameters, cid) + return _delete( + url, + self._parameters.authentication.get_headers(), + ssl_verify=self._parameters.authentication.verify_ssl, + ) \ No newline at end of file diff --git a/dbt/adapters/dremio/api/rest/endpoints.py b/dbt/adapters/dremio/api/rest/utils.py similarity index 57% rename from dbt/adapters/dremio/api/rest/endpoints.py rename to dbt/adapters/dremio/api/rest/utils.py index f26c418f..39e8bec7 100644 --- a/dbt/adapters/dremio/api/rest/endpoints.py +++ b/dbt/adapters/dremio/api/rest/utils.py @@ -31,9 +31,6 @@ import json as jsonlib from requests.exceptions import HTTPError -from dbt.adapters.dremio.api.authentication import DremioPatAuthentication -from dbt.adapters.dremio.api.parameters import Parameters -from dbt.adapters.dremio.api.rest.url_builder import UrlBuilder from dbt.adapters.events.logging import AdapterLogger @@ -41,6 +38,7 @@ session = requests.Session() + def _get(url, request_headers, details="", ssl_verify=True): response = session.get(url, headers=request_headers, verify=ssl_verify) return _check_error(response, details) @@ -116,11 +114,14 @@ def _check_error(response, details=""): except: # NOQA return response.text if code == 400: - raise DremioBadRequestException("Bad request:" + details, error, response) + raise DremioBadRequestException("Bad request:" + details, error, + response) if code == 401: - raise DremioUnauthorizedException("Unauthorized:" + details, error, response) + raise DremioUnauthorizedException("Unauthorized:" + details, error, + response) if code == 403: - raise DremioPermissionException("No permission:" + details, error, response) + raise DremioPermissionException("No permission:" + details, error, + response) if code == 404: raise DremioNotFoundException("Not found:" + details, error, response) if code == 408: @@ -128,7 +129,8 @@ def _check_error(response, details=""): "Request timeout:" + details, error, response ) if code == 409: - raise DremioAlreadyExistsException("Already exists:" + details, error, response) + raise DremioAlreadyExistsException("Already exists:" + details, error, + response) if code == 429: raise DremioTooManyRequestsException( "Too many requests:" + details, error, response @@ -148,106 +150,3 @@ def _check_error(response, details=""): raise DremioException("Unknown error", error) -def login(api_parameters: Parameters, timeout=10): - - if isinstance(api_parameters.authentication, DremioPatAuthentication): - return api_parameters - - url = UrlBuilder.login_url(api_parameters) - response = _post( - url, - json={ - "userName": api_parameters.authentication.username, - "password": api_parameters.authentication.password, - }, - timeout=timeout, - ssl_verify=api_parameters.authentication.verify_ssl, - ) - - api_parameters.authentication.token = response["token"] - - return api_parameters - - -def sql_endpoint(api_parameters: Parameters, query, context=None): - url = UrlBuilder.sql_url(api_parameters) - return _post( - url, - api_parameters.authentication.get_headers(), - ssl_verify=api_parameters.authentication.verify_ssl, - json={"sql": query, "context": context}, - ) - - -def job_status(api_parameters: Parameters, job_id): - url = UrlBuilder.job_status_url(api_parameters, job_id) - return _get( - url, - api_parameters.authentication.get_headers(), - ssl_verify=api_parameters.authentication.verify_ssl, - ) - - -def job_cancel_api(api_parameters: Parameters, job_id): - url = UrlBuilder.job_cancel_url(api_parameters, job_id) - return _post( - url, - api_parameters.authentication.get_headers(), - json=None, - ssl_verify=api_parameters.authentication.verify_ssl, - ) - - -def job_results(api_parameters: Parameters, job_id, offset=0, limit=100): - url = UrlBuilder.job_results_url( - api_parameters, - job_id, - offset, - limit, - ) - return _get( - url, - api_parameters.authentication.get_headers(), - ssl_verify=api_parameters.authentication.verify_ssl, - ) - - -def create_catalog_api(api_parameters, json): - url = UrlBuilder.catalog_url(api_parameters) - return _post( - url, - api_parameters.authentication.get_headers(), - json=json, - ssl_verify=api_parameters.authentication.verify_ssl, - ) - - -def get_catalog_item(api_parameters, catalog_id=None, catalog_path=None): - if catalog_id is None and catalog_path is None: - raise TypeError("both id and path can't be None for a catalog_item call") - - # Will use path if both id and path are specified - if catalog_path: - url = UrlBuilder.catalog_item_by_path_url( - api_parameters, - catalog_path, - ) - else: - url = UrlBuilder.catalog_item_by_id_url( - api_parameters, - catalog_id, - ) - return _get( - url, - api_parameters.authentication.get_headers(), - ssl_verify=api_parameters.authentication.verify_ssl, - ) - - -def delete_catalog(api_parameters, cid): - url = UrlBuilder.delete_catalog_url(api_parameters, cid) - return _delete( - url, - api_parameters.authentication.get_headers(), - ssl_verify=api_parameters.authentication.verify_ssl, - ) diff --git a/dbt/adapters/dremio/connections.py b/dbt/adapters/dremio/connections.py index 3096efaa..24ca039e 100644 --- a/dbt/adapters/dremio/connections.py +++ b/dbt/adapters/dremio/connections.py @@ -29,11 +29,8 @@ from dbt.adapters.sql import SQLConnectionManager from dbt.adapters.contracts.connection import AdapterResponse -from dbt.adapters.dremio.api.rest.endpoints import ( - delete_catalog, - create_catalog_api, - get_catalog_item, -) +from dbt.adapters.dremio.api.rest.client import DremioRestClient + from dbt.adapters.dremio.api.rest.error import ( DremioAlreadyExistsException, DremioNotFoundException, @@ -133,17 +130,20 @@ def add_commit_query(self): # Auto_begin may not be relevant with the rest_api def add_query( - self, sql, auto_begin=True, bindings=None, abridge_sql_log=False, fetch=False + self, sql, auto_begin=True, bindings=None, abridge_sql_log=False, + fetch=False ): connection = self.get_thread_connection() if auto_begin and connection.transaction_open is False: self.begin() - logger.debug(f'Using {self.TYPE} connection "{connection.name}". fetch={fetch}') + logger.debug( + f'Using {self.TYPE} connection "{connection.name}". fetch={fetch}') with self.exception_handler(sql): if abridge_sql_log: - logger.debug("On {}: {}....".format(connection.name, sql[0:512])) + logger.debug( + "On {}: {}....".format(connection.name, sql[0:512])) else: logger.debug("On {}: {}".format(connection.name, sql)) @@ -196,13 +196,12 @@ def drop_catalog(self, database, schema): thread_connection = self.get_thread_connection() connection = self.open(thread_connection) credentials = connection.credentials - api_parameters = connection.handle.get_parameters() + rest_client = connection.handle.get_client() path_list = self._create_path_list(database, schema) if database != credentials.datalake: try: - catalog_info = get_catalog_item( - api_parameters, + catalog_info = rest_client.get_catalog_item( catalog_id=None, catalog_path=path_list, ) @@ -210,13 +209,13 @@ def drop_catalog(self, database, schema): logger.debug("Catalog not found. Returning") return - delete_catalog(api_parameters, catalog_info["id"]) + rest_client.delete_catalog(catalog_info["id"]) def create_catalog(self, relation): thread_connection = self.get_thread_connection() connection = self.open(thread_connection) credentials = connection.credentials - api_parameters = connection.handle.get_parameters() + rest_client = connection.handle.get_client() database = relation.database schema = relation.schema @@ -224,11 +223,11 @@ def create_catalog(self, relation): logger.debug("Database is default: creating folders only") else: logger.debug(f"Creating space: {database}") - self._create_space(database, api_parameters) + self._create_space(database, rest_client) if database != credentials.datalake: logger.debug(f"Creating folder(s): {database}.{schema}") - self._create_folders(database, schema, api_parameters) + self._create_folders(database, schema, rest_client) return def _make_new_space_json(self, name) -> json: @@ -239,20 +238,21 @@ def _make_new_folder_json(self, path) -> json: python_dict = {"entityType": "folder", "path": path} return json.dumps(python_dict) - def _create_space(self, database, api_parameters): + def _create_space(self, database, rest_client: DremioRestClient): space_json = self._make_new_space_json(database) try: - create_catalog_api(api_parameters, space_json) + rest_client.create_catalog_api(space_json) except DremioAlreadyExistsException: - logger.debug(f"Database {database} already exists. Creating folders only.") + logger.debug( + f"Database {database} already exists. Creating folders only.") - def _create_folders(self, database, schema, api_parameters): + def _create_folders(self, database, schema, rest_client: DremioRestClient): temp_path_list = [database] for folder in schema.split("."): temp_path_list.append(folder) folder_json = self._make_new_folder_json(temp_path_list) try: - create_catalog_api(api_parameters, folder_json) + rest_client.create_catalog_api(folder_json) except DremioAlreadyExistsException: logger.debug(f"Folder {folder} already exists.") except DremioBadRequestException as e: diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 58dbc422..37153cc2 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -17,7 +17,7 @@ class TestRetryConnection: - @patch("dbt.adapters.dremio.api.rest.endpoints._post") + @patch("dbt.adapters.dremio.api.rest.client._post") @patch("dbt.adapters.contracts.connection.Connection") # When you nest patch decorators the mocks are passed in to the decorated function in bottom up order. def test_connection_retry( diff --git a/tests/unit/test_job_results.py b/tests/unit/test_job_results.py index 1eca997d..8100bc7f 100644 --- a/tests/unit/test_job_results.py +++ b/tests/unit/test_job_results.py @@ -15,6 +15,7 @@ from dbt.adapters.dremio.api.cursor import DremioCursor from dbt.adapters.dremio.api.parameters import Parameters from dbt.adapters.dremio.api.authentication import DremioAuthentication +from dbt.adapters.dremio.api.rest.client import DremioRestClient class TestJobResults: @@ -60,14 +61,14 @@ class TestJobResults: ], } - @patch("dbt.adapters.dremio.api.cursor.job_results") + @patch("dbt.adapters.dremio.api.rest.client.DremioRestClient.job_results") def test_job_result_pagination(self, mocked_job_results_func): # Arrange ROW_LIMIT = 2 JOB_RESULT_TOTAL_CALLS = ceil( self.mocked_job_results_dict[0]["rowCount"] / ROW_LIMIT ) - dremio_cursor_obj = DremioCursor(Parameters("base_url", DremioAuthentication())) + dremio_cursor_obj = DremioCursor(DremioRestClient(Parameters("base_url", DremioAuthentication()))) mocked_job_results_func.side_effect = self.mocked_job_results_dict # Act diff --git a/tests/unit/test_payload_error.py b/tests/unit/test_payload_error.py index e5049670..676c6350 100644 --- a/tests/unit/test_payload_error.py +++ b/tests/unit/test_payload_error.py @@ -13,13 +13,14 @@ from dbt.adapters.dremio.api.cursor import DremioCursor from dbt.adapters.dremio.api.parameters import Parameters from dbt.adapters.dremio.api.authentication import DremioAuthentication +from dbt.adapters.dremio.api.rest.client import DremioRestClient class TestPayloadErrorMessage(TestCase): - @mock.patch("dbt.adapters.dremio.api.cursor.job_status") + @mock.patch("dbt.adapters.dremio.api.rest.client.DremioRestClient.job_status") def test_payload_error(self, mocked_job_status): dremio_cursor_object = DremioCursor( - Parameters("base_url", DremioAuthentication) + DremioRestClient(Parameters("base_url", DremioAuthentication())) ) mocked_job_status.return_value = { "jobState": "FAILED",