Skip to content

Commit

Permalink
python3Packages.apache-airflow: add patch for CVE-2023-22884
Browse files Browse the repository at this point in the history
  • Loading branch information
risicle committed Feb 12, 2023
1 parent 8e82401 commit 7fa3ff5
Show file tree
Hide file tree
Showing 2 changed files with 222 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
Based on upstream 45dd0c484e16ff56800cc9c047f56b4a909d2d0d with
minor adjustments to apply to airflow 2.4.3

diff --git a/airflow/providers/apache/hive/transfers/hive_to_mysql.py b/airflow/providers/apache/hive/transfers/hive_to_mysql.py
index 9c01b3162b..041f2940a7 100644
--- a/airflow/providers/apache/hive/transfers/hive_to_mysql.py
+++ b/airflow/providers/apache/hive/transfers/hive_to_mysql.py
@@ -53,9 +53,9 @@ class HiveToMySqlOperator(BaseOperator):
import, typically used to move data from staging to
production and issue cleanup commands. (templated)
:param bulk_load: flag to use bulk_load option. This loads mysql directly
- from a tab-delimited text file using the LOAD DATA LOCAL INFILE command.
- This option requires an extra connection parameter for the
- destination MySQL connection: {'local_infile': true}.
+ from a tab-delimited text file using the LOAD DATA LOCAL INFILE command. The MySQL
+ server must support loading local files via this command (it is disabled by default).
+
:param hive_conf:
"""

@@ -108,7 +108,7 @@ class HiveToMySqlOperator(BaseOperator):
output_header=False,
hive_conf=hive_conf,
)
- mysql = self._call_preoperator()
+ mysql = self._call_preoperator(local_infile=self.bulk_load)
mysql.bulk_load(table=self.mysql_table, tmp_file=tmp_file.name)
else:
hive_results = hive.get_records(self.sql, parameters=hive_conf)
@@ -121,8 +121,8 @@ class HiveToMySqlOperator(BaseOperator):

self.log.info("Done.")

- def _call_preoperator(self):
- mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
+ def _call_preoperator(self, local_infile: bool = False) -> MySqlHook:
+ mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id, local_infile=local_infile)
if self.mysql_preoperator:
self.log.info("Running MySQL preoperator")
mysql.run(self.mysql_preoperator)
diff --git a/airflow/providers/mysql/hooks/mysql.py b/airflow/providers/mysql/hooks/mysql.py
index 508ae6c56c..21ddc24a0b 100644
--- a/airflow/providers/mysql/hooks/mysql.py
+++ b/airflow/providers/mysql/hooks/mysql.py
@@ -44,8 +44,12 @@ class MySqlHook(DbApiHook):
in extras.
extras example: ``{"iam":true, "aws_conn_id":"my_aws_conn"}``

+ You can also add "local_infile" parameter to determine whether local_infile feature of MySQL client is
+ going to be enabled (it is disabled by default).
+
:param schema: The MySQL database schema to connect to.
:param connection: The :ref:`MySQL connection id <howto/connection:mysql>` used for MySQL credentials.
+ :param local_infile: Boolean flag determining if local_infile should be used
"""

conn_name_attr = 'mysql_conn_id'
@@ -58,6 +62,7 @@ class MySqlHook(DbApiHook):
super().__init__(*args, **kwargs)
self.schema = kwargs.pop("schema", None)
self.connection = kwargs.pop("connection", None)
+ self.local_infile = kwargs.pop("local_infile", False)

