Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[22.11] python3Packages.apache-airflow: add patch for CVE-2023-22884 #214588

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
Based on upstream 45dd0c484e16ff56800cc9c047f56b4a909d2d0d with
minor adjustments to apply to airflow 2.4.3

diff --git a/airflow/providers/apache/hive/transfers/hive_to_mysql.py b/airflow/providers/apache/hive/transfers/hive_to_mysql.py
index 9c01b3162b..041f2940a7 100644
--- a/airflow/providers/apache/hive/transfers/hive_to_mysql.py
+++ b/airflow/providers/apache/hive/transfers/hive_to_mysql.py
@@ -53,9 +53,9 @@ class HiveToMySqlOperator(BaseOperator):
import, typically used to move data from staging to
production and issue cleanup commands. (templated)
:param bulk_load: flag to use bulk_load option. This loads mysql directly
- from a tab-delimited text file using the LOAD DATA LOCAL INFILE command.
- This option requires an extra connection parameter for the
- destination MySQL connection: {'local_infile': true}.
+ from a tab-delimited text file using the LOAD DATA LOCAL INFILE command. The MySQL
+ server must support loading local files via this command (it is disabled by default).
+
:param hive_conf:
"""

@@ -108,7 +108,7 @@ class HiveToMySqlOperator(BaseOperator):
output_header=False,
hive_conf=hive_conf,
)
- mysql = self._call_preoperator()
+ mysql = self._call_preoperator(local_infile=self.bulk_load)
mysql.bulk_load(table=self.mysql_table, tmp_file=tmp_file.name)
else:
hive_results = hive.get_records(self.sql, parameters=hive_conf)
@@ -121,8 +121,8 @@ class HiveToMySqlOperator(BaseOperator):

self.log.info("Done.")

- def _call_preoperator(self):
- mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
+ def _call_preoperator(self, local_infile: bool = False) -> MySqlHook:
+ mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id, local_infile=local_infile)
if self.mysql_preoperator:
self.log.info("Running MySQL preoperator")
mysql.run(self.mysql_preoperator)
diff --git a/airflow/providers/mysql/hooks/mysql.py b/airflow/providers/mysql/hooks/mysql.py
index 508ae6c56c..21ddc24a0b 100644
--- a/airflow/providers/mysql/hooks/mysql.py
+++ b/airflow/providers/mysql/hooks/mysql.py
@@ -44,8 +44,12 @@ class MySqlHook(DbApiHook):
in extras.
extras example: ``{"iam":true, "aws_conn_id":"my_aws_conn"}``

+ You can also add "local_infile" parameter to determine whether local_infile feature of MySQL client is
+ going to be enabled (it is disabled by default).
+
:param schema: The MySQL database schema to connect to.
:param connection: The :ref:`MySQL connection id <howto/connection:mysql>` used for MySQL credentials.
+ :param local_infile: Boolean flag determining if local_infile should be used
"""

conn_name_attr = 'mysql_conn_id'
@@ -58,6 +62,7 @@ class MySqlHook(DbApiHook):
super().__init__(*args, **kwargs)
self.schema = kwargs.pop("schema", None)
self.connection = kwargs.pop("connection", None)
+ self.local_infile = kwargs.pop("local_infile", False)

