-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FIX #12 - Add MLflowModelDataSet with versioning support
- Loading branch information
1 parent
9adb5ef
commit 60d4b1e
Showing
6 changed files
with
286 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,41 @@ | ||
# Version model | ||
|
||
This is coming soon. If you want to keep track of the progress on this feature, [follow this issue](https://github.com/Galileo-Galilei/kedro-mlflow/issues/12). | ||
## What is model tracking? | ||
|
||
MLflow allows to serialize and deserialize models to a common format, track those models in MLflow Tracking and manage them using MLflow Model Registry. Many popular Machine / Deep Learning frameworks have built-in support through what MLflow calls flavors. Even if there's no flavor for your framework of choice, it's easy to create your own flavor and integrate it with MLflow. | ||
|
||
## How to track models using MLflow in Kedro project? | ||
|
||
kedro-mlflow introduces a new dataset type that can be used in Data Catalog called ``MlflowModelDataSet``. Suppose you would like to add a scikit-learn model to your Data Catalog. For that you need to an entry like this: | ||
|
||
```yaml | ||
my_sklearn_model: | ||
type: kedro_mlflow.io.MlflowModelDataSet | ||
flavor: mlflow.sklearn | ||
path: data/06_models/my_sklearn_model | ||
``` | ||
You are now able to use ``my_sklearn_model`` in your nodes. | ||
## Frequently asked questions? | ||
## How is it working under the hood? | ||
During save, a model object from node output is save locally under specified ``path`` using ``save_model`` function of the specified ``flavor``. It is then logged to MLflow using ``log_model``. | ||
When model is loaded, the latest version stored locally is read using ``load_model`` function of the specified ``flavor``. You can also load a model from a specific [Kedro run](#can-i-use-kedro-versioning-with-mlflowmodeldataset) or [MLflow run](#can-i-load-a-model-from-a-specific-mlflow-run-id). | ||
### How can I track a custom MLflow model flavor? | ||
To track a custom MLflow model flavor you need to set the `flavor` parameter to import path of your custom flavor: | ||
|
||
```yaml | ||
my_custom_model: | ||
type: kedro_mlflow.io.MlflowModelDataSet | ||
flavor: my_package.custom_mlflow_flavor | ||
path: data/06_models/my_sklearn_model | ||
``` | ||
|
||
### Can I use Kedro versioning with `MlflowModelDataSet`? | ||
|
||
### Can I load a model from a specific MLflow Run ID? |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .mlflow_dataset import MlflowArtifactDataSet | ||
from .mlflow_metrics_dataset import MlflowMetricsDataSet | ||
from .mlflow_model_dataset import MlflowModelDataSet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import importlib | ||
from os import stat | ||
import shutil | ||
from pathlib import Path, PurePosixPath | ||
from typing import Any, Dict, Optional | ||
|
||
from kedro.io import AbstractVersionedDataSet, Version | ||
from kedro.io.core import DataSetError | ||
from mlflow.tracking import MlflowClient | ||
|
||
|
||
class MlflowModelDataSet(AbstractVersionedDataSet): | ||
"""Wrapper for saving, logging and loading for all MLflow model flavor.""" | ||
|
||
def __init__( | ||
self, | ||
flavor: str, | ||
path: str, | ||
run_id: Optional[str] = None, | ||
pyfunc_workflow: Optional[str] = None, | ||
load_args: Dict[str, Any] = {}, | ||
save_args: Dict[str, Any] = {}, | ||
log_args: Dict[str, Any] = {}, | ||
version: Version = None, | ||
) -> None: | ||
"""Intialize the Kedro MlflowModelDataSet. | ||
Parameters are passed from the Data Catalog. | ||
During save, the model is first saved locally at `path` and then | ||
logged to MLflow. | ||
During load, the model is either pulled from MLflow run with `run_id` | ||
or loaded from the local `path`. | ||
Args: | ||
flavor (str): Built-in or custom MLflow model flavor module. | ||
Must be Python-importable. | ||
path (str): Path to store the dataset locally. | ||
run_id (Optional[str], optional): MLflow run ID to use to load | ||
the model from. Defaults to None. | ||
pyfunc_workflow (str, optional): Either `python_model` or `loader_module`. | ||
See https://www.mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#workflows. | ||
load_args (Dict[str, Any], optional): Arguments to `load_model` | ||
function from specified `flavor`. Defaults to {}. | ||
save_args (Dict[str, Any], optional): Arguments to `save_model` | ||
function from specified `flavor`. Defaults to {}. | ||
log_args (Dict[str, Any], optional): Arguments to `log_model` | ||
function from specified `flavor`. Defaults to {}. | ||
version (Version, optional): Kedro version to use. Defaults to None. | ||
Raises: | ||
DataSetError: When passed `flavor` does not exist. | ||
""" | ||
super().__init__(PurePosixPath(path), version) | ||
self._flavor = flavor | ||
self._path = path | ||
self._run_id = run_id | ||
self._pyfunc_workflow = pyfunc_workflow | ||
|
||
if flavor == "mflow.pyfunc" and pyfunc_workflow not in ( | ||
"python_model", | ||
"loader_module", | ||
): | ||
raise DataSetError( | ||
"PyFunc models require specifying `pyfunc_workflow` " | ||
"(set to either `python_model` or `loader_module`)" | ||
) | ||
|
||
self._load_args = self._parse_args(load_args) | ||
self._save_args = self._parse_args(save_args) | ||
self._log_args = self._parse_args(log_args) | ||
self._version = version | ||
self._mlflow_model_module = self._import_module(self._flavor) | ||
|
||
def _load(self) -> Any: | ||
"""Loads an MLflow model from local path or from MLflow run. | ||
Returns: | ||
Any: Deserialized model. | ||
""" | ||
# If `run_id` is specified, pull the model from MLflow. | ||
if self._run_id: | ||
mlflow_client = MlflowClient() | ||
run = mlflow_client.get_run(self._run_id) | ||
load_path = f"{run.info.artifact_uri}/{Path(self._path).name}" | ||
# Alternatively, use local path to load the model. | ||
else: | ||
load_path = str(self._get_load_path()) | ||
return self._mlflow_model_module.load_model(load_path, **self._load_args) | ||
|
||
def _save(self, model: Any) -> None: | ||
"""Save a model to local path and then logs it to MLflow. | ||
Args: | ||
model (Any): A model object supported by the given MLflow flavor. | ||
""" | ||
save_path = self._get_save_path() | ||
# In case of an unversioned model we need to remove the save path | ||
# because MLflow cannot overwrite the target directory. | ||
if Path(save_path).exists(): | ||
shutil.rmtree(save_path) | ||
|
||
if self._flavor == "mlflow.pyfunc": | ||
# PyFunc models utilise either `python_model` or `loader_module` | ||
# workflow. We we assign the passed `model` object to one of those keys | ||
# depending on the chosen `pyfunc_workflow`. | ||
self._save_args[self._pyfunc_workflow] = model | ||
self._mlflow_model_module.save_model(save_path, **self._save_args) | ||
else: | ||
# Otherwise we save using the common workflow where first argument is the | ||
# model object and second is the path. | ||
self._mlflow_model_module.save_model(model, save_path, **self._save_args) | ||
# self._mlflow_model_module.log_model(model, save_path.name, **self._log_args) | ||
|
||
def _describe(self) -> Dict[str, Any]: | ||
return dict( | ||
flavor=self._flavor, | ||
path=self._path, | ||
run_id=self._run_id, | ||
load_args=self._load_args, | ||
save_args=self._save_args, | ||
log_args=self._log_args, | ||
version=self._version, | ||
) | ||
|
||
@classmethod | ||
def _parse_args(cls, kwargs_dict: Dict[str, Any]) -> Dict[str, Any]: | ||
parsed_kargs = {} | ||
for key, value in kwargs_dict.items(): | ||
if key.endswith("_args"): | ||
continue | ||
if f"{key}_args" in kwargs_dict: | ||
new_value = cls._import_module(value)( | ||
MlflowModelDataSet._parse_args(kwargs_dict[f"{key}_args"]) | ||
) | ||
parsed_kargs[key] = new_value | ||
else: | ||
parsed_kargs[key] = value | ||
return parsed_kargs | ||
|
||
@staticmethod | ||
def _import_module(import_path: str) -> Any: | ||
exists = importlib.util.find_spec(import_path) | ||
|
||
if not exists: | ||
raise ImportError(f"{import_path} module not found") | ||
|
||
return importlib.import_module(import_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import mlflow | ||
import pytest | ||
from kedro.io.core import DataSetError | ||
from mlflow.tracking import MlflowClient | ||
from sklearn.linear_model import LinearRegression | ||
|
||
from kedro_mlflow.io import MlflowModelDataSet | ||
|
||
|
||
@pytest.fixture | ||
def linreg_model(): | ||
return LinearRegression() | ||
|
||
|
||
@pytest.fixture | ||
def linreg_path(tmp_path): | ||
return tmp_path / "06_models/linreg" | ||
|
||
|
||
@pytest.fixture | ||
def mlflow_client_run_id(tmp_path): | ||
tracking_uri = tmp_path / "mlruns" | ||
mlflow.set_tracking_uri(tracking_uri.as_uri()) | ||
mlflow_client = MlflowClient(tracking_uri=tracking_uri.as_uri()) | ||
mlflow.start_run() | ||
yield mlflow_client, mlflow.active_run().info.run_id | ||
mlflow.end_run() | ||
|
||
|
||
def test_flavor_does_not_exists(linreg_path): | ||
with pytest.raises(DataSetError): | ||
MlflowModelDataSet.from_config( | ||
name="whoops", | ||
config={ | ||
"type": "kedro_mlflow.io.MlflowModelDataSet", | ||
"flavor": "mlflow.whoops", | ||
"path": linreg_path, | ||
}, | ||
) | ||
|
||
|
||
def test_save_unversioned_under_same_path( | ||
linreg_path, linreg_model, mlflow_client_run_id | ||
): | ||
model_config = { | ||
"name": "linreg", | ||
"config": { | ||
"type": "kedro_mlflow.io.MlflowModelDataSet", | ||
"flavor": "mlflow.sklearn", | ||
"path": linreg_path, | ||
}, | ||
} | ||
mlflow_model_ds = MlflowModelDataSet.from_config(**model_config) | ||
mlflow_model_ds.save(linreg_model) | ||
mlflow_model_ds.save(linreg_model) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"versioned,from_run_id", | ||
[(False, False), (True, False), (False, True), (True, True)], | ||
) | ||
def test_save_load_local( | ||
linreg_path, linreg_model, mlflow_client_run_id, versioned, from_run_id | ||
): | ||
model_config = { | ||
"name": "linreg", | ||
"config": { | ||
"type": "kedro_mlflow.io.MlflowModelDataSet", | ||
"flavor": "mlflow.sklearn", | ||
"path": linreg_path, | ||
"versioned": versioned, | ||
}, | ||
} | ||
mlflow_model_ds = MlflowModelDataSet.from_config(**model_config) | ||
mlflow_model_ds.save(linreg_model) | ||
|
||
if versioned: | ||
assert ( | ||
linreg_path / mlflow_model_ds._version.save / linreg_path.name | ||
).exists(), "Versioned model saved locally" | ||
else: | ||
assert linreg_path.exists(), "Unversioned model saved locally" | ||
|
||
mlflow_client, run_id = mlflow_client_run_id | ||
artifact = mlflow_client.list_artifacts(run_id=run_id)[0] | ||
versioned_str = "Versioned" if versioned else "Unversioned" | ||
assert linreg_path.name == artifact.path, f"{versioned_str} model logged to MLflow" | ||
|
||
if from_run_id: | ||
model_config["config"]["run_id"] = run_id | ||
mlflow_model_ds = MlflowModelDataSet.from_config(**model_config) | ||
|
||
linreg_model_loaded = mlflow_model_ds.load() | ||
assert isinstance( | ||
linreg_model_loaded, LinearRegression | ||
), f"{versioned_str} model loaded" |