diff --git a/qds_sdk/engine.py b/qds_sdk/engine.py index b7594b3e..c84ca366 100644 --- a/qds_sdk/engine.py +++ b/qds_sdk/engine.py @@ -33,7 +33,8 @@ def set_engine_config(self, airflow_python_version=None, is_ha=None, enable_rubix=None, - mlflow_version=None): + mlflow_version=None, + mlflow_dbtap_id=None): ''' Args: @@ -79,7 +80,7 @@ def set_engine_config(self, self.set_presto_settings(presto_version, custom_presto_config) self.set_spark_settings(spark_version, custom_spark_config) self.set_airflow_settings(dbtap_id, fernet_key, overrides, airflow_version, airflow_python_version) - self.set_mlflow_settings(mlflow_version) + self.set_mlflow_settings(mlflow_version, mlflow_dbtap_id) def set_fairscheduler_settings(self, fairscheduler_config_xml=None, @@ -127,8 +128,10 @@ def set_airflow_settings(self, self.airflow_settings['airflow_python_version'] = airflow_python_version def set_mlflow_settings(self, - mlflow_version="1.5"): + mlflow_version="1.7", + mlflow_dbtap_id=None): self.mlflow_settings['version'] = mlflow_version + self.mlflow_settings['dbtap_id'] = mlflow_dbtap_id def set_engine_config_settings(self, arguments): custom_hadoop_config = util._read_file(arguments.custom_hadoop_config_file) @@ -150,7 +153,8 @@ def set_engine_config_settings(self, arguments): airflow_version=arguments.airflow_version, airflow_python_version=arguments.airflow_python_version, enable_rubix=arguments.enable_rubix, - mlflow_version=arguments.mlflow_version) + mlflow_version=arguments.mlflow_version, + mlflow_dbtap_id=arguments.mlflow_dbtap_id) @staticmethod def engine_parser(argparser): @@ -253,4 +257,8 @@ def engine_parser(argparser): dest="mlflow_version", default=None, help="mlflow version for mlflow cluster", ) + mlflow_settings_group.add_argument("--mlflow-dbtap-id", + dest="mlflow_dbtap_id", + default=None, + help="dbtap id for mlflow cluster", ) diff --git a/tests/test_clusterv2.py b/tests/test_clusterv2.py index 0c791908..078108ff 100644 --- a/tests/test_clusterv2.py +++ b/tests/test_clusterv2.py @@ -548,7 +548,7 @@ def test_airflow_engine_config(self): def test_mlflow_engine_config(self): sys.argv = ['qds.py', '--version', 'v2', 'cluster', 'create', '--label', 'test_label', - '--flavour', 'mlflow', '--mlflow-version', '1.5'] + '--flavour', 'mlflow', '--mlflow-version', '1.7', '--mlflow-dbtap-id', '-1'] Qubole.cloud = None print_command() Connection._api_call = Mock(return_value={}) @@ -557,7 +557,8 @@ def test_mlflow_engine_config(self): {'engine_config': {'flavour': 'mlflow', 'mlflow_settings': { - 'version': '1.5' + 'version': '1.7', + 'dbtap_id': '-1' }}, 'cluster_info': {'label': ['test_label'], }})