def set_autocommit(self, conn: MySQLConnectionTypes, autocommit: bool) -> None:
"""
@@ -119,8 +124,7 @@ class MySqlHook(DbApiHook):
conn_config["cursorclass"] = MySQLdb.cursors.DictCursor
elif (conn.extra_dejson["cursor"]).lower() == 'ssdictcursor':
conn_config["cursorclass"] = MySQLdb.cursors.SSDictCursor
- local_infile = conn.extra_dejson.get('local_infile', False)
- if conn.extra_dejson.get('ssl', False):
+ if conn.extra_dejson.get("ssl", False):
# SSL parameter for MySQL has to be a dictionary and in case
# of extra/dejson we can get string if extra is passed via
# URL parameters
@@ -130,7 +134,7 @@ class MySqlHook(DbApiHook):
conn_config['ssl'] = dejson_ssl
if conn.extra_dejson.get('unix_socket'):
conn_config['unix_socket'] = conn.extra_dejson['unix_socket']
- if local_infile:
+ if self.local_infile:
conn_config["local_infile"] = 1
return conn_config

@@ -143,7 +147,7 @@ class MySqlHook(DbApiHook):
'port': int(conn.port) if conn.port else 3306,
}

- if conn.extra_dejson.get('allow_local_infile', False):
+ if self.local_infile:
conn_config["allow_local_infile"] = True

return conn_config
diff --git a/airflow/providers/mysql/transfers/vertica_to_mysql.py b/airflow/providers/mysql/transfers/vertica_to_mysql.py
index 595b2cb01b..a8ff591d52 100644
--- a/airflow/providers/mysql/transfers/vertica_to_mysql.py
+++ b/airflow/providers/mysql/transfers/vertica_to_mysql.py
@@ -52,9 +52,8 @@ class VerticaToMySqlOperator(BaseOperator):
import, typically used to move data from staging to production
and issue cleanup commands. (templated)
:param bulk_load: flag to use bulk_load option. This loads MySQL directly
- from a tab-delimited text file using the LOAD DATA LOCAL INFILE command.
- This option requires an extra connection parameter for the
- destination MySQL connection: {'local_infile': true}.
+ from a tab-delimited text file using the LOAD DATA LOCAL INFILE command. The MySQL
+ server must support loading local files via this command (it is disabled by default).
"""

template_fields: Sequence[str] = ('sql', 'mysql_table', 'mysql_preoperator', 'mysql_postoperator')
@@ -89,7 +88,7 @@ class VerticaToMySqlOperator(BaseOperator):

def execute(self, context: 'Context'):
vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id)
- mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
+ mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id, local_infile=self.bulk_load)

if self.bulk_load:
self._bulk_load_transfer(mysql, vertica)
diff --git a/docs/apache-airflow-providers-mysql/connections/mysql.rst b/docs/apache-airflow-providers-mysql/connections/mysql.rst
index 95d8e7aaba..e8b8091b83 100644
--- a/docs/apache-airflow-providers-mysql/connections/mysql.rst
+++ b/docs/apache-airflow-providers-mysql/connections/mysql.rst
@@ -46,9 +46,6 @@ Extra (optional)
* ``charset``: specify charset of the connection
* ``cursor``: one of ``sscursor``, ``dictcursor``, ``ssdictcursor`` . Specifies cursor class to be
used
- * ``local_infile``: controls MySQL's LOCAL capability (permitting local data loading by
- clients). See `MySQLdb docs <https://mysqlclient.readthedocs.io/user_guide.html>`_
- for details.
* ``unix_socket``: UNIX socket used instead of the default socket.
* ``ssl``: Dictionary of SSL parameters that control connecting using SSL. Those
parameters are server specific and should contain ``ca``, ``cert``, ``key``, ``capath``,
@@ -99,14 +96,7 @@ Extra (optional)
If encounter UnicodeDecodeError while working with MySQL connection, check
the charset defined is matched to the database charset.

- For ``mysql-connector-python`` the following extras are supported:
+ For ``mysql-connector-python`` no extras are supported:

- * ``allow_local_infile``: Whether to enable ``LOAD DATA LOCAL INFILE`` capability.
-
- Example "extras" field:
-
- .. code-block:: json
-
- {
- "allow_local_infile": true
- }
+In both cases, when you want to use ``LOAD DATA LOCAL INFILE`` SQL commands of MySQl, you need to create the
+Hook with "local_infile" parameter set to True.
diff --git a/tests/providers/apache/hive/transfers/test_hive_to_mysql.py b/tests/providers/apache/hive/transfers/test_hive_to_mysql.py
index 7e056a17ba..97c4680931 100644
--- a/tests/providers/apache/hive/transfers/test_hive_to_mysql.py
+++ b/tests/providers/apache/hive/transfers/test_hive_to_mysql.py
@@ -44,9 +44,11 @@ class TestHiveToMySqlTransfer(TestHiveEnvironment):
def test_execute(self, mock_hive_hook, mock_mysql_hook):
HiveToMySqlOperator(**self.kwargs).execute(context={})

