From 611e41e814eb62385b1f983432beae375dbcaf25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yolan=20Honor=C3=A9-Roug=C3=A9?= Date: Sun, 27 Nov 2022 22:51:22 +0100 Subject: [PATCH] :sparkles: Add an MlflowModelRegistryDataSet to load from the mlflow model registry (#260) --- CHANGELOG.md | 2 + docs/source/07_python_objects/01_DataSets.md | 46 ++++++++++- docs/source/08_API/kedro_mlflow.io.rst | 5 ++ kedro_mlflow/io/models/__init__.py | 1 + .../models/mlflow_model_registry_dataset.py | 80 +++++++++++++++++++ .../test_mlflow_model_registry_dataset.py | 59 ++++++++++++++ 6 files changed, 192 insertions(+), 1 deletion(-) create mode 100644 kedro_mlflow/io/models/mlflow_model_registry_dataset.py create mode 100644 tests/io/models/test_mlflow_model_registry_dataset.py diff --git a/CHANGELOG.md b/CHANGELOG.md index b1b401a3..e418b2b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ ## [0.11.6] - 2023-01-09 +- :sparkles: Added a `MlflowModelRegistryDataSet` in `kedro_mlflow.io.models` to enable fetching a mlflow model from the mlflow model registry by its name([#260](https://github.com/Galileo-Galilei/kedro-mlflow/issues/260)) + ### Changed - :sparkles: `kedro-mlflow` now uses the default configuration (ignoring `mlflow.yml`) if an active run already exists in the process where the pipeline is started, and uses this active run for logging. This enables using ` kedro-mlflow` with an orchestrator which starts mlflow itself before running kedro (e.g. airflow, the `mlflow run` command, AzureML...) ([#358](https://github.com/Galileo-Galilei/kedro-mlflow/issues/358)) diff --git a/docs/source/07_python_objects/01_DataSets.md b/docs/source/07_python_objects/01_DataSets.md index 073a8080..3a791a08 100644 --- a/docs/source/07_python_objects/01_DataSets.md +++ b/docs/source/07_python_objects/01_DataSets.md @@ -126,7 +126,7 @@ The ``MlflowModelLoggerDataSet`` accepts the following arguments: - save_args (Dict[str, Any], optional): Arguments to `save_model` function from specified `flavor`. Defaults to None. - version (Version, optional): Kedro version to use. Defaults to None. -The use is very similar to MlflowModelLoggerDataSet, but that you specify a filepath instead of a `run_id`: +The use is very similar to ``MlflowModelLoggerDataSet``, but you have to specify a local ``filepath`` instead of a `run_id`: ```python from kedro_mlflow.io.models import MlflowModelLoggerDataSet @@ -158,3 +158,47 @@ my_model: filepath: path/to/where/you/want/model version: ``` + +### ``MlflowModelRegistryDataSet`` + +The ``MlflowModelRegistryDataSet`` accepts the following arguments: + +- model_name (str): The name of the registered model is the mlflow registry +- stage_or_version (str): A valid stage (either "staging" or "production") or version number for the registred model.Default to "latest" which fetch the last version and the higher "stage" available. +- flavor (str): Built-in or custom MLflow model flavor module. Must be Python-importable. +- pyfunc_workflow (str, optional): Either `python_model` or `loader_module`. See [mlflow workflows](https://www.mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#workflows). +- load_args (Dict[str, Any], optional): Arguments to `load_model` function from specified `flavor`. Defaults to None. + +We assume you have registered a mlflow model first, either [with the ``MlflowClient``](https://mlflow.org/docs/latest/model-registry.html#adding-an-mlflow-model-to-the-model-registry) or [within the mlflow ui](https://mlflow.org/docs/latest/model-registry.html#ui-workflow), e.g. : + +```python +from sklearn.tree import DecisionTreeClassifier + +import mlflow +import mlflow.sklearn + +with mlflow.start_run(): + model = DecisionTreeClassifier() + + # Log the sklearn model and register as version 1 + mlflow.sklearn.log_model( + sk_model=model, artifact_path="model", registered_model_name="my_awesome_model" + ) +``` + +You can fetch the model by its name: + +```python +from kedro_mlflow.io.models import MlflowModelRegistryDataSet + +mlflow_model_logger = MlflowModelRegistryDataSet(model_name="my_awesome_model") +my_model = mlflow_model_logger.load() +``` + +and with the YAML API in the `catalog.yml` (only for loading an existing model): + +```yaml +my_model: + type: kedro_mlflow.io.models.MlflowModelRegistryDataSet + model_name: my_awesome_model +``` diff --git a/docs/source/08_API/kedro_mlflow.io.rst b/docs/source/08_API/kedro_mlflow.io.rst index c4bdbe8b..16319fcc 100644 --- a/docs/source/08_API/kedro_mlflow.io.rst +++ b/docs/source/08_API/kedro_mlflow.io.rst @@ -45,3 +45,8 @@ Models DataSet :members: :undoc-members: :show-inheritance: + +.. automodule:: kedro_mlflow.io.models.mlflow_model_registry_dataset + :members: + :undoc-members: + :show-inheritance: diff --git a/kedro_mlflow/io/models/__init__.py b/kedro_mlflow/io/models/__init__.py index f622e3f3..6d102f8d 100644 --- a/kedro_mlflow/io/models/__init__.py +++ b/kedro_mlflow/io/models/__init__.py @@ -1,2 +1,3 @@ from .mlflow_model_logger_dataset import MlflowModelLoggerDataSet +from .mlflow_model_registry_dataset import MlflowModelRegistryDataSet from .mlflow_model_saver_dataset import MlflowModelSaverDataSet diff --git a/kedro_mlflow/io/models/mlflow_model_registry_dataset.py b/kedro_mlflow/io/models/mlflow_model_registry_dataset.py new file mode 100644 index 00000000..13544270 --- /dev/null +++ b/kedro_mlflow/io/models/mlflow_model_registry_dataset.py @@ -0,0 +1,80 @@ +from typing import Any, Dict, Optional, Union + +from kedro_mlflow.io.models.mlflow_abstract_model_dataset import ( + MlflowAbstractModelDataSet, +) + + +class MlflowModelRegistryDataSet(MlflowAbstractModelDataSet): + """Wrapper for saving, logging and loading for all MLflow model flavor.""" + + def __init__( + self, + model_name: str, + stage_or_version: Union[str, int] = "latest", + flavor: Optional[str] = "mlflow.pyfunc", + pyfunc_workflow: Optional[str] = "python_model", + load_args: Optional[Dict[str, Any]] = None, + ) -> None: + """Initialize the Kedro MlflowModelRegistryDataSet. + + Parameters are passed from the Data Catalog. + + During "load", the model is pulled from MLflow model registry by its name. + "save" is not supported. + + Args: + model_name (str): The name of the registered model is the mlflow registry + stage_or_version (str): A valid stage (either "staging" or "production") or version number for the registred model. + Default to "latest" which fetch the last version and the higher "stage" available. + flavor (str): Built-in or custom MLflow model flavor module. + Must be Python-importable. + pyfunc_workflow (str, optional): Either `python_model` or `loader_module`. + See https://www.mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#workflows. + load_args (Dict[str, Any], optional): Arguments to `load_model` + function from specified `flavor`. Defaults to None. + + Raises: + DataSetError: When passed `flavor` does not exist. + """ + super().__init__( + filepath="", + flavor=flavor, + pyfunc_workflow=pyfunc_workflow, + load_args=load_args, + save_args={}, + version=None, + ) + + self.model_name = model_name + self.stage_or_version = stage_or_version + self.model_uri = f"models:/{model_name}/{stage_or_version}" + + def _load(self) -> Any: + """Loads an MLflow model from local path or from MLflow run. + + Returns: + Any: Deserialized model. + """ + + # If `run_id` is specified, pull the model from MLflow. + # TODO: enable loading from another mlflow conf (with a client with another tracking uri) + # Alternatively, use local path to load the model. + return self._mlflow_model_module.load_model( + model_uri=self.model_uri, **self._load_args + ) + + def _save(self, model: Any) -> None: + raise NotImplementedError( + "The 'save' method is not implemented for MlflowModelRegistryDataSet. You can pass 'registered_model_name' argument in 'MLflowModelLoggerDataSet(..., save_args={registered_model_name='my_model'}' to save and register a model in the same step. " + ) + + def _describe(self) -> Dict[str, Any]: + return dict( + model_uri=self.model_uri, + model_name=self.model_name, + stage_or_version=self.stage_or_version, + flavor=self._flavor, + pyfunc_workflow=self._pyfunc_workflow, + load_args=self._load_args, + ) diff --git a/tests/io/models/test_mlflow_model_registry_dataset.py b/tests/io/models/test_mlflow_model_registry_dataset.py new file mode 100644 index 00000000..63d0af82 --- /dev/null +++ b/tests/io/models/test_mlflow_model_registry_dataset.py @@ -0,0 +1,59 @@ +import mlflow +import pytest +from kedro.io.core import DataSetError +from mlflow import MlflowClient +from sklearn.tree import DecisionTreeClassifier + +from kedro_mlflow.io.models import MlflowModelRegistryDataSet + + +def test_mlflow_model_registry_save_not_implemented(tmp_path): + ml_ds = MlflowModelRegistryDataSet(model_name="demo_model") + with pytest.raises( + DataSetError, + match=r"The 'save' method is not implemented for MlflowModelRegistryDataSet", + ): + ml_ds.save(DecisionTreeClassifier()) + + +def test_mlflow_model_registry_load_given_stage_or_version(tmp_path, monkeypatch): + + # we must change the working directory because when + # using mlflow with a local database tracking, the artifacts + # are stored in a relative mlruns/ folder so we need to have + # the same working directory that the one of the tracking uri + monkeypatch.chdir(tmp_path) + tracking_uri = r"sqlite:///" + (tmp_path / "mlruns3.db").as_posix() + mlflow.set_tracking_uri(tracking_uri) + + # setup: we train 3 version of a model under a single + # registered model and stage the 2nd one + runs = {} + for i in range(3): + with mlflow.start_run(): + model = DecisionTreeClassifier() + mlflow.sklearn.log_model( + model, artifact_path="demo_model", registered_model_name="demo_model" + ) + runs[i + 1] = mlflow.active_run().info.run_id + print(f"run_{i+1}={runs[i+1]}") + + client = MlflowClient(tracking_uri=tracking_uri) + client.transition_model_version_stage(name="demo_model", version=2, stage="Staging") + + # case 1: no version is provided, we take the last one + ml_ds = MlflowModelRegistryDataSet(model_name="demo_model") + loaded_model = ml_ds.load() + assert loaded_model.metadata.run_id == runs[3] + + # case 2: a stage is provided, we take the last model with this stage + ml_ds = MlflowModelRegistryDataSet( + model_name="demo_model", stage_or_version="staging" + ) + loaded_model = ml_ds.load() + assert loaded_model.metadata.run_id == runs[2] + + # case 3: a version is provided, we take the associated model + ml_ds = MlflowModelRegistryDataSet(model_name="demo_model", stage_or_version="1") + loaded_model = ml_ds.load() + assert loaded_model.metadata.run_id == runs[1]