Skip to content

Commit

Permalink
✨ Add an MlflowModelRegistryDataSet to load from the mlflow model reg…
Browse files Browse the repository at this point in the history
…istry (#260)
  • Loading branch information
Galileo-Galilei committed Jan 16, 2023
1 parent e2afa95 commit 611e41e
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

## [0.11.6] - 2023-01-09

- :sparkles: Added a `MlflowModelRegistryDataSet` in `kedro_mlflow.io.models` to enable fetching a mlflow model from the mlflow model registry by its name([#260](https://github.com/Galileo-Galilei/kedro-mlflow/issues/260))

### Changed

- :sparkles: `kedro-mlflow` now uses the default configuration (ignoring `mlflow.yml`) if an active run already exists in the process where the pipeline is started, and uses this active run for logging. This enables using ` kedro-mlflow` with an orchestrator which starts mlflow itself before running kedro (e.g. airflow, the `mlflow run` command, AzureML...) ([#358](https://github.com/Galileo-Galilei/kedro-mlflow/issues/358))
Expand Down
46 changes: 45 additions & 1 deletion docs/source/07_python_objects/01_DataSets.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ The ``MlflowModelLoggerDataSet`` accepts the following arguments:
- save_args (Dict[str, Any], optional): Arguments to `save_model` function from specified `flavor`. Defaults to None.
- version (Version, optional): Kedro version to use. Defaults to None.

The use is very similar to MlflowModelLoggerDataSet, but that you specify a filepath instead of a `run_id`:
The use is very similar to ``MlflowModelLoggerDataSet``, but you have to specify a local ``filepath`` instead of a `run_id`:

```python
from kedro_mlflow.io.models import MlflowModelLoggerDataSet
Expand Down Expand Up @@ -158,3 +158,47 @@ my_model:
filepath: path/to/where/you/want/model
version: <valid-kedro-version>
```

### ``MlflowModelRegistryDataSet``

The ``MlflowModelRegistryDataSet`` accepts the following arguments:

- model_name (str): The name of the registered model is the mlflow registry
- stage_or_version (str): A valid stage (either "staging" or "production") or version number for the registred model.Default to "latest" which fetch the last version and the higher "stage" available.
- flavor (str): Built-in or custom MLflow model flavor module. Must be Python-importable.
- pyfunc_workflow (str, optional): Either `python_model` or `loader_module`. See [mlflow workflows](https://www.mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#workflows).
- load_args (Dict[str, Any], optional): Arguments to `load_model` function from specified `flavor`. Defaults to None.

We assume you have registered a mlflow model first, either [with the ``MlflowClient``](https://mlflow.org/docs/latest/model-registry.html#adding-an-mlflow-model-to-the-model-registry) or [within the mlflow ui](https://mlflow.org/docs/latest/model-registry.html#ui-workflow), e.g. :

```python
from sklearn.tree import DecisionTreeClassifier
import mlflow
import mlflow.sklearn
with mlflow.start_run():
model = DecisionTreeClassifier()
# Log the sklearn model and register as version 1
mlflow.sklearn.log_model(
sk_model=model, artifact_path="model", registered_model_name="my_awesome_model"
)
```

You can fetch the model by its name:

```python
from kedro_mlflow.io.models import MlflowModelRegistryDataSet
mlflow_model_logger = MlflowModelRegistryDataSet(model_name="my_awesome_model")
my_model = mlflow_model_logger.load()
```

and with the YAML API in the `catalog.yml` (only for loading an existing model):

```yaml
my_model:
type: kedro_mlflow.io.models.MlflowModelRegistryDataSet
model_name: my_awesome_model
```
5 changes: 5 additions & 0 deletions docs/source/08_API/kedro_mlflow.io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,8 @@ Models DataSet
:members:
:undoc-members:
:show-inheritance:

.. automodule:: kedro_mlflow.io.models.mlflow_model_registry_dataset
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions kedro_mlflow/io/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .mlflow_model_logger_dataset import MlflowModelLoggerDataSet
from .mlflow_model_registry_dataset import MlflowModelRegistryDataSet
from .mlflow_model_saver_dataset import MlflowModelSaverDataSet
80 changes: 80 additions & 0 deletions kedro_mlflow/io/models/mlflow_model_registry_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Any, Dict, Optional, Union

from kedro_mlflow.io.models.mlflow_abstract_model_dataset import (
MlflowAbstractModelDataSet,
)


class MlflowModelRegistryDataSet(MlflowAbstractModelDataSet):
"""Wrapper for saving, logging and loading for all MLflow model flavor."""

def __init__(
self,
model_name: str,
stage_or_version: Union[str, int] = "latest",
flavor: Optional[str] = "mlflow.pyfunc",
pyfunc_workflow: Optional[str] = "python_model",
load_args: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize the Kedro MlflowModelRegistryDataSet.
Parameters are passed from the Data Catalog.
During "load", the model is pulled from MLflow model registry by its name.
"save" is not supported.
Args:
model_name (str): The name of the registered model is the mlflow registry
stage_or_version (str): A valid stage (either "staging" or "production") or version number for the registred model.
Default to "latest" which fetch the last version and the higher "stage" available.
flavor (str): Built-in or custom MLflow model flavor module.
Must be Python-importable.
pyfunc_workflow (str, optional): Either `python_model` or `loader_module`.
See https://www.mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#workflows.
load_args (Dict[str, Any], optional): Arguments to `load_model`
function from specified `flavor`. Defaults to None.
Raises:
DataSetError: When passed `flavor` does not exist.
"""
super().__init__(
filepath="",
flavor=flavor,
pyfunc_workflow=pyfunc_workflow,
load_args=load_args,
save_args={},
version=None,
)

self.model_name = model_name
self.stage_or_version = stage_or_version
self.model_uri = f"models:/{model_name}/{stage_or_version}"

def _load(self) -> Any:
"""Loads an MLflow model from local path or from MLflow run.
Returns:
Any: Deserialized model.
"""

# If `run_id` is specified, pull the model from MLflow.
# TODO: enable loading from another mlflow conf (with a client with another tracking uri)
# Alternatively, use local path to load the model.
return self._mlflow_model_module.load_model(
model_uri=self.model_uri, **self._load_args
)

def _save(self, model: Any) -> None:
raise NotImplementedError(
"The 'save' method is not implemented for MlflowModelRegistryDataSet. You can pass 'registered_model_name' argument in 'MLflowModelLoggerDataSet(..., save_args={registered_model_name='my_model'}' to save and register a model in the same step. "
)

def _describe(self) -> Dict[str, Any]:
return dict(
model_uri=self.model_uri,
model_name=self.model_name,
stage_or_version=self.stage_or_version,
flavor=self._flavor,
pyfunc_workflow=self._pyfunc_workflow,
load_args=self._load_args,
)
59 changes: 59 additions & 0 deletions tests/io/models/test_mlflow_model_registry_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import mlflow
import pytest
from kedro.io.core import DataSetError
from mlflow import MlflowClient
from sklearn.tree import DecisionTreeClassifier

from kedro_mlflow.io.models import MlflowModelRegistryDataSet


def test_mlflow_model_registry_save_not_implemented(tmp_path):
ml_ds = MlflowModelRegistryDataSet(model_name="demo_model")
with pytest.raises(
DataSetError,
match=r"The 'save' method is not implemented for MlflowModelRegistryDataSet",
):
ml_ds.save(DecisionTreeClassifier())


def test_mlflow_model_registry_load_given_stage_or_version(tmp_path, monkeypatch):

# we must change the working directory because when
# using mlflow with a local database tracking, the artifacts
# are stored in a relative mlruns/ folder so we need to have
# the same working directory that the one of the tracking uri
monkeypatch.chdir(tmp_path)
tracking_uri = r"sqlite:///" + (tmp_path / "mlruns3.db").as_posix()
mlflow.set_tracking_uri(tracking_uri)

# setup: we train 3 version of a model under a single
# registered model and stage the 2nd one
runs = {}
for i in range(3):
with mlflow.start_run():
model = DecisionTreeClassifier()
mlflow.sklearn.log_model(
model, artifact_path="demo_model", registered_model_name="demo_model"
)
runs[i + 1] = mlflow.active_run().info.run_id
print(f"run_{i+1}={runs[i+1]}")

client = MlflowClient(tracking_uri=tracking_uri)
client.transition_model_version_stage(name="demo_model", version=2, stage="Staging")

# case 1: no version is provided, we take the last one
ml_ds = MlflowModelRegistryDataSet(model_name="demo_model")
loaded_model = ml_ds.load()
assert loaded_model.metadata.run_id == runs[3]

# case 2: a stage is provided, we take the last model with this stage
ml_ds = MlflowModelRegistryDataSet(
model_name="demo_model", stage_or_version="staging"
)
loaded_model = ml_ds.load()
assert loaded_model.metadata.run_id == runs[2]

# case 3: a version is provided, we take the associated model
ml_ds = MlflowModelRegistryDataSet(model_name="demo_model", stage_or_version="1")
loaded_model = ml_ds.load()
assert loaded_model.metadata.run_id == runs[1]

0 comments on commit 611e41e

Please sign in to comment.