Skip to content

Commit

Permalink
🐛 Correctly pass kpm_kwargs and log_model_kwargs to pipeline_ml_facto…
Browse files Browse the repository at this point in the history
…ry instead of always using default values (#329)
  • Loading branch information
Galileo-Galilei committed Jun 18, 2022
1 parent a5c5781 commit 0deafc9
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 13 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## [Unreleased]

### Fixed

- :bug: Make ``pipeline_ml_factory`` correctly pass ``kpm_kwargs`` and ``log_model_kwargs`` instead of always using the default values. ([#329](https://github.com/Galileo-Galilei/kedro-mlflow/issues/329))

## [0.11.0] - 2022-06-18

### Added
Expand Down
21 changes: 10 additions & 11 deletions kedro_mlflow/pipeline/pipeline_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def __init__(
when loaded in memory
- `runner`: the kedro runner to run the model with
log_model_kwargs:
extra arguments to be passed to `mlflow.pyfunc.log_model`
- "signature" accepts an extra "auto" which automatically infer the signature
based on "input_name" dataset
extra arguments to be passed to `mlflow.pyfunc.log_model`, e.g.:
- "signature" accepts an extra "auto" which automatically infer the signature
based on "input_name" dataset
"""

Expand All @@ -80,16 +80,11 @@ def __init__(
self.input_name = input_name
# they will be passed to KedroPipelineModel to enable flexibility

kpm_kwargs_with_default = self.KPM_KWARGS_DEFAULT.copy()
kpm_kwargs = kpm_kwargs or {}
kpm_kwargs_with_default.update(kpm_kwargs)
self.kpm_kwargs = kpm_kwargs_with_default
self.kpm_kwargs = {**self.KPM_KWARGS_DEFAULT, **kpm_kwargs}

log_model_kwargs_with_default = self.LOG_MODEL_KWARGS_DEFAULT.copy()
log_model_kwargs = log_model_kwargs or {}
log_model_kwargs_with_default.update(log_model_kwargs)
self.log_model_kwargs = log_model_kwargs_with_default

self.log_model_kwargs = {**self.LOG_MODEL_KWARGS_DEFAULT, **log_model_kwargs}
self._check_consistency()

@property
Expand Down Expand Up @@ -167,7 +162,11 @@ def _check_consistency(self) -> None:

def _turn_pipeline_to_ml(self, pipeline: Pipeline):
return PipelineML(
nodes=pipeline.nodes, inference=self.inference, input_name=self.input_name
nodes=pipeline.nodes,
inference=self.inference,
input_name=self.input_name,
kpm_kwargs=self.kpm_kwargs,
log_model_kwargs=self.log_model_kwargs,
)

def only_nodes(self, *node_names: str) -> "Pipeline": # pragma: no cover
Expand Down
63 changes: 62 additions & 1 deletion tests/framework/hooks/test_hook_pipeline_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,6 @@ def test_mlflow_hook_save_pipeline_ml_with_signature(
input_name="raw_data",
log_model_kwargs={
"conda_env": env_from_dict,
"artifact_path": "model",
"signature": model_signature,
},
)
Expand Down Expand Up @@ -442,3 +441,65 @@ def test_mlflow_hook_save_pipeline_ml_with_signature(
# test : parameters should have been logged
trained_model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model")
assert trained_model.metadata.signature == expected_signature


@pytest.mark.parametrize(
"artifact_path,expected_artifact_path",
([None, "model"], ["my_custom_model", "my_custom_model"]),
)
def test_mlflow_hook_save_pipeline_ml_with_artifact_path(
kedro_project_with_mlflow_conf,
env_from_dict,
dummy_pipeline,
dummy_catalog,
dummy_run_params,
artifact_path,
expected_artifact_path,
):
# config_with_base_mlflow_conf is a conftest fixture
bootstrap_project(kedro_project_with_mlflow_conf)
with KedroSession.create(project_path=kedro_project_with_mlflow_conf) as session:
mlflow_hook = MlflowHook()
runner = SequentialRunner()

log_model_kwargs = {
"conda_env": env_from_dict,
}
if artifact_path is not None:
# we need to test what happens if the key is NOT present
log_model_kwargs["artifact_path"] = artifact_path

pipeline_to_run = pipeline_ml_factory(
training=dummy_pipeline.only_nodes_with_tags("training"),
inference=dummy_pipeline.only_nodes_with_tags("inference"),
input_name="raw_data",
log_model_kwargs=log_model_kwargs,
)

context = session.load_context()
mlflow_hook.after_context_created(context)
mlflow_hook.after_catalog_created(
catalog=dummy_catalog,
# `after_catalog_created` is not using any of arguments bellow,
# so we are setting them to empty values.
conf_catalog={},
conf_creds={},
feed_dict={},
save_version="",
load_versions="",
)
mlflow_hook.before_pipeline_run(
run_params=dummy_run_params, pipeline=pipeline_to_run, catalog=dummy_catalog
)
runner.run(pipeline_to_run, dummy_catalog, session._hook_manager)
run_id = mlflow.active_run().info.run_id
mlflow_hook.after_pipeline_run(
run_params=dummy_run_params, pipeline=pipeline_to_run, catalog=dummy_catalog
)

# test : parameters should have been logged
trained_model = mlflow.pyfunc.load_model(
f"runs:/{run_id}/{expected_artifact_path}"
)
# the real test is that the model is loaded without error
assert trained_model is not None
1 change: 0 additions & 1 deletion tests/io/models/test_mlflow_model_logger_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,6 @@ def test_pyfunc_flavor_python_model_save_and_load(
model_config2 = model_config.copy()
model_config2["config"]["run_id"] = current_run_id
mlflow_model_ds2 = MlflowModelLoggerDataSet.from_config(**model_config2)
print(model_config)

loaded_model = mlflow_model_ds2.load()

Expand Down

0 comments on commit 0deafc9

Please sign in to comment.