-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FIX #12 - Refactor MlflowModelDataSet to distinguish between saver an…
…d logger
- Loading branch information
1 parent
4393681
commit 43f1cbe
Showing
11 changed files
with
1,057 additions
and
256 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,2 @@ | ||
from .mlflow_dataset import MlflowArtifactDataSet | ||
from .mlflow_metrics_dataset import MlflowMetricsDataSet | ||
from .mlflow_model_dataset import MlflowModelDataSet |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .mlflow_model_logger_dataset import MlflowModelLoggerDataSet | ||
from .mlflow_model_saver_dataset import MlflowModelSaverDataSet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import importlib | ||
from typing import Any, Dict, Optional | ||
|
||
from kedro.io import AbstractVersionedDataSet, Version | ||
from kedro.io.core import DataSetError | ||
|
||
|
||
class MlflowAbstractModelDataSet(AbstractVersionedDataSet): | ||
""" | ||
Absract mother class for model datasets. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
filepath, | ||
flavor: str, | ||
pyfunc_workflow: Optional[str] = None, | ||
load_args: Dict[str, Any] = None, | ||
save_args: Dict[str, Any] = None, | ||
version: Version = None, | ||
) -> None: | ||
"""Initialize the Kedro MlflowModelDataSet. | ||
Parameters are passed from the Data Catalog. | ||
During save, the model is first logged to MLflow. | ||
During load, the model is pulled from MLflow run with `run_id`. | ||
Args: | ||
flavor (str): Built-in or custom MLflow model flavor module. | ||
Must be Python-importable. | ||
filepath (str): Path to store the dataset locally. | ||
run_id (Optional[str], optional): MLflow run ID to use to load | ||
the model from or save the model to. If provided, | ||
takes precedence over filepath. Defaults to None. | ||
pyfunc_workflow (str, optional): Either `python_model` or `loader_module`. | ||
See https://www.mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#workflows. | ||
load_args (Dict[str, Any], optional): Arguments to `load_model` | ||
function from specified `flavor`. Defaults to {}. | ||
save_args (Dict[str, Any], optional): Arguments to `log_model` | ||
function from specified `flavor`. Defaults to {}. | ||
Raises: | ||
DataSetError: When passed `flavor` does not exist. | ||
""" | ||
super().__init__(filepath, version) | ||
self._flavor = flavor | ||
self._pyfunc_workflow = pyfunc_workflow | ||
|
||
if flavor == "mlflow.pyfunc" and pyfunc_workflow not in ( | ||
"python_model", | ||
"loader_module", | ||
): | ||
raise DataSetError( | ||
"PyFunc models require specifying `pyfunc_workflow` " | ||
"(set to either `python_model` or `loader_module`)" | ||
) | ||
|
||
self._load_args = load_args or {} | ||
self._save_args = save_args or {} | ||
|
||
self._mlflow_model_module = self._import_module(self._flavor) | ||
|
||
# TODO: check with Kajetan what was orignally intended here | ||
# @classmethod | ||
# def _parse_args(cls, kwargs_dict: Dict[str, Any]) -> Dict[str, Any]: | ||
# parsed_kargs = {} | ||
# for key, value in kwargs_dict.items(): | ||
# if key.endswith("_args"): | ||
# continue | ||
# if f"{key}_args" in kwargs_dict: | ||
# new_value = cls._import_module(value)( | ||
# MlflowModelDataSet._parse_args(kwargs_dict[f"{key}_args"]) | ||
# ) | ||
# parsed_kargs[key] = new_value | ||
# else: | ||
# parsed_kargs[key] = value | ||
# return parsed_kargs | ||
|
||
@staticmethod | ||
def _import_module(import_path: str) -> Any: | ||
exists = importlib.util.find_spec(import_path) | ||
|
||
if not exists: | ||
raise ImportError(f"{import_path} module not found") | ||
|
||
return importlib.import_module(import_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
from typing import Any, Dict, Optional | ||
|
||
import mlflow | ||
from kedro.io.core import DataSetError | ||
|
||
from kedro_mlflow.io.models.mlflow_abstract_model_dataset import ( | ||
MlflowAbstractModelDataSet, | ||
) | ||
|
||
|
||
class MlflowModelLoggerDataSet(MlflowAbstractModelDataSet): | ||
"""Wrapper for saving, logging and loading for all MLflow model flavor.""" | ||
|
||
def __init__( | ||
self, | ||
flavor: str, | ||
load_run_id: Optional[str] = None, | ||
save_run_id: Optional[str] = None, | ||
artifact_path: Optional[str] = "model", | ||
pyfunc_workflow: Optional[str] = None, | ||
load_args: Optional[Dict[str, Any]] = None, | ||
save_args: Optional[Dict[str, Any]] = None, | ||
) -> None: | ||
"""Initialize the Kedro MlflowModelDataSet. | ||
Parameters are passed from the Data Catalog. | ||
During save, the model is first logged to MLflow. | ||
During load, the model is pulled from MLflow run with `run_id`. | ||
Args: | ||
flavor (str): Built-in or custom MLflow model flavor module. | ||
Must be Python-importable. | ||
run_id (Optional[str], optional): MLflow run ID to use to load | ||
the model from or save the model to. If provided, | ||
takes precedence over filepath. Defaults to None. | ||
artifact_path (str, optional): the run relative path to | ||
the model. | ||
pyfunc_workflow (str, optional): Either `python_model` or `loader_module`. | ||
See https://www.mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#workflows. | ||
load_args (Dict[str, Any], optional): Arguments to `load_model` | ||
function from specified `flavor`. Defaults to None. | ||
save_args (Dict[str, Any], optional): Arguments to `log_model` | ||
function from specified `flavor`. Defaults to None. | ||
Raises: | ||
DataSetError: When passed `flavor` does not exist. | ||
""" | ||
super().__init__( | ||
filepath="", | ||
flavor=flavor, | ||
pyfunc_workflow=pyfunc_workflow, | ||
load_args=load_args, | ||
save_args=save_args, | ||
version=None, | ||
) | ||
|
||
self._load_run_id = load_run_id | ||
self._save_run_id = save_run_id | ||
self._artifact_path = artifact_path | ||
|
||
# drop the key which MUST be common to save and load and | ||
# thus is instantiated outside save_args | ||
self._save_args.pop("artifact_path", None) | ||
|
||
@property | ||
def model_uri(self): | ||
run_id = None | ||
if self._load_run_id: | ||
run_id = self._load_run_id | ||
elif mlflow.active_run() is not None: | ||
run_id = mlflow.active_run().info.run_id | ||
if run_id is None: | ||
raise DataSetError( | ||
( | ||
"To access the model_uri, you must either: " | ||
"\n - specifiy 'run_id' " | ||
"\n - have an active run to retrieve data from" | ||
) | ||
) | ||
|
||
model_uri = f"runs:/{run_id}/{self._artifact_path}" | ||
|
||
return model_uri | ||
|
||
def _load(self) -> Any: | ||
"""Loads an MLflow model from local path or from MLflow run. | ||
Returns: | ||
Any: Deserialized model. | ||
""" | ||
|
||
# If `run_id` is specified, pull the model from MLflow. | ||
# TODO: enable loading from another mlflow conf (with a client with another tracking uri) | ||
# Alternatively, use local path to load the model. | ||
return self._mlflow_model_module.load_model( | ||
model_uri=self.model_uri, **self._load_args | ||
) | ||
|
||
def _save(self, model: Any) -> None: | ||
"""Save a model to local path and then logs it to MLflow. | ||
Args: | ||
model (Any): A model object supported by the given MLflow flavor. | ||
""" | ||
if self._save_run_id: | ||
if mlflow.active_run(): | ||
# it is not possible to log in a run which is not the current opened one | ||
raise DataSetError( | ||
( | ||
"'save_run_id' cannot be specified" | ||
" if there is an mlflow active run." | ||
"Run_id mismatch: " | ||
f"\n - 'save_run_id'={self._save_run_id}" | ||
f"\n - active_run id={mlflow.active_run().info.run_id}" | ||
) | ||
) | ||
else: | ||
# if the run id is specified and there is no opened run, | ||
# open the right run before logging | ||
with mlflow.start_run(run_id=self._save_run_id): | ||
self._save_model_in_run(model) | ||
else: | ||
# if there is no run_id, log in active run | ||
# OR open automatically a new run to log | ||
self._save_model_in_run(model) | ||
|
||
def _save_model_in_run(self, model): | ||
|
||
if self._flavor == "mlflow.pyfunc": | ||
# PyFunc models utilise either `python_model` or `loader_module` | ||
# workflow. We we assign the passed `model` object to one of those keys | ||
# depending on the chosen `pyfunc_workflow`. | ||
self._save_args[self._pyfunc_workflow] = model | ||
self._mlflow_model_module.log_model(self._artifact_path, **self._save_args) | ||
else: | ||
# Otherwise we save using the common workflow where first argument is the | ||
# model object and second is the path. | ||
self._mlflow_model_module.log_model( | ||
model, self._artifact_path, **self._save_args | ||
) | ||
|
||
def _describe(self) -> Dict[str, Any]: | ||
return dict( | ||
flavor=self._flavor, | ||
load_run_id=self._load_run_id, | ||
save_run_id=self._save_run_id, | ||
artifact_path=self._artifact_path, | ||
pyfunc_workflow=self._pyfunc_workflow, | ||
load_args=self._load_args, | ||
save_args=self._save_args, | ||
) |
Oops, something went wrong.