-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✨ Add an MlflowModelRegistryDataSet to load from the mlflow model reg…
…istry (#260)
- Loading branch information
1 parent
e2afa95
commit 611e41e
Showing
6 changed files
with
192 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .mlflow_model_logger_dataset import MlflowModelLoggerDataSet | ||
from .mlflow_model_registry_dataset import MlflowModelRegistryDataSet | ||
from .mlflow_model_saver_dataset import MlflowModelSaverDataSet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from typing import Any, Dict, Optional, Union | ||
|
||
from kedro_mlflow.io.models.mlflow_abstract_model_dataset import ( | ||
MlflowAbstractModelDataSet, | ||
) | ||
|
||
|
||
class MlflowModelRegistryDataSet(MlflowAbstractModelDataSet): | ||
"""Wrapper for saving, logging and loading for all MLflow model flavor.""" | ||
|
||
def __init__( | ||
self, | ||
model_name: str, | ||
stage_or_version: Union[str, int] = "latest", | ||
flavor: Optional[str] = "mlflow.pyfunc", | ||
pyfunc_workflow: Optional[str] = "python_model", | ||
load_args: Optional[Dict[str, Any]] = None, | ||
) -> None: | ||
"""Initialize the Kedro MlflowModelRegistryDataSet. | ||
Parameters are passed from the Data Catalog. | ||
During "load", the model is pulled from MLflow model registry by its name. | ||
"save" is not supported. | ||
Args: | ||
model_name (str): The name of the registered model is the mlflow registry | ||
stage_or_version (str): A valid stage (either "staging" or "production") or version number for the registred model. | ||
Default to "latest" which fetch the last version and the higher "stage" available. | ||
flavor (str): Built-in or custom MLflow model flavor module. | ||
Must be Python-importable. | ||
pyfunc_workflow (str, optional): Either `python_model` or `loader_module`. | ||
See https://www.mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#workflows. | ||
load_args (Dict[str, Any], optional): Arguments to `load_model` | ||
function from specified `flavor`. Defaults to None. | ||
Raises: | ||
DataSetError: When passed `flavor` does not exist. | ||
""" | ||
super().__init__( | ||
filepath="", | ||
flavor=flavor, | ||
pyfunc_workflow=pyfunc_workflow, | ||
load_args=load_args, | ||
save_args={}, | ||
version=None, | ||
) | ||
|
||
self.model_name = model_name | ||
self.stage_or_version = stage_or_version | ||
self.model_uri = f"models:/{model_name}/{stage_or_version}" | ||
|
||
def _load(self) -> Any: | ||
"""Loads an MLflow model from local path or from MLflow run. | ||
Returns: | ||
Any: Deserialized model. | ||
""" | ||
|
||
# If `run_id` is specified, pull the model from MLflow. | ||
# TODO: enable loading from another mlflow conf (with a client with another tracking uri) | ||
# Alternatively, use local path to load the model. | ||
return self._mlflow_model_module.load_model( | ||
model_uri=self.model_uri, **self._load_args | ||
) | ||
|
||
def _save(self, model: Any) -> None: | ||
raise NotImplementedError( | ||
"The 'save' method is not implemented for MlflowModelRegistryDataSet. You can pass 'registered_model_name' argument in 'MLflowModelLoggerDataSet(..., save_args={registered_model_name='my_model'}' to save and register a model in the same step. " | ||
) | ||
|
||
def _describe(self) -> Dict[str, Any]: | ||
return dict( | ||
model_uri=self.model_uri, | ||
model_name=self.model_name, | ||
stage_or_version=self.stage_or_version, | ||
flavor=self._flavor, | ||
pyfunc_workflow=self._pyfunc_workflow, | ||
load_args=self._load_args, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import mlflow | ||
import pytest | ||
from kedro.io.core import DataSetError | ||
from mlflow import MlflowClient | ||
from sklearn.tree import DecisionTreeClassifier | ||
|
||
from kedro_mlflow.io.models import MlflowModelRegistryDataSet | ||
|
||
|
||
def test_mlflow_model_registry_save_not_implemented(tmp_path): | ||
ml_ds = MlflowModelRegistryDataSet(model_name="demo_model") | ||
with pytest.raises( | ||
DataSetError, | ||
match=r"The 'save' method is not implemented for MlflowModelRegistryDataSet", | ||
): | ||
ml_ds.save(DecisionTreeClassifier()) | ||
|
||
|
||
def test_mlflow_model_registry_load_given_stage_or_version(tmp_path, monkeypatch): | ||
|
||
# we must change the working directory because when | ||
# using mlflow with a local database tracking, the artifacts | ||
# are stored in a relative mlruns/ folder so we need to have | ||
# the same working directory that the one of the tracking uri | ||
monkeypatch.chdir(tmp_path) | ||
tracking_uri = r"sqlite:///" + (tmp_path / "mlruns3.db").as_posix() | ||
mlflow.set_tracking_uri(tracking_uri) | ||
|
||
# setup: we train 3 version of a model under a single | ||
# registered model and stage the 2nd one | ||
runs = {} | ||
for i in range(3): | ||
with mlflow.start_run(): | ||
model = DecisionTreeClassifier() | ||
mlflow.sklearn.log_model( | ||
model, artifact_path="demo_model", registered_model_name="demo_model" | ||
) | ||
runs[i + 1] = mlflow.active_run().info.run_id | ||
print(f"run_{i+1}={runs[i+1]}") | ||
|
||
client = MlflowClient(tracking_uri=tracking_uri) | ||
client.transition_model_version_stage(name="demo_model", version=2, stage="Staging") | ||
|
||
# case 1: no version is provided, we take the last one | ||
ml_ds = MlflowModelRegistryDataSet(model_name="demo_model") | ||
loaded_model = ml_ds.load() | ||
assert loaded_model.metadata.run_id == runs[3] | ||
|
||
# case 2: a stage is provided, we take the last model with this stage | ||
ml_ds = MlflowModelRegistryDataSet( | ||
model_name="demo_model", stage_or_version="staging" | ||
) | ||
loaded_model = ml_ds.load() | ||
assert loaded_model.metadata.run_id == runs[2] | ||
|
||
# case 3: a version is provided, we take the associated model | ||
ml_ds = MlflowModelRegistryDataSet(model_name="demo_model", stage_or_version="1") | ||
loaded_model = ml_ds.load() | ||
assert loaded_model.metadata.run_id == runs[1] |