Skip to content

Commit

Permalink
#12 Add MLflowModelDataSet with versioning support
Browse files Browse the repository at this point in the history
  • Loading branch information
kaemo authored and Galileo-Galilei committed Oct 31, 2020
1 parent 9adb5ef commit ecaba73
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 0 deletions.
1 change: 1 addition & 0 deletions kedro_mlflow/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .mlflow_dataset import MlflowArtifactDataSet
from .mlflow_metrics_dataset import MlflowMetricsDataSet
from .mlflow_model_dataset import MlflowModelDataSet
105 changes: 105 additions & 0 deletions kedro_mlflow/io/mlflow_model_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import importlib
import shutil
from pathlib import Path, PurePosixPath
from typing import Any, Dict, Optional

from kedro.io import AbstractVersionedDataSet, Version
from kedro.io.core import DataSetError
from mlflow.tracking import MlflowClient


class MlflowModelDataSet(AbstractVersionedDataSet):
"""Wrapper for saving, logging and loading for all MLflow model flavor."""

def __init__(
self,
flavor: str,
path: str,
run_id: Optional[str] = None,
load_args: Dict[str, Any] = {},
save_args: Dict[str, Any] = {},
log_args: Dict[str, Any] = {},
version: Version = None,
) -> None:
"""Intialize the Kedro MlflowModelDataSet.
Parameters are passed from the Data Catalog.
During save, the model is first saved locally at `path` and then
logged to MLflow.
During load, the model is either pulled from MLflow run with `run_id`
or loaded from the local `path`.
Args:
flavor (str): Built-in or custom MLflow model flavor module.
Must be Python-importable.
path (str): Path to store the dataset locally.
run_id (Optional[str], optional): MLflow run ID to use to load
the model from. Defaults to None.
load_args (Dict[str, Any], optional): Arguments to `load_model`
function from specified `flavor`. Defaults to {}.
save_args (Dict[str, Any], optional): Arguments to `save_model`
function from specified `flavor`. Defaults to {}.
log_args (Dict[str, Any], optional): Arguments to `log_model`
function from specified `flavor`. Defaults to {}.
version (Version, optional): Kedro version to use. Defaults to None.
Raises:
DataSetError: When passed `flavor` does not exist.
"""
super().__init__(PurePosixPath(path), version)
self._flavor = flavor
self._path = path
self._run_id = run_id
self._load_args = load_args
self._save_args = save_args
self._log_args = log_args
self._version = version

mlflow_model_module_exists = importlib.util.find_spec(self._flavor)

if not mlflow_model_module_exists:
raise DataSetError(f"{self._flavor} not found")

self._mlflow_model_module = importlib.import_module(self._flavor)

def _load(self) -> Any:
"""Loads an MLflow model from local path or from MLflow run.
Returns:
Any: Deserialized model.
"""
# If `run_id` is specified, pull the model from MLflow.
if self._run_id:
mlflow_client = MlflowClient()
run = mlflow_client.get_run(self._run_id)
load_path = f"{run.info.artifact_uri}/{Path(self._path).name}"
# Alternatively, use local path to load the model.
else:
load_path = str(self._get_load_path())
return self._mlflow_model_module.load_model(load_path, **self._load_args)

def _save(self, model: Any) -> None:
"""Save a model to local path and then logs it to MLflow.
Args:
model (Any): A model object supported by the given MLflow flavor.
"""
save_path = self._get_save_path()
# In case of an unversioned model we need to remove the save path
# because MLflow cannot overwrite the target directory.
if Path(save_path).exists():
shutil.rmtree(save_path)
self._mlflow_model_module.save_model(model, save_path, **self._save_args)
self._mlflow_model_module.log_model(model, save_path.name, **self._log_args)

def _describe(self) -> Dict[str, Any]:
return dict(
flavor=self._flavor,
path=self._path,
run_id=self._run_id,
load_args=self._load_args,
save_args=self._save_args,
log_args=self._log_args,
version=self._version,
)

0 comments on commit ecaba73

Please sign in to comment.