diff --git a/cabby/abstract.py b/cabby/abstract.py index 73be118..1e586b7 100644 --- a/cabby/abstract.py +++ b/cabby/abstract.py @@ -158,7 +158,9 @@ def prepare_generic_session(self): key_file=self.key_file, key_password=self.key_password, ca_cert=self.ca_cert, - verify_ssl=self.verify_ssl) + verify_ssl=self.verify_ssl, + jwt_token=self.jwt_token, + ) def _execute_request(self, request, uri=None, service_type=None): ''' diff --git a/cabby/dispatcher.py b/cabby/dispatcher.py index 2fc17aa..cc01aa4 100644 --- a/cabby/dispatcher.py +++ b/cabby/dispatcher.py @@ -304,16 +304,17 @@ def __call__(self, r): def get_generic_session( - proxies=None, - headers=None, - username=None, - password=None, - cert_file=None, - key_file=None, - key_password=None, - ca_cert=None, - verify_ssl=True): - + proxies=None, + headers=None, + username=None, + password=None, + cert_file=None, + key_file=None, + key_password=None, + ca_cert=None, + verify_ssl=True, + jwt_token=None, +): session = requests.Session() if ca_cert: session.verify = ca_cert @@ -328,6 +329,8 @@ def get_generic_session( session.auth = HTTPBasicAuth(username, password) if cert_file and key_file: session.cert = (cert_file, key_file) + if jwt_token: + session.auth = JWTAuth(jwt_token) session._cabby_key_password = key_password return session diff --git a/tests/test_common.py b/tests/test_common.py index 3d5527a..ca93a6a 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -306,3 +306,28 @@ def poll_callback(request): list(client.poll(collection_name="X", uri="/poll")) assert client.jwt_token == second_token + + +@pytest.mark.parametrize("version", [11, 10]) +@responses.activate +def test_jwt_token_when_set_directly(version): + given_token = "abcd" + client = make_client(version) + + # The purpose of this test is to check that this assignment has effect: + client.jwt_token = given_token + + def poll_callback(request): + _, _, token = request.headers["Authorization"].partition("Bearer ") + assert token == given_token + return (200, make_taxii_headers(version), get_fix(version).POLL_RESPONSE) + + responses.mock._matches.append( + responses.CallbackResponse( + responses.POST, + url="http://example.localhost/poll", + callback=poll_callback, + stream=True, + ) + ) + list(client.poll(collection_name="X", uri="/poll"))