def set_autocommit(self, conn: MySQLConnectionTypes, autocommit: bool) -> None:
"""
@@ -119,8 +124,7 @@ class MySqlHook(DbApiHook):
conn_config["cursorclass"] = MySQLdb.cursors.DictCursor
elif (conn.extra_dejson["cursor"]).lower() == 'ssdictcursor':
conn_config["cursorclass"] = MySQLdb.cursors.SSDictCursor
- local_infile = conn.extra_dejson.get('local_infile', False)
- if conn.extra_dejson.get('ssl', False):
+ if conn.extra_dejson.get("ssl", False):
# SSL parameter for MySQL has to be a dictionary and in case
# of extra/dejson we can get string if extra is passed via
# URL parameters
@@ -130,7 +134,7 @@ class MySqlHook(DbApiHook):
conn_config['ssl'] = dejson_ssl
if conn.extra_dejson.get('unix_socket'):
conn_config['unix_socket'] = conn.extra_dejson['unix_socket']
- if local_infile:
+ if self.local_infile:
conn_config["local_infile"] = 1
return conn_config

@@ -143,7 +147,7 @@ class MySqlHook(DbApiHook):
'port': int(conn.port) if conn.port else 3306,
}

- if conn.extra_dejson.get('allow_local_infile', False):
+ if self.local_infile:
conn_config["allow_local_infile"] = True

return conn_config
diff --git a/airflow/providers/mysql/transfers/vertica_to_mysql.py b/airflow/providers/mysql/transfers/vertica_to_mysql.py
index 595b2cb01b..a8ff591d52 100644
--- a/airflow/providers/mysql/transfers/vertica_to_mysql.py
+++ b/airflow/providers/mysql/transfers/vertica_to_mysql.py
@@ -52,9 +52,8 @@ class VerticaToMySqlOperator(BaseOperator):
import, typically used to move data from staging to production
and issue cleanup commands. (templated)
:param bulk_load: flag to use bulk_load option. This loads MySQL directly
- from a tab-delimited text file using the LOAD DATA LOCAL INFILE command.
- This option requires an extra connection parameter for the
- destination MySQL connection: {'local_infile': true}.
+ from a tab-delimited text file using the LOAD DATA LOCAL INFILE command. The MySQL
+ server must support loading local files via this command (it is disabled by default).
"""

template_fields: Sequence[str] = ('sql', 'mysql_table', 'mysql_preoperator', 'mysql_postoperator')
@@ -89,7 +88,7 @@ class VerticaToMySqlOperator(BaseOperator):

def execute(self, context: 'Context'):
vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id)
- mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
+ mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id, local_infile=self.bulk_load)

if self.bulk_load:
self._bulk_load_transfer(mysql, vertica)
diff --git a/docs/apache-airflow-providers-mysql/connections/mysql.rst b/docs/apache-airflow-providers-mysql/connections/mysql.rst
index 95d8e7aaba..e8b8091b83 100644
--- a/docs/apache-airflow-providers-mysql/connections/mysql.rst
+++ b/docs/apache-airflow-providers-mysql/connections/mysql.rst
@@ -46,9 +46,6 @@ Extra (optional)
* ``charset``: specify charset of the connection
* ``cursor``: one of ``sscursor``, ``dictcursor``, ``ssdictcursor`` . Specifies cursor class to be
used
- * ``local_infile``: controls MySQL's LOCAL capability (permitting local data loading by
- clients). See `MySQLdb docs <https://mysqlclient.readthedocs.io/user_guide.html>`_
- for details.
* ``unix_socket``: UNIX socket used instead of the default socket.
* ``ssl``: Dictionary of SSL parameters that control connecting using SSL. Those
parameters are server specific and should contain ``ca``, ``cert``, ``key``, ``capath``,
@@ -99,14 +96,7 @@ Extra (optional)
If encounter UnicodeDecodeError while working with MySQL connection, check
the charset defined is matched to the database charset.

- For ``mysql-connector-python`` the following extras are supported:
+ For ``mysql-connector-python`` no extras are supported:

- * ``allow_local_infile``: Whether to enable ``LOAD DATA LOCAL INFILE`` capability.
-
- Example "extras" field:
-
- .. code-block:: json
-
- {
- "allow_local_infile": true
- }
+In both cases, when you want to use ``LOAD DATA LOCAL INFILE`` SQL commands of MySQl, you need to create the
+Hook with "local_infile" parameter set to True.
diff --git a/tests/providers/apache/hive/transfers/test_hive_to_mysql.py b/tests/providers/apache/hive/transfers/test_hive_to_mysql.py
index 7e056a17ba..97c4680931 100644
--- a/tests/providers/apache/hive/transfers/test_hive_to_mysql.py
+++ b/tests/providers/apache/hive/transfers/test_hive_to_mysql.py
@@ -44,9 +44,11 @@ class TestHiveToMySqlTransfer(TestHiveEnvironment):
def test_execute(self, mock_hive_hook, mock_mysql_hook):
HiveToMySqlOperator(**self.kwargs).execute(context={})

