Skip to content

Commit

Permalink
FIX #12 - Refactor MlflowModelDataSet to distinguish between saver an…
Browse files Browse the repository at this point in the history
…d logger
  • Loading branch information
Galileo-Galilei committed Nov 1, 2020
1 parent 4393681 commit 43f1cbe
Show file tree
Hide file tree
Showing 11 changed files with 1,057 additions and 256 deletions.
1 change: 0 additions & 1 deletion kedro_mlflow/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from .mlflow_dataset import MlflowArtifactDataSet
from .mlflow_metrics_dataset import MlflowMetricsDataSet
from .mlflow_model_dataset import MlflowModelDataSet
148 changes: 0 additions & 148 deletions kedro_mlflow/io/mlflow_model_dataset.py

This file was deleted.

2 changes: 2 additions & 0 deletions kedro_mlflow/io/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .mlflow_model_logger_dataset import MlflowModelLoggerDataSet
from .mlflow_model_saver_dataset import MlflowModelSaverDataSet
87 changes: 87 additions & 0 deletions kedro_mlflow/io/models/mlflow_abstract_model_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import importlib
from typing import Any, Dict, Optional

from kedro.io import AbstractVersionedDataSet, Version
from kedro.io.core import DataSetError


class MlflowAbstractModelDataSet(AbstractVersionedDataSet):
"""
Absract mother class for model datasets.
"""

def __init__(
self,
filepath,
flavor: str,
pyfunc_workflow: Optional[str] = None,
load_args: Dict[str, Any] = None,
save_args: Dict[str, Any] = None,
version: Version = None,
) -> None:
"""Initialize the Kedro MlflowModelDataSet.
Parameters are passed from the Data Catalog.
During save, the model is first logged to MLflow.
During load, the model is pulled from MLflow run with `run_id`.
Args:
flavor (str): Built-in or custom MLflow model flavor module.
Must be Python-importable.
filepath (str): Path to store the dataset locally.
run_id (Optional[str], optional): MLflow run ID to use to load
the model from or save the model to. If provided,
takes precedence over filepath. Defaults to None.
pyfunc_workflow (str, optional): Either `python_model` or `loader_module`.
See https://www.mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#workflows.
load_args (Dict[str, Any], optional): Arguments to `load_model`
function from specified `flavor`. Defaults to {}.
save_args (Dict[str, Any], optional): Arguments to `log_model`
function from specified `flavor`. Defaults to {}.
Raises:
DataSetError: When passed `flavor` does not exist.
"""
super().__init__(filepath, version)
self._flavor = flavor
self._pyfunc_workflow = pyfunc_workflow

if flavor == "mlflow.pyfunc" and pyfunc_workflow not in (
"python_model",
"loader_module",
):
raise DataSetError(
"PyFunc models require specifying `pyfunc_workflow` "
"(set to either `python_model` or `loader_module`)"
)

self._load_args = load_args or {}
self._save_args = save_args or {}

self._mlflow_model_module = self._import_module(self._flavor)

# TODO: check with Kajetan what was orignally intended here
# @classmethod
# def _parse_args(cls, kwargs_dict: Dict[str, Any]) -> Dict[str, Any]:
# parsed_kargs = {}
# for key, value in kwargs_dict.items():
# if key.endswith("_args"):
# continue
# if f"{key}_args" in kwargs_dict:
# new_value = cls._import_module(value)(
# MlflowModelDataSet._parse_args(kwargs_dict[f"{key}_args"])
# )
# parsed_kargs[key] = new_value
# else:
# parsed_kargs[key] = value
# return parsed_kargs

@staticmethod
def _import_module(import_path: str) -> Any:
exists = importlib.util.find_spec(import_path)

if not exists:
raise ImportError(f"{import_path} module not found")

return importlib.import_module(import_path)
152 changes: 152 additions & 0 deletions kedro_mlflow/io/models/mlflow_model_logger_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from typing import Any, Dict, Optional

import mlflow
from kedro.io.core import DataSetError

from kedro_mlflow.io.models.mlflow_abstract_model_dataset import (
MlflowAbstractModelDataSet,
)


class MlflowModelLoggerDataSet(MlflowAbstractModelDataSet):
"""Wrapper for saving, logging and loading for all MLflow model flavor."""