- mock_hive_hook.assert_called_once_with(hiveserver2_conn_id=self.kwargs['hiveserver2_conn_id'])
- mock_hive_hook.return_value.get_records.assert_called_once_with('sql', parameters={})
- mock_mysql_hook.assert_called_once_with(mysql_conn_id=self.kwargs['mysql_conn_id'])
+ mock_hive_hook.assert_called_once_with(hiveserver2_conn_id=self.kwargs["hiveserver2_conn_id"])
+ mock_hive_hook.return_value.get_records.assert_called_once_with("sql", parameters={})
+ mock_mysql_hook.assert_called_once_with(
+ mysql_conn_id=self.kwargs["mysql_conn_id"], local_infile=False
+ )
mock_mysql_hook.return_value.insert_rows.assert_called_once_with(
table=self.kwargs['mysql_table'], rows=mock_hive_hook.return_value.get_records.return_value
)
@@ -81,6 +83,7 @@ class TestHiveToMySqlTransfer(TestHiveEnvironment):

HiveToMySqlOperator(**self.kwargs).execute(context=context)

+ mock_mysql_hook.assert_called_once_with(mysql_conn_id=self.kwargs["mysql_conn_id"], local_infile=True)
mock_tmp_file_context.assert_called_once_with()
mock_hive_hook.return_value.to_csv.assert_called_once_with(
self.kwargs['sql'],
diff --git a/tests/providers/mysql/hooks/test_mysql.py b/tests/providers/mysql/hooks/test_mysql.py
index 911b9765c5..85d01ca830 100644
--- a/tests/providers/mysql/hooks/test_mysql.py
+++ b/tests/providers/mysql/hooks/test_mysql.py
@@ -119,7 +119,7 @@ class TestMySqlHookConn(unittest.TestCase):

@mock.patch('MySQLdb.connect')
def test_get_conn_local_infile(self, mock_connect):
- self.connection.extra = json.dumps({'local_infile': True})
+ self.db_hook.local_infile = True
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
@@ -208,8 +208,8 @@ class TestMySqlHookConnMySqlConnectorPython(unittest.TestCase):
@mock.patch('mysql.connector.connect')
def test_get_conn_allow_local_infile(self, mock_connect):
extra_dict = self.connection.extra_dejson
- extra_dict.update(allow_local_infile=True)
self.connection.extra = json.dumps(extra_dict)
+ self.db_hook.local_infile = True
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
@@ -391,7 +391,7 @@ class TestMySql(unittest.TestCase):
@mock.patch.dict(
'os.environ',
{
- 'AIRFLOW_CONN_AIRFLOW_DB': 'mysql://root@mysql/airflow?charset=utf8mb4&local_infile=1',
+ "AIRFLOW_CONN_AIRFLOW_DB": "mysql://root@mysql/airflow?charset=utf8mb4",
},
)
def test_mysql_hook_test_bulk_load(self, client):
@@ -404,7 +404,7 @@ class TestMySql(unittest.TestCase):
f.write("\n".join(records).encode('utf8'))
f.flush()

- hook = MySqlHook('airflow_db')
+ hook = MySqlHook("airflow_db", local_infile=True)
with closing(hook.get_conn()) as conn:
with closing(conn.cursor()) as cursor:
cursor.execute(
4 changes: 4 additions & 0 deletions pkgs/development/python-modules/apache-airflow/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ buildPythonPackage rec {
# above
INSTALL_PROVIDERS_FROM_SOURCES = "true";

patches = [
./2.4.3-CVE-2023-22884.patch
];

postPatch = ''
substituteInPlace setup.cfg \
--replace "colorlog>=4.0.2, <5.0" "colorlog" \
Expand Down