- mock_hive_hook.assert_called_once_with(hiveserver2_conn_id=self.kwargs['hiveserver2_conn_id'])
- mock_hive_hook.return_value.get_records.assert_called_once_with('sql', parameters={})
- mock_mysql_hook.assert_called_once_with(mysql_conn_id=self.kwargs['mysql_conn_id'])
+ mock_hive_hook.assert_called_once_with(hiveserver2_conn_id=self.kwargs["hiveserver2_conn_id"])
+ mock_hive_hook.return_value.get_records.assert_called_once_with("sql", parameters={})
+ mock_mysql_hook.assert_called_once_with(
+ mysql_conn_id=self.kwargs["mysql_conn_id"], local_infile=False
+ )
mock_mysql_hook.return_value.insert_rows.assert_called_once_with(
table=self.kwargs['mysql_table'], rows=mock_hive_hook.return_value.get_records.return_value
)
@@ -81,6 +83,7 @@ class TestHiveToMySqlTransfer(TestHiveEnvironment):

HiveToMySqlOperator(**self.kwargs).execute(context=context)

+ mock_mysql_hook.assert_called_once_with(mysql_conn_id=self.kwargs["mysql_conn_id"], local_infile=True)
mock_tmp_file_context.assert_called_once_with()
mock_hive_hook.return_value.to_csv.assert_called_once_with(
self.kwargs['sql'],
diff --git a/tests/providers/mysql/hooks/test_mysql.py b/tests/providers/mysql/hooks/test_mysql.py
index 911b9765c5..85d01ca830 100644
--- a/tests/providers/mysql/hooks/test_mysql.py
+++ b/tests/providers/mysql/hooks/test_mysql.py
@@ -119,7 +119,7 @@ class TestMySqlHookConn(unittest.TestCase):

@mock.patch('MySQLdb.connect')
def test_get_conn_local_infile(self, mock_connect):
- self.connection.extra = json.dumps({'local_infile': True})
+ self.db_hook.local_infile = True
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
@@ -208,8 +208,8 @@ class TestMySqlHookConnMySqlConnectorPython(unittest.TestCase):
@mock.patch('mysql.connector.connect')
def test_get_conn_allow_local_infile(self, mock_connect):
extra_dict = self.connection.extra_dejson
- extra_dict.update(allow_local_infile=True)
self.connection.extra = json.dumps(extra_dict)
+ self.db_hook.local_infile = True
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
@@ -391,7 +391,7 @@ class TestMySql(unittest.TestCase):
@mock.patch.dict(
'os.environ',
{
- 'AIRFLOW_CONN_AIRFLOW_DB': 'mysql://root@mysql/airflow?charset=utf8mb4&local_infile=1',
+ "AIRFLOW_CONN_AIRFLOW_DB": "mysql://root@mysql/airflow?charset=utf8mb4",
},
)
def test_mysql_hook_test_bulk_load(self, client):
@@ -404,7 +404,7 @@ class TestMySql(unittest.TestCase):
f.write("\n".join(records).encode('utf8'))
f.flush()

- hook = MySqlHook('airflow_db')
+ hook = MySqlHook("airflow_db", local_infile=True)
with closing(hook.get_conn()) as conn:
with closing(conn.cursor()) as cursor:
cursor.execute(
4 changes: 4 additions & 0 deletions pkgs/development/python-modules/apache-airflow/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ buildPythonPackage rec {
# above
INSTALL_PROVIDERS_FROM_SOURCES = "true";

patches = [
./2.4.3-CVE-2023-22884.patch
];

postPatch = ''
substituteInPlace setup.cfg \
--replace "colorlog>=4.0.2, <5.0" "colorlog" \
Expand Down

0 comments on commit 7fa3ff5

Please sign in to comment.