def __init__(
self,
flavor: str,
load_run_id: Optional[str] = None,
save_run_id: Optional[str] = None,
artifact_path: Optional[str] = "model",
pyfunc_workflow: Optional[str] = None,
load_args: Optional[Dict[str, Any]] = None,
save_args: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize the Kedro MlflowModelDataSet.
Parameters are passed from the Data Catalog.
During save, the model is first logged to MLflow.
During load, the model is pulled from MLflow run with `run_id`.
Args:
flavor (str): Built-in or custom MLflow model flavor module.
Must be Python-importable.
run_id (Optional[str], optional): MLflow run ID to use to load
the model from or save the model to. If provided,
takes precedence over filepath. Defaults to None.
artifact_path (str, optional): the run relative path to
the model.
pyfunc_workflow (str, optional): Either `python_model` or `loader_module`.
See https://www.mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#workflows.
load_args (Dict[str, Any], optional): Arguments to `load_model`
function from specified `flavor`. Defaults to None.
save_args (Dict[str, Any], optional): Arguments to `log_model`
function from specified `flavor`. Defaults to None.
Raises:
DataSetError: When passed `flavor` does not exist.
"""
super().__init__(
filepath="",
flavor=flavor,
pyfunc_workflow=pyfunc_workflow,
load_args=load_args,
save_args=save_args,
version=None,
)

self._load_run_id = load_run_id
self._save_run_id = save_run_id
self._artifact_path = artifact_path

# drop the key which MUST be common to save and load and
# thus is instantiated outside save_args
self._save_args.pop("artifact_path", None)

@property
def model_uri(self):
run_id = None
if self._load_run_id:
run_id = self._load_run_id
elif mlflow.active_run() is not None:
run_id = mlflow.active_run().info.run_id
if run_id is None:
raise DataSetError(
(
"To access the model_uri, you must either: "
"\n - specifiy 'run_id' "
"\n - have an active run to retrieve data from"
)
)

model_uri = f"runs:/{run_id}/{self._artifact_path}"

return model_uri

def _load(self) -> Any:
"""Loads an MLflow model from local path or from MLflow run.
Returns:
Any: Deserialized model.
"""

# If `run_id` is specified, pull the model from MLflow.
# TODO: enable loading from another mlflow conf (with a client with another tracking uri)
# Alternatively, use local path to load the model.
return self._mlflow_model_module.load_model(
model_uri=self.model_uri, **self._load_args
)

def _save(self, model: Any) -> None:
"""Save a model to local path and then logs it to MLflow.
Args:
model (Any): A model object supported by the given MLflow flavor.
"""
if self._save_run_id:
if mlflow.active_run():
# it is not possible to log in a run which is not the current opened one
raise DataSetError(
(
"'save_run_id' cannot be specified"
" if there is an mlflow active run."
"Run_id mismatch: "
f"\n - 'save_run_id'={self._save_run_id}"
f"\n - active_run id={mlflow.active_run().info.run_id}"
)
)
else:
# if the run id is specified and there is no opened run,
# open the right run before logging
with mlflow.start_run(run_id=self._save_run_id):
self._save_model_in_run(model)
else:
# if there is no run_id, log in active run
# OR open automatically a new run to log
self._save_model_in_run(model)

def _save_model_in_run(self, model):

if self._flavor == "mlflow.pyfunc":
# PyFunc models utilise either `python_model` or `loader_module`
# workflow. We we assign the passed `model` object to one of those keys
# depending on the chosen `pyfunc_workflow`.
self._save_args[self._pyfunc_workflow] = model
self._mlflow_model_module.log_model(self._artifact_path, **self._save_args)
else:
# Otherwise we save using the common workflow where first argument is the
# model object and second is the path.
self._mlflow_model_module.log_model(
model, self._artifact_path, **self._save_args
)

def _describe(self) -> Dict[str, Any]:
return dict(
flavor=self._flavor,
load_run_id=self._load_run_id,
save_run_id=self._save_run_id,
artifact_path=self._artifact_path,
pyfunc_workflow=self._pyfunc_workflow,
load_args=self._load_args,
save_args=self._save_args,
)
Loading

0 comments on commit 43f1cbe

Please sign in to comment.