From 4ac1bf1de31a08ce9606d8bd2cc1c9b2011e581d Mon Sep 17 00:00:00 2001 From: Rohit Chatterjee Date: Fri, 6 Dec 2024 19:55:54 +0530 Subject: [PATCH 1/5] handle inline certificates --- dbt_automation/utils/postgres.py | 33 ++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/dbt_automation/utils/postgres.py b/dbt_automation/utils/postgres.py index 148502b..7829708 100644 --- a/dbt_automation/utils/postgres.py +++ b/dbt_automation/utils/postgres.py @@ -1,6 +1,7 @@ """helpers for postgres""" import os +import tempfile from logging import basicConfig, getLogger, INFO import psycopg2 from sshtunnel import SSHTunnelForwarder @@ -36,11 +37,39 @@ def get_connection(conn_info): "user", "password", "database", - "sslmode", - "sslrootcert", ]: if key in conn_info: connect_params[key] = conn_info[key] + + if "sslmode" in conn_info: + # sslmode can be a string or a boolean or a dict + if isinstance(conn_info["sslmode"], str): + # "require", "disable", "verify-ca", "verify-full" + connect_params["sslmode"] = conn_info["sslmode"] + elif isinstance(conn_info["sslmode"], bool): + # true = require, false = disable + connect_params["sslmode"] = ( + "require" if conn_info["sslmode"] else "disable" + ) + elif ( + isinstance(conn_info["sslmode"], dict) + and "mode" in conn_info["sslmode"] + ): + # mode is "require", "disable", "verify-ca", "verify-full" etc + connect_params["sslmode"] = conn_info["sslmode"]["mode"] + if "ca_certificate" in conn_info["sslmode"]: + # connect_params['sslcert'] needs a file path but + # conn_info['sslmode']['ca_certificate'] + # is a string (i.e. the actual certificate). so we write + # it to disk and pass the file path + with tempfile.NamedTemporaryFile(delete=False) as fp: + fp.write(conn_info["ssl_mode"]["ca_certificate"].encode()) + connect_params["sslrootcert"] = fp.name + connect_params["sslcert"] = fp.name + + if "sslrootcert" in conn_info: + connect_params["sslrootcert"] = conn_info["sslrootcert"] + connection = psycopg2.connect(**connect_params) return connection From 75a14096e34763f58c2b7df87354a33b194fd4b7 Mon Sep 17 00:00:00 2001 From: Rohit Chatterjee Date: Fri, 6 Dec 2024 20:28:42 +0530 Subject: [PATCH 2/5] added tests for get_connwection --- dbt_automation/utils/postgres.py | 10 ++- tests/utils/test_postgres.py | 114 +++++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 1 deletion(-) create mode 100644 tests/utils/test_postgres.py diff --git a/dbt_automation/utils/postgres.py b/dbt_automation/utils/postgres.py index 7829708..5e3b248 100644 --- a/dbt_automation/utils/postgres.py +++ b/dbt_automation/utils/postgres.py @@ -41,6 +41,14 @@ def get_connection(conn_info): if key in conn_info: connect_params[key] = conn_info[key] + # make sure dbname is set + if "database" in connect_params: + connect_params["dbname"] = connect_params["database"] + + # ssl_mode is an alias for sslmode + if "ssl_mode" in conn_info: + conn_info["sslmode"] = conn_info["ssl_mode"] + if "sslmode" in conn_info: # sslmode can be a string or a boolean or a dict if isinstance(conn_info["sslmode"], str): @@ -63,7 +71,7 @@ def get_connection(conn_info): # is a string (i.e. the actual certificate). so we write # it to disk and pass the file path with tempfile.NamedTemporaryFile(delete=False) as fp: - fp.write(conn_info["ssl_mode"]["ca_certificate"].encode()) + fp.write(conn_info["sslmode"]["ca_certificate"].encode()) connect_params["sslrootcert"] = fp.name connect_params["sslcert"] = fp.name diff --git a/tests/utils/test_postgres.py b/tests/utils/test_postgres.py new file mode 100644 index 0000000..fb01b09 --- /dev/null +++ b/tests/utils/test_postgres.py @@ -0,0 +1,114 @@ +from unittest.mock import patch, ANY +from dbt_automation.utils.postgres import PostgresClient + + +def test_get_connection_1(): + """tests PostgresClient.get_connection""" + with patch("dbt_automation.utils.postgres.psycopg2.connect") as mock_connect: + PostgresClient.get_connection( + {"host": "HOST", "port": 1234, "user": "USER", "password": "PASSWORD"} + ) + mock_connect.assert_called_once() + mock_connect.assert_called_with( + host="HOST", + port=1234, + user="USER", + password="PASSWORD", + ) + + +def test_get_connection_2(): + """tests PostgresClient.get_connection""" + with patch("dbt_automation.utils.postgres.psycopg2.connect") as mock_connect: + PostgresClient.get_connection( + { + "host": "HOST", + "port": 1234, + "user": "USER", + "password": "PASSWORD", + "database": "DATABASE", + } + ) + mock_connect.assert_called_once() + mock_connect.assert_called_with( + host="HOST", + port=1234, + user="USER", + password="PASSWORD", + database="DATABASE", + dbname="DATABASE", + ) + + +def test_get_connection_3(): + """tests PostgresClient.get_connection""" + with patch("dbt_automation.utils.postgres.psycopg2.connect") as mock_connect: + PostgresClient.get_connection( + { + "sslmode": "verify-ca", + "sslrootcert": "/path/to/cert", + } + ) + mock_connect.assert_called_once() + mock_connect.assert_called_with( + sslmode="verify-ca", + sslrootcert="/path/to/cert", + ) + + +def test_get_connection_4(): + """tests PostgresClient.get_connection""" + with patch("dbt_automation.utils.postgres.psycopg2.connect") as mock_connect: + PostgresClient.get_connection( + { + "sslmode": True, + "sslrootcert": "/path/to/cert", + } + ) + mock_connect.assert_called_once() + mock_connect.assert_called_with( + sslmode="require", + sslrootcert="/path/to/cert", + ) + + +def test_get_connection_5(): + """tests PostgresClient.get_connection""" + with patch("dbt_automation.utils.postgres.psycopg2.connect") as mock_connect: + PostgresClient.get_connection( + { + "sslmode": False, + "sslrootcert": "/path/to/cert", + } + ) + mock_connect.assert_called_once() + mock_connect.assert_called_with( + sslmode="disable", + sslrootcert="/path/to/cert", + ) + + +def test_get_connection_6(): + """tests PostgresClient.get_connection""" + with patch("dbt_automation.utils.postgres.psycopg2.connect") as mock_connect: + PostgresClient.get_connection( + { + "sslmode": { + "mode": "disable", + } + } + ) + mock_connect.assert_called_once() + mock_connect.assert_called_with( + sslmode="disable", + ) + + +def test_get_connection_7(): + """tests PostgresClient.get_connection""" + with patch("dbt_automation.utils.postgres.psycopg2.connect") as mock_connect: + PostgresClient.get_connection( + {"sslmode": {"mode": "disable", "ca_certification": "LONG-CERTIFICATE"}} + ) + mock_connect.assert_called_once() + mock_connect.assert_called_with(sslmode="disable", sslrootcert=ANY) From 4d2e4f8c245824a45aa891c1c2e96f44efc9c845 Mon Sep 17 00:00:00 2001 From: Rohit Chatterjee Date: Fri, 6 Dec 2024 20:34:24 +0530 Subject: [PATCH 3/5] can't have both dbname and database --- dbt_automation/utils/postgres.py | 4 ---- tests/utils/test_postgres.py | 1 - 2 files changed, 5 deletions(-) diff --git a/dbt_automation/utils/postgres.py b/dbt_automation/utils/postgres.py index 5e3b248..54dd331 100644 --- a/dbt_automation/utils/postgres.py +++ b/dbt_automation/utils/postgres.py @@ -41,10 +41,6 @@ def get_connection(conn_info): if key in conn_info: connect_params[key] = conn_info[key] - # make sure dbname is set - if "database" in connect_params: - connect_params["dbname"] = connect_params["database"] - # ssl_mode is an alias for sslmode if "ssl_mode" in conn_info: conn_info["sslmode"] = conn_info["ssl_mode"] diff --git a/tests/utils/test_postgres.py b/tests/utils/test_postgres.py index fb01b09..ae89cd7 100644 --- a/tests/utils/test_postgres.py +++ b/tests/utils/test_postgres.py @@ -36,7 +36,6 @@ def test_get_connection_2(): user="USER", password="PASSWORD", database="DATABASE", - dbname="DATABASE", ) From cf6b3ce4d0fc626feed82bf66a08408c361a8813 Mon Sep 17 00:00:00 2001 From: Rohit Chatterjee Date: Fri, 6 Dec 2024 22:42:27 +0530 Subject: [PATCH 4/5] fix typo --- tests/utils/test_postgres.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_postgres.py b/tests/utils/test_postgres.py index ae89cd7..aa84c0f 100644 --- a/tests/utils/test_postgres.py +++ b/tests/utils/test_postgres.py @@ -107,7 +107,7 @@ def test_get_connection_7(): """tests PostgresClient.get_connection""" with patch("dbt_automation.utils.postgres.psycopg2.connect") as mock_connect: PostgresClient.get_connection( - {"sslmode": {"mode": "disable", "ca_certification": "LONG-CERTIFICATE"}} + {"sslmode": {"mode": "disable", "ca_certificate": "LONG-CERTIFICATE"}} ) mock_connect.assert_called_once() mock_connect.assert_called_with(sslmode="disable", sslrootcert=ANY) From 9daffabffa46c1fbba6894733bb576940ab5c3e1 Mon Sep 17 00:00:00 2001 From: Rohit Chatterjee Date: Fri, 6 Dec 2024 23:27:40 +0530 Subject: [PATCH 5/5] fix test --- tests/utils/test_postgres.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_postgres.py b/tests/utils/test_postgres.py index aa84c0f..2d9dc5b 100644 --- a/tests/utils/test_postgres.py +++ b/tests/utils/test_postgres.py @@ -110,4 +110,4 @@ def test_get_connection_7(): {"sslmode": {"mode": "disable", "ca_certificate": "LONG-CERTIFICATE"}} ) mock_connect.assert_called_once() - mock_connect.assert_called_with(sslmode="disable", sslrootcert=ANY) + mock_connect.assert_called_with(sslmode="disable", sslrootcert=ANY, sslcert=ANY)