From e661698eb6db9b07416655ea4b6df6583b9fbfec Mon Sep 17 00:00:00 2001 From: Ren Ren Date: Wed, 30 Mar 2022 15:49:53 -0400 Subject: [PATCH 1/8] change connection from class to module --- nasdaqdatalink/connection.py | 192 ++++++++++++++-------------- nasdaqdatalink/model/database.py | 2 +- nasdaqdatalink/model/datatable.py | 2 +- nasdaqdatalink/operations/get.py | 2 +- nasdaqdatalink/operations/list.py | 2 +- test/test_connection.py | 4 +- test/test_data.py | 2 +- test/test_database.py | 6 +- test/test_dataset.py | 4 +- test/test_datatable.py | 10 +- test/test_datatable_data.py | 4 +- test/test_get.py | 4 +- test/test_get_point_in_time_data.py | 22 ++-- test/test_get_table.py | 12 +- test/test_point_in_time.py | 6 +- test/test_retries.py | 2 +- 16 files changed, 135 insertions(+), 141 deletions(-) diff --git a/nasdaqdatalink/connection.py b/nasdaqdatalink/connection.py index 2338a5f..73a6241 100644 --- a/nasdaqdatalink/connection.py +++ b/nasdaqdatalink/connection.py @@ -13,103 +13,97 @@ AuthenticationError, ForbiddenError, InvalidRequestError, NotFoundError, ServiceUnavailableError) - -class Connection: - @classmethod - def request(cls, http_verb, url, **options): - if 'headers' in options: - headers = options['headers'] +def request(http_verb, url, **options): + if 'headers' in options: + headers = options['headers'] + else: + headers = {} + + accept_value = 'application/json' + if ApiConfig.api_version: + accept_value += ", application/vnd.data.nasdaq+json;version=%s" % ApiConfig.api_version + + headers = Util.merge_to_dicts({'accept': accept_value, + 'request-source': 'python', + 'request-source-version': VERSION}, headers) + if ApiConfig.api_key: + headers = Util.merge_to_dicts({'x-api-token': ApiConfig.api_key}, headers) + + options['headers'] = headers + + abs_url = '%s/%s' % (ApiConfig.api_base, url) + + return execute_request(http_verb, abs_url, **options) + +def execute_request(http_verb, url, **options): + session = get_session(url) + + try: + response = session.request(method=http_verb, + url=url, + verify=ApiConfig.verify_ssl, + **options) + if response.status_code < 200 or response.status_code >= 300: + handle_api_error(response) else: - headers = {} - - accept_value = 'application/json' - if ApiConfig.api_version: - accept_value += ", application/vnd.data.nasdaq+json;version=%s" % ApiConfig.api_version - - headers = Util.merge_to_dicts({'accept': accept_value, - 'request-source': 'python', - 'request-source-version': VERSION}, headers) - if ApiConfig.api_key: - headers = Util.merge_to_dicts({'x-api-token': ApiConfig.api_key}, headers) - - options['headers'] = headers - - abs_url = '%s/%s' % (ApiConfig.api_base, url) - - return cls.execute_request(http_verb, abs_url, **options) - - @classmethod - def execute_request(cls, http_verb, url, **options): - session = cls.get_session() - - try: - response = session.request(method=http_verb, - url=url, - verify=ApiConfig.verify_ssl, - **options) - if response.status_code < 200 or response.status_code >= 300: - cls.handle_api_error(response) - else: - return response - except requests.exceptions.RequestException as e: - if e.response: - cls.handle_api_error(e.response) - raise e - - @classmethod - def get_session(cls): - session = requests.Session() - adapter = HTTPAdapter(max_retries=cls.get_retries()) - session.mount(ApiConfig.api_protocol, adapter) - - return session - - @classmethod - def get_retries(cls): - if not ApiConfig.use_retries: - return Retry(total=0) - - Retry.BACKOFF_MAX = ApiConfig.max_wait_between_retries - retries = Retry(total=ApiConfig.number_of_retries, - connect=ApiConfig.number_of_retries, - read=ApiConfig.number_of_retries, - status_forcelist=ApiConfig.retry_status_codes, - backoff_factor=ApiConfig.retry_backoff_factor, - raise_on_status=False) - - return retries - - @classmethod - def parse(cls, response): - try: - return response.json() - except ValueError: - raise DataLinkError(http_status=response.status_code, http_body=response.text) - - @classmethod - def handle_api_error(cls, resp): - error_body = cls.parse(resp) - - # if our app does not form a proper data_link_error response - # throw generic error - if 'error' not in error_body: - raise DataLinkError(http_status=resp.status_code, http_body=resp.text) - - code = error_body['error']['code'] - message = error_body['error']['message'] - prog = re.compile('^QE([a-zA-Z])x') - if prog.match(code): - code_letter = prog.match(code).group(1) - - d_klass = { - 'L': LimitExceededError, - 'M': InternalServerError, - 'A': AuthenticationError, - 'P': ForbiddenError, - 'S': InvalidRequestError, - 'C': NotFoundError, - 'X': ServiceUnavailableError - } - klass = d_klass.get(code_letter, DataLinkError) - - raise klass(message, resp.status_code, resp.text, resp.headers, code) + return response + except requests.exceptions.RequestException as e: + if e.response: + handle_api_error(e.response) + raise e + +def get_retries(): + if not ApiConfig.use_retries: + return Retry(total=0) + + Retry.BACKOFF_MAX = ApiConfig.max_wait_between_retries + retries = Retry(total=ApiConfig.number_of_retries, + connect=ApiConfig.number_of_retries, + read=ApiConfig.number_of_retries, + status_forcelist=ApiConfig.retry_status_codes, + backoff_factor=ApiConfig.retry_backoff_factor, + raise_on_status=False) + + return retries + +session = requests.Session() + +def get_session(url = ApiConfig.api_protocol): + adapter = HTTPAdapter(max_retries=get_retries()) + session.mount(url, adapter) + return session + +def parse(response): + try: + return response.json() + except ValueError: + raise DataLinkError(http_status=response.status_code, http_body=response.text) + + + +def handle_api_error(resp): + error_body = parse(resp) + + # if our app does not form a proper data_link_error response + # throw generic error + if 'error' not in error_body: + raise DataLinkError(http_status=resp.status_code, http_body=resp.text) + + code = error_body['error']['code'] + message = error_body['error']['message'] + prog = re.compile('^QE([a-zA-Z])x') + if prog.match(code): + code_letter = prog.match(code).group(1) + + d_klass = { + 'L': LimitExceededError, + 'M': InternalServerError, + 'A': AuthenticationError, + 'P': ForbiddenError, + 'S': InvalidRequestError, + 'C': NotFoundError, + 'X': ServiceUnavailableError + } + klass = d_klass.get(code_letter, DataLinkError) + + raise klass(message, resp.status_code, resp.text, resp.headers, code) \ No newline at end of file diff --git a/nasdaqdatalink/model/database.py b/nasdaqdatalink/model/database.py index 870dedc..5cde79a 100644 --- a/nasdaqdatalink/model/database.py +++ b/nasdaqdatalink/model/database.py @@ -4,7 +4,7 @@ import nasdaqdatalink.model.dataset from nasdaqdatalink.api_config import ApiConfig -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as Connection from nasdaqdatalink.errors.data_link_error import DataLinkError from nasdaqdatalink.message import Message from nasdaqdatalink.operations.get import GetOperation diff --git a/nasdaqdatalink/model/datatable.py b/nasdaqdatalink/model/datatable.py index 2edadb8..b253764 100644 --- a/nasdaqdatalink/model/datatable.py +++ b/nasdaqdatalink/model/datatable.py @@ -3,7 +3,7 @@ from six.moves.urllib.request import urlopen -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as Connection from nasdaqdatalink.errors.data_link_error import DataLinkError from nasdaqdatalink.message import Message from nasdaqdatalink.operations.get import GetOperation diff --git a/nasdaqdatalink/operations/get.py b/nasdaqdatalink/operations/get.py index 8f93b95..3d70a79 100644 --- a/nasdaqdatalink/operations/get.py +++ b/nasdaqdatalink/operations/get.py @@ -1,7 +1,7 @@ from inflection import singularize from .operation import Operation -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as Connection from nasdaqdatalink.util import Util diff --git a/nasdaqdatalink/operations/list.py b/nasdaqdatalink/operations/list.py index 6aa020a..fb2f5cd 100644 --- a/nasdaqdatalink/operations/list.py +++ b/nasdaqdatalink/operations/list.py @@ -1,5 +1,5 @@ from .operation import Operation -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as Connection from nasdaqdatalink.util import Util from nasdaqdatalink.model.paginated_list import PaginatedList from nasdaqdatalink.utils.request_type_util import RequestType diff --git a/test/test_connection.py b/test/test_connection.py index 96d8380..3ee62eb 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -1,4 +1,4 @@ -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as Connection from nasdaqdatalink.api_config import ApiConfig from nasdaqdatalink.errors.data_link_error import ( DataLinkError, LimitExceededError, InternalServerError, @@ -65,7 +65,7 @@ def test_non_data_link_error(self, request_method): DataLinkError, lambda: Connection.request(request_method, 'databases')) @parameterized.expand(['GET', 'POST']) - @patch('nasdaqdatalink.connection.Connection.execute_request') + @patch('nasdaqdatalink.connection.execute_request') def test_build_request(self, request_method, mock): ApiConfig.api_key = 'api_token' ApiConfig.api_version = '2015-04-09' diff --git a/test/test_data.py b/test/test_data.py index 7852dbe..53817a1 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -77,7 +77,7 @@ def tearDownClass(cls): httpretty.disable() httpretty.reset() - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_data_calls_connection(self, mock): Data.all(params={'database_code': 'NSE', 'dataset_code': 'OIL'}) expected = call('get', 'datasets/NSE/OIL/data', params={}) diff --git a/test/test_database.py b/test/test_database.py index bbae558..0b11cec 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -7,7 +7,7 @@ from six.moves.urllib.parse import parse_qs, urlparse from nasdaqdatalink.api_config import ApiConfig -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as Connection from nasdaqdatalink.errors.data_link_error import (InternalServerError, DataLinkError) from nasdaqdatalink.model.database import Database from test.factories.database import DatabaseFactory @@ -34,7 +34,7 @@ def tearDownClass(cls): httpretty.disable() httpretty.reset() - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_database_calls_connection(self, mock): database = Database('NSE') database.data_fields() @@ -80,7 +80,7 @@ def tearDownClass(cls): httpretty.disable() httpretty.reset() - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_databases_calls_connection(self, mock): Database.all() expected = call('get', 'databases', params={}) diff --git a/test/test_dataset.py b/test/test_dataset.py index c44ea65..aed9b8a 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -30,7 +30,7 @@ def tearDownClass(cls): httpretty.disable() httpretty.reset() - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_dataset_calls_connection(self, mock): d = Dataset('NSE/OIL') d.data_fields() @@ -84,7 +84,7 @@ def tearDownClass(cls): httpretty.disable() httpretty.reset() - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datasets_calls_connection(self, mock): Dataset.all() expected = call('get', 'datasets', params={}) diff --git a/test/test_datatable.py b/test/test_datatable.py index ab80194..ff5525b 100644 --- a/test/test_datatable.py +++ b/test/test_datatable.py @@ -37,26 +37,26 @@ def tearDownClass(cls): def tearDown(self): RequestType.USE_GET_REQUEST = True - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_metadata_calls_connection(self, mock): Datatable('ZACKS/FC').data_fields() expected = call('get', 'datatables/ZACKS/FC/metadata', params={}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_data_calls_connection_with_no_params_for_get_request(self, mock): Datatable('ZACKS/FC').data() expected = call('get', 'datatables/ZACKS/FC', params={}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_data_calls_connection_with_no_params_for_post_request(self, mock): RequestType.USE_GET_REQUEST = False Datatable('ZACKS/FC').data() expected = call('post', 'datatables/ZACKS/FC', json={}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_calls_connection_with_params_for_get_request(self, mock): params = {'ticker': ['AAPL', 'MSFT'], 'per_end_date': {'gte': '2015-01-01'}, @@ -76,7 +76,7 @@ def test_datatable_calls_connection_with_params_for_get_request(self, mock): expected = call('get', 'datatables/ZACKS/FC', params=expected_params) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_calls_connection_with_params_for_post_request(self, mock): RequestType.USE_GET_REQUEST = False params = {'ticker': ['AAPL', 'MSFT'], diff --git a/test/test_datatable_data.py b/test/test_datatable_data.py index 7ba53f3..b837fc9 100644 --- a/test/test_datatable_data.py +++ b/test/test_datatable_data.py @@ -83,7 +83,7 @@ def tearDownClass(cls): def tearDown(self): RequestType.USE_GET_REQUEST = True - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_data_calls_connection_get(self, mock): datatable = Datatable('ZACKS/FC') Data.page(datatable, params={'ticker': ['AAPL', 'MSFT'], @@ -95,7 +95,7 @@ def test_data_calls_connection_get(self, mock): 'qopts.columns[]': ['ticker', 'per_end_date']}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_data_calls_connection_post(self, mock): RequestType.USE_GET_REQUEST = False datatable = Datatable('ZACKS/FC') diff --git a/test/test_get.py b/test/test_get.py index 950c5c5..8e753c7 100644 --- a/test/test_get.py +++ b/test/test_get.py @@ -8,7 +8,7 @@ from nasdaqdatalink.model.merged_dataset import MergedDataset from nasdaqdatalink.get import get from nasdaqdatalink.api_config import ApiConfig -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as Connection class GetSingleDatasetTest(unittest.TestCase): @@ -37,7 +37,7 @@ def test_returns_numpys_when_requested(self): def test_setting_api_key_config(self): mock_connection = Mock(wraps=Connection) - with patch('nasdaqdatalink.connection.Connection.execute_request', + with patch('nasdaqdatalink.connection.execute_request', new=mock_connection.execute_request) as mock: ApiConfig.api_key = 'api_key_configured' get('NSE/OIL') diff --git a/test/test_get_point_in_time_data.py b/test/test_get_point_in_time_data.py index 8fa57f7..ef0b236 100644 --- a/test/test_get_point_in_time_data.py +++ b/test/test_get_point_in_time_data.py @@ -27,7 +27,7 @@ def tearDownClass(cls): def tearDown(self): RequestType.USE_GET_REQUEST = True - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_get_point_in_time_returns_data_frame_object(self, mock): with self.assertWarns(UserWarning): df = nasdaqdatalink.get_point_in_time( @@ -36,7 +36,7 @@ def test_get_point_in_time_returns_data_frame_object(self, mock): self.assertIsInstance(df, pandas.core.frame.DataFrame) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_asofdate_call_connection(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time('ZACKS/FC', interval='asofdate', date='2020-01-01') @@ -44,7 +44,7 @@ def test_asofdate_call_connection(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_asofdate_call_connection_with_datetimes(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time( @@ -54,7 +54,7 @@ def test_asofdate_call_connection_with_datetimes(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_asofdate_call_without_date(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time('ZACKS/FC', interval='asofdate') @@ -62,7 +62,7 @@ def test_asofdate_call_without_date(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_from_call_connection(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time( @@ -75,7 +75,7 @@ def test_from_call_connection(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_from_call_connection_with_datetimes(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time( @@ -90,7 +90,7 @@ def test_from_call_connection_with_datetimes(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_between_call_connection(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time( @@ -103,7 +103,7 @@ def test_between_call_connection(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_between_call_connection_with_datetimes(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time( @@ -118,7 +118,7 @@ def test_between_call_connection_with_datetimes(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_invalid_interval_connection(self, mock): self.assertRaises(InvalidRequestError, lambda: nasdaqdatalink.get_point_in_time('ZACKS/FC')) self.assertRaises( @@ -126,7 +126,7 @@ def test_invalid_interval_connection(self, mock): lambda: nasdaqdatalink.get_point_in_time('ZACKS/FC', interval='nasdaqdatalink') ) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_invalid_from_connection(self, mock): self.assertRaises( InvalidRequestError, @@ -145,7 +145,7 @@ def test_invalid_from_connection(self, mock): ) ) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_invalid_between_connection(self, mock): self.assertRaises( InvalidRequestError, diff --git a/test/test_get_table.py b/test/test_get_table.py index 8100b66..7f49f84 100644 --- a/test/test_get_table.py +++ b/test/test_get_table.py @@ -37,21 +37,21 @@ def tearDownClass(cls): def tearDown(self): RequestType.USE_GET_REQUEST = True - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_returns_datatable_object(self, mock): with self.assertWarns(UserWarning): df = nasdaqdatalink.get_table('ZACKS/FC', params={}) self.assertIsInstance(df, pandas.core.frame.DataFrame) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_with_code_returns_datatable_object(self, mock): with self.assertWarns(UserWarning): df = nasdaqdatalink.get_table('AR/MWCF', code="ICEP_WAC_Z2017_S") self.assertIsInstance(df, pandas.core.frame.DataFrame) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_get_table_calls_connection_with_no_params_for_get_request(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_table('ZACKS/FC') @@ -59,7 +59,7 @@ def test_get_table_calls_connection_with_no_params_for_get_request(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_get_table_calls_connection_with_no_params_for_post_request(self, mock): with self.assertWarns(UserWarning): RequestType.USE_GET_REQUEST = False @@ -69,7 +69,7 @@ def test_get_table_calls_connection_with_no_params_for_post_request(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_get_table_calls_connection_with_params_for_get_request(self, mock): with self.assertWarns(UserWarning): params = { @@ -93,7 +93,7 @@ def test_get_table_calls_connection_with_params_for_get_request(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_get_table_calls_connection_with_params_for_post_request(self, mock): with self.assertWarns(UserWarning): RequestType.USE_GET_REQUEST = False diff --git a/test/test_point_in_time.py b/test/test_point_in_time.py index 07918e1..f73e33b 100644 --- a/test/test_point_in_time.py +++ b/test/test_point_in_time.py @@ -26,7 +26,7 @@ def tearDownClass(cls): def tearDown(self): RequestType.USE_GET_REQUEST = True - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_asofdate_call_connection(self, mock): PointInTime( 'ZACKS/FC', @@ -38,7 +38,7 @@ def test_asofdate_call_connection(self, mock): expected = call('get', 'pit/ZACKS/FC/asofdate/2020-01-01', params={}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_from_call_connection(self, mock): PointInTime( 'ZACKS/FC', @@ -51,7 +51,7 @@ def test_from_call_connection(self, mock): expected = call('get', 'pit/ZACKS/FC/from/2020-01-01/to/2020-01-02', params={}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_between_call_connection(self, mock): PointInTime( 'ZACKS/FC', diff --git a/test/test_retries.py b/test/test_retries.py index 3028095..857580e 100644 --- a/test/test_retries.py +++ b/test/test_retries.py @@ -1,7 +1,7 @@ import unittest import json -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as Connection from nasdaqdatalink.api_config import ApiConfig from test.factories.datatable import DatatableFactory from test.helpers.httpretty_extension import httpretty From 24e8845c514f1989fba863e7fed4b0ba7dd44e57 Mon Sep 17 00:00:00 2001 From: Ren Ren Date: Fri, 1 Apr 2022 13:20:27 -0400 Subject: [PATCH 2/8] reuse adapter --- nasdaqdatalink/connection.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nasdaqdatalink/connection.py b/nasdaqdatalink/connection.py index 73a6241..72335f8 100644 --- a/nasdaqdatalink/connection.py +++ b/nasdaqdatalink/connection.py @@ -36,7 +36,7 @@ def request(http_verb, url, **options): return execute_request(http_verb, abs_url, **options) def execute_request(http_verb, url, **options): - session = get_session(url) + session = get_session() try: response = session.request(method=http_verb, @@ -67,10 +67,10 @@ def get_retries(): return retries session = requests.Session() +adapter = HTTPAdapter(max_retries=get_retries()) +session.mount(ApiConfig.api_protocol, adapter) -def get_session(url = ApiConfig.api_protocol): - adapter = HTTPAdapter(max_retries=get_retries()) - session.mount(url, adapter) +def get_session(): return session def parse(response): From c258098f3093f8d587c56a72909092f15a78c9ed Mon Sep 17 00:00:00 2001 From: Ren Ren <97460234+runawaycoast@users.noreply.github.com> Date: Tue, 5 Apr 2022 11:22:27 -0400 Subject: [PATCH 3/8] change initialization flow to allow configurate --- nasdaqdatalink/connection.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/nasdaqdatalink/connection.py b/nasdaqdatalink/connection.py index 72335f8..94b80f0 100644 --- a/nasdaqdatalink/connection.py +++ b/nasdaqdatalink/connection.py @@ -66,11 +66,14 @@ def get_retries(): return retries -session = requests.Session() -adapter = HTTPAdapter(max_retries=get_retries()) -session.mount(ApiConfig.api_protocol, adapter) +session = None def get_session(): + global session + if session is None: + session = requests.Session() + adapter = HTTPAdapter(max_retries=get_retries()) + session.mount(ApiConfig.api_protocol, adapter) return session def parse(response): From a715192f6d8ff58e7c15948061c33d34dcd55cad Mon Sep 17 00:00:00 2001 From: Ren Ren <97460234+runawaycoast@users.noreply.github.com> Date: Tue, 5 Apr 2022 15:36:20 -0400 Subject: [PATCH 4/8] change module naming, fix test --- nasdaqdatalink/model/database.py | 4 ++-- nasdaqdatalink/model/datatable.py | 4 ++-- nasdaqdatalink/operations/get.py | 4 ++-- nasdaqdatalink/operations/list.py | 6 +++--- test/test_connection.py | 10 +++++----- test/test_database.py | 4 ++-- test/test_get.py | 4 ++-- test/test_retries.py | 20 +++++++++++--------- 8 files changed, 29 insertions(+), 27 deletions(-) diff --git a/nasdaqdatalink/model/database.py b/nasdaqdatalink/model/database.py index 5cde79a..fbf9e73 100644 --- a/nasdaqdatalink/model/database.py +++ b/nasdaqdatalink/model/database.py @@ -4,7 +4,7 @@ import nasdaqdatalink.model.dataset from nasdaqdatalink.api_config import ApiConfig -import nasdaqdatalink.connection as Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.errors.data_link_error import DataLinkError from nasdaqdatalink.message import Message from nasdaqdatalink.operations.get import GetOperation @@ -43,7 +43,7 @@ def bulk_download_to_file(self, file_or_folder_path, **options): path_url = self._bulk_download_path() options['stream'] = True - r = Connection.request('get', path_url, **options) + r = connection.request('get', path_url, **options) file_path = file_or_folder_path if os.path.isdir(file_or_folder_path): file_path = file_or_folder_path + '/' + os.path.basename(urlparse(r.url).path) diff --git a/nasdaqdatalink/model/datatable.py b/nasdaqdatalink/model/datatable.py index b253764..935590e 100644 --- a/nasdaqdatalink/model/datatable.py +++ b/nasdaqdatalink/model/datatable.py @@ -3,7 +3,7 @@ from six.moves.urllib.request import urlopen -import nasdaqdatalink.connection as Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.errors.data_link_error import DataLinkError from nasdaqdatalink.message import Message from nasdaqdatalink.operations.get import GetOperation @@ -51,7 +51,7 @@ def _request_file_info(self, file_or_folder_path, **options): updated_options = Util.convert_options(request_type=request_type, **options) - r = Connection.request(request_type, url, **updated_options) + r = connection.request(request_type, url, **updated_options) response_data = r.json() diff --git a/nasdaqdatalink/operations/get.py b/nasdaqdatalink/operations/get.py index 3d70a79..efe3a26 100644 --- a/nasdaqdatalink/operations/get.py +++ b/nasdaqdatalink/operations/get.py @@ -1,7 +1,7 @@ from inflection import singularize from .operation import Operation -import nasdaqdatalink.connection as Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.util import Util @@ -21,7 +21,7 @@ def __get_raw_data__(self): path = Util.constructed_path(cls.get_path(), options['params']) - r = Connection.request('get', path, **options) + r = connection.request('get', path, **options) response_data = r.json() Util.convert_to_dates(response_data) self._raw_data = response_data[singularize(cls.lookup_key())] diff --git a/nasdaqdatalink/operations/list.py b/nasdaqdatalink/operations/list.py index fb2f5cd..6e94e78 100644 --- a/nasdaqdatalink/operations/list.py +++ b/nasdaqdatalink/operations/list.py @@ -1,5 +1,5 @@ from .operation import Operation -import nasdaqdatalink.connection as Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.util import Util from nasdaqdatalink.model.paginated_list import PaginatedList from nasdaqdatalink.utils.request_type_util import RequestType @@ -12,7 +12,7 @@ def all(cls, **options): if 'params' not in options: options['params'] = {} path = Util.constructed_path(cls.list_path(), options['params']) - r = Connection.request('get', path, **options) + r = connection.request('get', path, **options) response_data = r.json() Util.convert_to_dates(response_data) resource = cls.create_list_from_response(response_data) @@ -27,7 +27,7 @@ def page(cls, datatable, **options): updated_options = Util.convert_options(request_type=request_type, **options) - r = Connection.request(request_type, path, **updated_options) + r = connection.request(request_type, path, **updated_options) response_data = r.json() Util.convert_to_dates(response_data) diff --git a/test/test_connection.py b/test/test_connection.py index 3ee62eb..7cf1845 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -1,4 +1,4 @@ -import nasdaqdatalink.connection as Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.api_config import ApiConfig from nasdaqdatalink.errors.data_link_error import ( DataLinkError, LimitExceededError, InternalServerError, @@ -42,7 +42,7 @@ def test_nasdaqdatalink_exceptions_no_retries(self, request_method): for expected_error in data_link_errors: self.assertRaises( - expected_error[2], lambda: Connection.request(request_method, 'databases')) + expected_error[2], lambda: connection.request(request_method, 'databases')) @parameterized.expand(['GET', 'POST']) def test_parse_error(self, request_method): @@ -51,7 +51,7 @@ def test_parse_error(self, request_method): "https://data.nasdaq.com/api/v3/databases", body="not json", status=500) self.assertRaises( - DataLinkError, lambda: Connection.request(request_method, 'databases')) + DataLinkError, lambda: connection.request(request_method, 'databases')) @parameterized.expand(['GET', 'POST']) def test_non_data_link_error(self, request_method): @@ -62,7 +62,7 @@ def test_non_data_link_error(self, request_method): {'foobar': {'code': 'blah', 'message': 'something went wrong'}}), status=500) self.assertRaises( - DataLinkError, lambda: Connection.request(request_method, 'databases')) + DataLinkError, lambda: connection.request(request_method, 'databases')) @parameterized.expand(['GET', 'POST']) @patch('nasdaqdatalink.connection.execute_request') @@ -71,7 +71,7 @@ def test_build_request(self, request_method, mock): ApiConfig.api_version = '2015-04-09' params = {'per_page': 10, 'page': 2} headers = {'x-custom-header': 'header value'} - Connection.request(request_method, 'databases', headers=headers, params=params) + connection.request(request_method, 'databases', headers=headers, params=params) expected = call(request_method, 'https://data.nasdaq.com/api/v3/databases', headers={'x-custom-header': 'header value', 'x-api-token': 'api_token', diff --git a/test/test_database.py b/test/test_database.py index 0b11cec..7b38b98 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -7,7 +7,7 @@ from six.moves.urllib.parse import parse_qs, urlparse from nasdaqdatalink.api_config import ApiConfig -import nasdaqdatalink.connection as Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.errors.data_link_error import (InternalServerError, DataLinkError) from nasdaqdatalink.model.database import Database from test.factories.database import DatabaseFactory @@ -148,7 +148,7 @@ def test_get_bulk_download_url_without_download_type(self): def test_bulk_download_to_fileaccepts_download_type(self): m = mock_open() - with patch.object(Connection, 'request') as mock_method: + with patch.object(connection, 'request') as mock_method: mock_method.return_value.url = 'https://www.blah.com/download/db.zip' with patch('nasdaqdatalink.model.database.open', m, create=True): self.database.bulk_download_to_file( diff --git a/test/test_get.py b/test/test_get.py index 8e753c7..66f2ba6 100644 --- a/test/test_get.py +++ b/test/test_get.py @@ -8,7 +8,7 @@ from nasdaqdatalink.model.merged_dataset import MergedDataset from nasdaqdatalink.get import get from nasdaqdatalink.api_config import ApiConfig -import nasdaqdatalink.connection as Connection +import nasdaqdatalink.connection as connection class GetSingleDatasetTest(unittest.TestCase): @@ -36,7 +36,7 @@ def test_returns_numpys_when_requested(self): self.assertIsInstance(result, numpy.core.records.recarray) def test_setting_api_key_config(self): - mock_connection = Mock(wraps=Connection) + mock_connection = Mock(wraps=connection) with patch('nasdaqdatalink.connection.execute_request', new=mock_connection.execute_request) as mock: ApiConfig.api_key = 'api_key_configured' diff --git a/test/test_retries.py b/test/test_retries.py index 857580e..69c7653 100644 --- a/test/test_retries.py +++ b/test/test_retries.py @@ -1,7 +1,7 @@ import unittest import json -import nasdaqdatalink.connection as Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.api_config import ApiConfig from test.factories.datatable import DatatableFactory from test.helpers.httpretty_extension import httpretty @@ -28,6 +28,8 @@ def tearDown(self): class TestRetries(ModifyRetrySettingsTestCase): def setUp(self): + # reset session to None before every test + connection.session = None ApiConfig.use_retries = True super(TestRetries, self).setUp() @@ -47,13 +49,13 @@ def setUpClass(cls): def test_modifying_use_retries(self): ApiConfig.use_retries = False - retries = Connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries + retries = connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries self.assertEqual(retries.total, 0) def test_modifying_number_of_retries(self): ApiConfig.number_of_retries = 3000 - retries = Connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries + retries = connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries self.assertEqual(retries.total, ApiConfig.number_of_retries) self.assertEqual(retries.connect, ApiConfig.number_of_retries) @@ -62,19 +64,19 @@ def test_modifying_number_of_retries(self): def test_modifying_retry_backoff_factor(self): ApiConfig.retry_backoff_factor = 3000 - retries = Connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries + retries = connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries self.assertEqual(retries.backoff_factor, ApiConfig.retry_backoff_factor) def test_modifying_retry_status_codes(self): ApiConfig.retry_status_codes = [1, 2, 3] - retries = Connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries + retries = connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries self.assertEqual(retries.status_forcelist, ApiConfig.retry_status_codes) def test_modifying_max_wait_between_retries(self): ApiConfig.max_wait_between_retries = 3000 - retries = Connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries + retries = connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries self.assertEqual(retries.BACKOFF_MAX, ApiConfig.max_wait_between_retries) @httpretty.enabled @@ -87,7 +89,7 @@ def test_correct_response_returned_if_retries_succeed(self): "https://data.nasdaq.com/api/v3/databases", responses=mock_responses) - response = Connection.request('get', 'databases') + response = connection.request('get', 'databases') self.assertEqual(response.json(), self.datatable) self.assertEqual(response.status_code, self.success_response.status) @@ -100,7 +102,7 @@ def test_correct_response_exception_raised_if_retries_fail(self): "https://data.nasdaq.com/api/v3/databases", responses=mock_responses) - self.assertRaises(InternalServerError, Connection.request, 'get', 'databases') + self.assertRaises(InternalServerError, connection.request, 'get', 'databases') @httpretty.enabled def test_correct_response_exception_raised_for_errors_not_in_retry_status_codes(self): @@ -110,4 +112,4 @@ def test_correct_response_exception_raised_for_errors_not_in_retry_status_codes( "https://data.nasdaq.com/api/v3/databases", responses=mock_responses) - self.assertRaises(InternalServerError, Connection.request, 'get', 'databases') + self.assertRaises(InternalServerError, connection.request, 'get', 'databases') From d9105388e72621ae50498bb884b49b4dc2a4aabe Mon Sep 17 00:00:00 2001 From: Ren Ren <97460234+runawaycoast@users.noreply.github.com> Date: Tue, 5 Apr 2022 16:34:53 -0400 Subject: [PATCH 5/8] fix lint --- nasdaqdatalink/connection.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/nasdaqdatalink/connection.py b/nasdaqdatalink/connection.py index 94b80f0..ffd7d4a 100644 --- a/nasdaqdatalink/connection.py +++ b/nasdaqdatalink/connection.py @@ -13,6 +13,10 @@ AuthenticationError, ForbiddenError, InvalidRequestError, NotFoundError, ServiceUnavailableError) +# global session +session = None + + def request(http_verb, url, **options): if 'headers' in options: headers = options['headers'] @@ -23,9 +27,11 @@ def request(http_verb, url, **options): if ApiConfig.api_version: accept_value += ", application/vnd.data.nasdaq+json;version=%s" % ApiConfig.api_version - headers = Util.merge_to_dicts({'accept': accept_value, - 'request-source': 'python', - 'request-source-version': VERSION}, headers) + headers = Util.merge_to_dicts({ + 'accept': accept_value, + 'request-source': 'python', + 'request-source-version': VERSION + }, headers) if ApiConfig.api_key: headers = Util.merge_to_dicts({'x-api-token': ApiConfig.api_key}, headers) @@ -35,14 +41,17 @@ def request(http_verb, url, **options): return execute_request(http_verb, abs_url, **options) + def execute_request(http_verb, url, **options): session = get_session() try: - response = session.request(method=http_verb, - url=url, - verify=ApiConfig.verify_ssl, - **options) + response = session.request( + method=http_verb, + url=url, + verify=ApiConfig.verify_ssl, + **options + ) if response.status_code < 200 or response.status_code >= 300: handle_api_error(response) else: @@ -52,6 +61,7 @@ def execute_request(http_verb, url, **options): handle_api_error(e.response) raise e + def get_retries(): if not ApiConfig.use_retries: return Retry(total=0) @@ -66,16 +76,17 @@ def get_retries(): return retries -session = None def get_session(): global session if session is None: + print("initialized") session = requests.Session() adapter = HTTPAdapter(max_retries=get_retries()) session.mount(ApiConfig.api_protocol, adapter) return session + def parse(response): try: return response.json() @@ -83,7 +94,6 @@ def parse(response): raise DataLinkError(http_status=response.status_code, http_body=response.text) - def handle_api_error(resp): error_body = parse(resp) @@ -109,4 +119,4 @@ def handle_api_error(resp): } klass = d_klass.get(code_letter, DataLinkError) - raise klass(message, resp.status_code, resp.text, resp.headers, code) \ No newline at end of file + raise klass(message, resp.status_code, resp.text, resp.headers, code) From c79d4501c6e436a2cfbd84a9669206b20125155d Mon Sep 17 00:00:00 2001 From: Ren Ren <97460234+runawaycoast@users.noreply.github.com> Date: Tue, 5 Apr 2022 16:38:49 -0400 Subject: [PATCH 6/8] remove debug print --- nasdaqdatalink/connection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nasdaqdatalink/connection.py b/nasdaqdatalink/connection.py index ffd7d4a..718db1e 100644 --- a/nasdaqdatalink/connection.py +++ b/nasdaqdatalink/connection.py @@ -80,7 +80,6 @@ def get_retries(): def get_session(): global session if session is None: - print("initialized") session = requests.Session() adapter = HTTPAdapter(max_retries=get_retries()) session.mount(ApiConfig.api_protocol, adapter) From 2b0f25d955a2b998352f060a36be62e50e7cf4bd Mon Sep 17 00:00:00 2001 From: Ren Ren <97460234+runawaycoast@users.noreply.github.com> Date: Tue, 5 Apr 2022 22:14:52 -0400 Subject: [PATCH 7/8] add test for session reuse --- test/test_connection.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/test_connection.py b/test/test_connection.py index 7cf1845..c5f75e9 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -81,3 +81,15 @@ def test_build_request(self, request_method, mock): 'request-source-version': VERSION}, params={'per_page': 10, 'page': 2}) self.assertEqual(mock.call_args, expected) + + def test_session_reuse(self): + session1 = connection.get_session() + session2 = connection.get_session() + areSessionsSame = session1 is session2 + + adapter1 = connection.get_session().get_adapter(ApiConfig.api_protocol) + adapter2 = connection.get_session().get_adapter(ApiConfig.api_protocol) + areAdaptersSame = adapter1 is adapter2 + + self.assertEqual(areAdaptersSame, True) + self.assertEqual(areSessionsSame, True) From 9bc8dd0729882daddb478e08d7454dcd5f8eafa4 Mon Sep 17 00:00:00 2001 From: Ren Ren <97460234+runawaycoast@users.noreply.github.com> Date: Tue, 5 Apr 2022 22:17:50 -0400 Subject: [PATCH 8/8] fix lint --- test/test_connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_connection.py b/test/test_connection.py index c5f75e9..384d1e0 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -90,6 +90,6 @@ def test_session_reuse(self): adapter1 = connection.get_session().get_adapter(ApiConfig.api_protocol) adapter2 = connection.get_session().get_adapter(ApiConfig.api_protocol) areAdaptersSame = adapter1 is adapter2 - + self.assertEqual(areAdaptersSame, True) self.assertEqual(areSessionsSame, True)