Skip to content

Commit

Permalink
FIX #12 - Add MLflowModelDataSet with versioning support
Browse files Browse the repository at this point in the history
  • Loading branch information
kaemo authored and Galileo-Galilei committed Nov 3, 2020
1 parent 9adb5ef commit d89e2fa
Show file tree
Hide file tree
Showing 6 changed files with 286 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ line_length=88
ensure_newline_before_comments=True
sections=FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
known_first_party=kedro_mlflow
known_third_party=anyconfig,click,cookiecutter,jinja2,kedro,mlflow,packaging,pandas,pytest,pytest_lazyfixture,setuptools,yaml
known_third_party=anyconfig,black,click,cookiecutter,flake8,isort,jinja2,kedro,mlflow,pandas,pytest,pytest_lazyfixture,setuptools,sklearn,yaml
40 changes: 39 additions & 1 deletion docs/source/03_tutorial/06_version_models.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,41 @@
# Version model

This is coming soon. If you want to keep track of the progress on this feature, [follow this issue](https://github.com/Galileo-Galilei/kedro-mlflow/issues/12).
## What is model tracking?

MLflow allows to serialize and deserialize models to a common format, track those models in MLflow Tracking and manage them using MLflow Model Registry. Many popular Machine / Deep Learning frameworks have built-in support through what MLflow calls flavors. Even if there's no flavor for your framework of choice, it's easy to create your own flavor and integrate it with MLflow.

## How to track models using MLflow in Kedro project?

kedro-mlflow introduces a new dataset type that can be used in Data Catalog called ``MlflowModelDataSet``. Suppose you would like to add a scikit-learn model to your Data Catalog. For that you need to an entry like this:

```yaml
my_sklearn_model:
type: kedro_mlflow.io.MlflowModelDataSet
flavor: mlflow.sklearn
path: data/06_models/my_sklearn_model
```
You are now able to use ``my_sklearn_model`` in your nodes.
## Frequently asked questions?
## How is it working under the hood?
During save, a model object from node output is save locally under specified ``path`` using ``save_model`` function of the specified ``flavor``. It is then logged to MLflow using ``log_model``.
When model is loaded, the latest version stored locally is read using ``load_model`` function of the specified ``flavor``. You can also load a model from a specific [Kedro run](#can-i-use-kedro-versioning-with-mlflowmodeldataset) or [MLflow run](#can-i-load-a-model-from-a-specific-mlflow-run-id).
### How can I track a custom MLflow model flavor?
To track a custom MLflow model flavor you need to set the `flavor` parameter to import path of your custom flavor:

```yaml
my_custom_model:
type: kedro_mlflow.io.MlflowModelDataSet
flavor: my_package.custom_mlflow_flavor
path: data/06_models/my_sklearn_model
```

### Can I use Kedro versioning with `MlflowModelDataSet`?

### Can I load a model from a specific MLflow Run ID?
1 change: 1 addition & 0 deletions kedro_mlflow/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .mlflow_dataset import MlflowArtifactDataSet
from .mlflow_metrics_dataset import MlflowMetricsDataSet
from .mlflow_model_dataset import MlflowModelDataSet
148 changes: 148 additions & 0 deletions kedro_mlflow/io/mlflow_model_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import importlib
from os import stat
import shutil
from pathlib import Path, PurePosixPath
from typing import Any, Dict, Optional

from kedro.io import AbstractVersionedDataSet, Version
from kedro.io.core import DataSetError
from mlflow.tracking import MlflowClient


class MlflowModelDataSet(AbstractVersionedDataSet):
"""Wrapper for saving, logging and loading for all MLflow model flavor."""

def __init__(
self,
flavor: str,
path: str,
run_id: Optional[str] = None,
pyfunc_workflow: Optional[str] = None,
load_args: Dict[str, Any] = {},
save_args: Dict[str, Any] = {},
log_args: Dict[str, Any] = {},
version: Version = None,
) -> None:
"""Intialize the Kedro MlflowModelDataSet.
Parameters are passed from the Data Catalog.
During save, the model is first saved locally at `path` and then
logged to MLflow.
During load, the model is either pulled from MLflow run with `run_id`
or loaded from the local `path`.
Args:
flavor (str): Built-in or custom MLflow model flavor module.
Must be Python-importable.
path (str): Path to store the dataset locally.
run_id (Optional[str], optional): MLflow run ID to use to load
the model from. Defaults to None.
pyfunc_workflow (str, optional): Either `python_model` or `loader_module`.
See https://www.mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#workflows.
load_args (Dict[str, Any], optional): Arguments to `load_model`
function from specified `flavor`. Defaults to {}.
save_args (Dict[str, Any], optional): Arguments to `save_model`
function from specified `flavor`. Defaults to {}.
log_args (Dict[str, Any], optional): Arguments to `log_model`
function from specified `flavor`. Defaults to {}.
version (Version, optional): Kedro version to use. Defaults to None.
Raises:
DataSetError: When passed `flavor` does not exist.
"""
super().__init__(PurePosixPath(path), version)
self._flavor = flavor
self._path = path
self._run_id = run_id
self._pyfunc_workflow = pyfunc_workflow

if flavor == "mflow.pyfunc" and pyfunc_workflow not in (
"python_model",
"loader_module",
):
raise DataSetError(
"PyFunc models require specifying `pyfunc_workflow` "
"(set to either `python_model` or `loader_module`)"
)

self._load_args = self._parse_args(load_args)
self._save_args = self._parse_args(save_args)
self._log_args = self._parse_args(log_args)
self._version = version
self._mlflow_model_module = self._import_module(self._flavor)

def _load(self) -> Any:
"""Loads an MLflow model from local path or from MLflow run.
Returns:
Any: Deserialized model.
"""
# If `run_id` is specified, pull the model from MLflow.
if self._run_id:
mlflow_client = MlflowClient()
run = mlflow_client.get_run(self._run_id)
load_path = f"{run.info.artifact_uri}/{Path(self._path).name}"
# Alternatively, use local path to load the model.
else:
load_path = str(self._get_load_path())
return self._mlflow_model_module.load_model(load_path, **self._load_args)

def _save(self, model: Any) -> None:
"""Save a model to local path and then logs it to MLflow.
Args:
model (Any): A model object supported by the given MLflow flavor.
"""
save_path = self._get_save_path()
# In case of an unversioned model we need to remove the save path
# because MLflow cannot overwrite the target directory.
if Path(save_path).exists():
shutil.rmtree(save_path)

if self._flavor == "mlflow.pyfunc":
# PyFunc models utilise either `python_model` or `loader_module`
# workflow. We we assign the passed `model` object to one of those keys
# depending on the chosen `pyfunc_workflow`.
self._save_args[self._pyfunc_workflow] = model
self._mlflow_model_module.save_model(save_path, **self._save_args)
else:
# Otherwise we save using the common workflow where first argument is the
# model object and second is the path.
self._mlflow_model_module.save_model(model, save_path, **self._save_args)
# self._mlflow_model_module.log_model(model, save_path.name, **self._log_args)

def _describe(self) -> Dict[str, Any]:
return dict(
flavor=self._flavor,
path=self._path,
run_id=self._run_id,
load_args=self._load_args,
save_args=self._save_args,
log_args=self._log_args,
version=self._version,
)

@classmethod
def _parse_args(cls, kwargs_dict: Dict[str, Any]) -> Dict[str, Any]:
parsed_kargs = {}
for key, value in kwargs_dict.items():
if key.endswith("_args"):
continue
if f"{key}_args" in kwargs_dict:
new_value = cls._import_module(value)(
MlflowModelDataSet._parse_args(kwargs_dict[f"{key}_args"])
)
parsed_kargs[key] = new_value
else:
parsed_kargs[key] = value
return parsed_kargs

@staticmethod
def _import_module(import_path: str) -> Any:
exists = importlib.util.find_spec(import_path)

if not exists:
raise ImportError(f"{import_path} module not found")

return importlib.import_module(import_path)
1 change: 1 addition & 0 deletions requirements/test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pytest>=5.4.0, <6.0.0
pytest-cov>=2.8.0, <3.0.0
pytest-lazy-fixture>=0.6.0, <1.0.0
pytest-mock>=3.1.0, <4.0.0
sklearn>=0.23.0, <0.24.0
flake8>=3.0.0, <4.0.0
black==19.10b0 # pin black version because it is not compatible with a pip range (because of non semver version number)
isort>=5.0.0, <6.0.0
96 changes: 96 additions & 0 deletions tests/io/test_mlflow_model_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import mlflow
import pytest
from kedro.io.core import DataSetError
from mlflow.tracking import MlflowClient
from sklearn.linear_model import LinearRegression

from kedro_mlflow.io import MlflowModelDataSet


@pytest.fixture
def linreg_model():
return LinearRegression()


@pytest.fixture
def linreg_path(tmp_path):
return tmp_path / "06_models/linreg"


@pytest.fixture
def mlflow_client_run_id(tmp_path):
tracking_uri = tmp_path / "mlruns"
mlflow.set_tracking_uri(tracking_uri.as_uri())
mlflow_client = MlflowClient(tracking_uri=tracking_uri.as_uri())
mlflow.start_run()
yield mlflow_client, mlflow.active_run().info.run_id
mlflow.end_run()


def test_flavor_does_not_exists(linreg_path):
with pytest.raises(DataSetError):
MlflowModelDataSet.from_config(
name="whoops",
config={
"type": "kedro_mlflow.io.MlflowModelDataSet",
"flavor": "mlflow.whoops",
"path": linreg_path,
},
)


def test_save_unversioned_under_same_path(
linreg_path, linreg_model, mlflow_client_run_id
):
model_config = {
"name": "linreg",
"config": {
"type": "kedro_mlflow.io.MlflowModelDataSet",
"flavor": "mlflow.sklearn",
"path": linreg_path,
},
}
mlflow_model_ds = MlflowModelDataSet.from_config(**model_config)
mlflow_model_ds.save(linreg_model)
mlflow_model_ds.save(linreg_model)


@pytest.mark.parametrize(
"versioned,from_run_id",
[(False, False), (True, False), (False, True), (True, True)],
)
def test_save_load_local(
linreg_path, linreg_model, mlflow_client_run_id, versioned, from_run_id
):
model_config = {
"name": "linreg",
"config": {
"type": "kedro_mlflow.io.MlflowModelDataSet",
"flavor": "mlflow.sklearn",
"path": linreg_path,
"versioned": versioned,
},
}
mlflow_model_ds = MlflowModelDataSet.from_config(**model_config)
mlflow_model_ds.save(linreg_model)

if versioned:
assert (
linreg_path / mlflow_model_ds._version.save / linreg_path.name
).exists(), "Versioned model saved locally"
else:
assert linreg_path.exists(), "Unversioned model saved locally"

mlflow_client, run_id = mlflow_client_run_id
artifact = mlflow_client.list_artifacts(run_id=run_id)[0]
versioned_str = "Versioned" if versioned else "Unversioned"
assert linreg_path.name == artifact.path, f"{versioned_str} model logged to MLflow"

if from_run_id:
model_config["config"]["run_id"] = run_id
mlflow_model_ds = MlflowModelDataSet.from_config(**model_config)

linreg_model_loaded = mlflow_model_ds.load()
assert isinstance(
linreg_model_loaded, LinearRegression
), f"{versioned_str} model loaded"

0 comments on commit d89e2fa

Please sign in to comment.