forked from Galileo-Galilei/kedro-mlflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FIX Galileo-Galilei#9 MlflowMetricsDataSet implemented.
- Loading branch information
Adrian Piotr Kruszewski
committed
Aug 20, 2020
1 parent
379b617
commit b71c81d
Showing
5 changed files
with
380 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,62 @@ | ||
# Version metrics | ||
|
||
This is coming soon. If you want to keep track of the progress on this feature, [follow this issue](https://github.com/Galileo-Galilei/kedro-mlflow/issues/9). | ||
## What is metric tracking? | ||
|
||
Mlflow defines metrics as "Key-value metrics, where the value is numeric. Each metric can be updated throughout the course of the run (for example, to track how your model’s loss function is converging), and MLflow records and lets you visualize the metric’s full history". | ||
|
||
## How to version metrics in a kedro project? | ||
|
||
kedro-mlflow introduces a new ``AbstractDataSet`` called ``MlflowMetricsDataSet``. It is wrapper around dictionary with metrics which is returned by node and log metrics in MLflow. | ||
|
||
Since it is a ``AbstractDataSet``, it can be used with the YAML API. You can define it as: | ||
|
||
```yaml | ||
my_model_metrics: | ||
type: kedro_mlflow.io.MlflowMetricsDataSet | ||
``` | ||
It can get also ``prefix`` configuration option. This is useful especially when your pipeline evaluate metrics on different datasets. For example: | ||
```yaml | ||
my_model_metrics_dev: | ||
type: kedro_mlflow.io.MlflowMetricsDataSet | ||
prefix: dev | ||
my_model_metrics_test: | ||
type: kedro_mlflow.io.MlflowMetricsDataSet | ||
prefix: test | ||
``` | ||
In that scenario metrics will be available in MLflow with given prefixes. For example your ``accuracy`` metric from example above, for ``my_model_metrics_test`` will be stored under key ``test.accuracy``, for ``my_model_metrics_dev``, under key ``dev.accuracy``. | ||
## How to return metrics from node? | ||
Let assume that you have node which doesn't have any inputs and returns dictionary with metrics to log: | ||
```python | ||
def metrics_node() -> Dict[str, Union[float, List[float]]]: | ||
return { | ||
"metric1": 1.0, | ||
"metric2": [1.0, 1.1] | ||
} | ||
``` | ||
|
||
As you can see above, ``kedro_mlflow.io.MlflowMetricsDataSet`` can take as metrics ``floats`` or ``lists`` of ``floats``. In first case under the given metric key just one value will be logged, in second a series of values. | ||
|
||
To store metrics we need to define metrics dataset in Kedro Catalog: | ||
|
||
```yaml | ||
my_model_metrics: | ||
type: kedro_mlflow.io.MlflowMetricsDataSet | ||
``` | ||
To fulfill example we also need pipeline which will use this node and store metrics under ``my_model_metrics`` name. | ||
```python | ||
def create_pipeline() -> Pipeline: | ||
return Pipeline(node( | ||
func=metrics_node, | ||
inputs=None, | ||
outputs="my_model_metrics", | ||
name="log_metrics", | ||
)) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .mlflow_dataset import MlflowDataSet | ||
from .mlflow_metrics_dataset import MlflowMetricsDataSet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
from copy import deepcopy | ||
from functools import partial, reduce | ||
from itertools import chain | ||
from typing import Any, Dict, Generator, List, Optional, Tuple, Union | ||
|
||
import mlflow | ||
from kedro.io import AbstractDataSet, DataSetError | ||
from mlflow.tracking import MlflowClient | ||
|
||
MetricItem = Union[float, List[float], Dict[str, float], List[Dict[str, float]]] | ||
MetricTuple = Tuple[str, float, int] | ||
MetricsDict = Dict[str, MetricItem] | ||
|
||
|
||
class MlflowMetricsDataSet(AbstractDataSet): | ||
"""This class represent MLflow metrics dataset.""" | ||
|
||
DEFAULT_LOAD_ARGS: Dict[str, Any] = {} | ||
DEFAULT_SAVE_ARGS: Dict[str, Any] = {} | ||
|
||
def __init__( | ||
self, | ||
prefix: Optional[str] = None, | ||
save_args: Dict[str, Any] = None, | ||
load_args: Dict[str, Any] = None, | ||
): | ||
"""Initialise MlflowMetricsDataSet. | ||
Args: | ||
prefix (Optional[str]): Prefix for metrics logged in MLflow. | ||
save_args (Dict[str, Any]): MLflow options for loading metrics. | ||
load_args (Dict[str, Any]): MLflow options for saving metrics. | ||
""" | ||
self._prefix = prefix | ||
|
||
# Handle default load and save arguments | ||
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) | ||
if load_args is not None: | ||
self._load_args.update(load_args) | ||
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) | ||
if save_args is not None: | ||
self._save_args.update(save_args) | ||
|
||
def _load(self) -> MetricsDict: | ||
"""Load MlflowMetricDataSet. | ||
Returns: | ||
Dict[str, Union[int, float]]: Dictionary with MLflow metrics dataset. | ||
""" | ||
client = MlflowClient() | ||
run_id = self._load_args.get("run_id") or self._get_active_run_id() | ||
all_metrics = client._tracking_client.store.get_all_metrics(run_uuid=run_id) | ||
dataset_metrics = filter(self._is_dataset_metric, all_metrics) | ||
dataset = reduce( | ||
lambda xs, x: self._update_metric( | ||
# get_all_metrics returns last saved values per metric key. | ||
# All values are required here. | ||
client.get_metric_history(run_id, x.key), | ||
xs, | ||
), | ||
dataset_metrics, | ||
{}, | ||
) | ||
return dataset | ||
|
||
def _save(self, data: MetricsDict) -> None: | ||
"""Save given MLflow metrics dataset and log it in MLflow as metrics. | ||
Args: | ||
data (MetricsDict): MLflow metrics dataset. | ||
""" | ||
client = MlflowClient() | ||
try: | ||
run_id = self._save_args.get("run_id") or self._get_active_run_id() | ||
except DataSetError: | ||
# If run_id can't be found log_metric would create new run. | ||
run_id = None | ||
|
||
log_metric = ( | ||
partial(client.log_metric, run_id) | ||
if run_id is not None | ||
else mlflow.log_metric | ||
) | ||
metrics = ( | ||
self._build_args_list_from_metric_item(k, v) for k, v, in data.items() | ||
) | ||
for k, v, i in chain.from_iterable(metrics): | ||
log_metric(k, v, step=i) | ||
|
||
def _exists(self) -> bool: | ||
"""Check if MLflow metrics dataset exists. | ||
Returns: | ||
bool: Is MLflow metrics dataset exists? | ||
""" | ||
client = MlflowClient() | ||
run_id = self._load_args.get("run_id") or self._get_active_run_id() | ||
all_metrics = client._tracking_client.store.get_all_metrics(run_uuid=run_id) | ||
return any(self._is_dataset_metric(x) for x in all_metrics) | ||
|
||
def _describe(self) -> Dict[str, Any]: | ||
"""Describe MLflow metrics dataset. | ||
Returns: | ||
Dict[str, Any]: Dictionary with MLflow metrics dataset description. | ||
""" | ||
return { | ||
"prefix": self._prefix, | ||
} | ||
|
||
def _get_active_run_id(self) -> str: | ||
"""Get run id. | ||
If active run is not found, tries to find last experiment. | ||
Raise `DataSetError` exception if run id can't be found. | ||
Returns: | ||
str: String contains run_id. | ||
""" | ||
run = mlflow.active_run() | ||
if run: | ||
return run.info.run_id | ||
raise DataSetError("Cannot find run id.") | ||
|
||
def _is_dataset_metric(self, metric: mlflow.entities.Metric) -> bool: | ||
"""Check if given metric belongs to dataset. | ||
Args: | ||
metric (mlflow.entities.Metric): MLflow metric instance. | ||
""" | ||
return self._prefix is None or ( | ||
self._prefix and metric.key.startswith(self._prefix) | ||
) | ||
|
||
@staticmethod | ||
def _update_metric( | ||
metrics: List[mlflow.entities.Metric], dataset: MetricsDict = {} | ||
) -> MetricsDict: | ||
"""Update metric in given dataset. | ||
Args: | ||
metrics (List[mlflow.entities.Metric]): List with MLflow metric objects. | ||
dataset (MetricsDict): Dictionary contains MLflow metrics dataset. | ||
Returns: | ||
MetricsDict: Dictionary with MLflow metrics dataset. | ||
""" | ||
for metric in metrics: | ||
metric_dict = {"step": metric.step, "value": metric.value} | ||
if metric.key in dataset: | ||
if isinstance(dataset[metric.key], list): | ||
dataset[metric.key].append(metric_dict) | ||
else: | ||
dataset[metric.key] = [dataset[metric.key], metric_dict] | ||
else: | ||
dataset[metric.key] = metric_dict | ||
return dataset | ||
|
||
def _build_args_list_from_metric_item( | ||
self, key: str, value: MetricItem | ||
) -> Generator[MetricTuple, None, None]: | ||
"""Build list of tuples with metrics. | ||
First element of a tuple is key, second metric value, third step. | ||
If MLflow metrics dataset has prefix, it will be attached to key. | ||
Args: | ||
key (str): Metric key. | ||
value (MetricItem): Metric value | ||
Returns: | ||
List[MetricTuple]: List with metrics as tuples. | ||
""" | ||
if self._prefix: | ||
key = f"{self._prefix}.{key}" | ||
if isinstance(value, float): | ||
return (i for i in [(key, value, 0)]) | ||
if isinstance(value, dict): | ||
return (i for i in [(key, value["value"], value["step"])]) | ||
if isinstance(value, list) and len(value) > 0: | ||
if isinstance(value[0], dict): | ||
return ((key, x["value"], x["step"]) for x in value) | ||
return ((key, v, i) for i, v in enumerate(value)) | ||
raise DataSetError( | ||
f"Unexpected metric value. Should be of type `{MetricItem}`, got {type(value)}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
from typing import Dict, List, Optional, Union | ||
|
||
import mlflow | ||
import pytest | ||
from mlflow.tracking import MlflowClient | ||
from pytest_lazyfixture import lazy_fixture | ||
|
||
from kedro_mlflow.io import MlflowMetricsDataSet | ||
|
||
|
||
def assert_are_metrics_logged( | ||
data: Dict[str, Union[float, List[float]]], | ||
client: MlflowClient, | ||
run_id: str, | ||
prefix: Optional[str] = None, | ||
) -> bool: | ||
"""Helper function which checks if given metrics where logged. | ||
Args: | ||
data: (Dict[str, Union[float, List[float]]]): Logged metrics. | ||
client: (MlflowClient): MLflow client instance. | ||
run_id: (str): id of run where data was logged. | ||
prefix: (Optional[str]) | ||
""" | ||
for key in data.keys(): | ||
metric_key = f"{prefix}.{key}" if prefix else key | ||
metric = client.get_metric_history(run_id, metric_key) | ||
data_len = len(data[key]) if isinstance(data[key], list) else 1 | ||
assert len(metric) == data_len | ||
for idx, item in enumerate(metric): | ||
if isinstance(data[key], list): | ||
data_value = ( | ||
data[key][idx] | ||
if isinstance(data[key][idx], float) | ||
else data[key][idx]["value"] | ||
) | ||
elif isinstance(data[key], dict): | ||
data_value = data[key]["value"] | ||
else: | ||
data_value = data[key] | ||
assert item.value == data_value and item.key == metric_key | ||
assert True | ||
|
||
|
||
@pytest.fixture | ||
def tracking_uri(tmp_path): | ||
return tmp_path / "mlruns" | ||
|
||
|
||
@pytest.fixture | ||
def metrics(): | ||
return {"metric1": 1.1, "metric2": 1.2} | ||
|
||
|
||
@pytest.fixture | ||
def metric_with_multiple_values(): | ||
return {"metric1": [1.1, 1.2, 1.3]} | ||
|
||
|
||
@pytest.fixture | ||
def metrics_with_one_and_multiple_values(): | ||
return {"metric1": [1.1, 1.2, 1.3], "metric2": 1.2} | ||
|
||
|
||
@pytest.fixture | ||
def metrics2(): | ||
return { | ||
"metric1": [ | ||
{"step": 0, "value": 1.1}, | ||
{"step": 1, "value": 1.2}, | ||
{"step": 2, "value": 1.3}, | ||
], | ||
"metric2": 1.2, | ||
"metric3": {"step": 0, "value": 1.4}, | ||
} | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"data, prefix", | ||
[ | ||
(lazy_fixture("metrics"), None), | ||
(lazy_fixture("metrics"), "test"), | ||
(lazy_fixture("metrics2"), None), | ||
(lazy_fixture("metrics2"), "test"), | ||
(lazy_fixture("metric_with_multiple_values"), None), | ||
(lazy_fixture("metric_with_multiple_values"), "test"), | ||
(lazy_fixture("metrics_with_one_and_multiple_values"), None), | ||
(lazy_fixture("metrics_with_one_and_multiple_values"), "test"), | ||
], | ||
) | ||
def test_mlflow_metrics_dataset_saved_and_logged(tmp_path, tracking_uri, data, prefix): | ||
"""Check if MlflowMetricsDataSet can be saved in catalog when filepath is given, | ||
and if logged in mlflow. | ||
""" | ||
mlflow.set_tracking_uri(tracking_uri.as_uri()) | ||
mlflow_client = MlflowClient(tracking_uri=tracking_uri.as_uri()) | ||
mlflow_metrics_dataset = MlflowMetricsDataSet(prefix=prefix) | ||
|
||
with mlflow.start_run(): | ||
run_id = mlflow.active_run().info.run_id | ||
mlflow_metrics_dataset.save(data) | ||
|
||
# Check if metrics where logged corectly in MLflow. | ||
assert_are_metrics_logged(data, mlflow_client, run_id, prefix) | ||
|
||
# Check if metrics are stored in catalog. | ||
catalog_metrics = MlflowMetricsDataSet( | ||
prefix=prefix, | ||
# Run id needs to be provided as there is no active run. | ||
load_args={"run_id": run_id}, | ||
).load() | ||
|
||
assert len(catalog_metrics) == len(data) | ||
for k in catalog_metrics.keys(): | ||
data_key = k.split(".")[-1] if prefix is not None else k | ||
if isinstance(data[data_key], list): | ||
assert isinstance(catalog_metrics[k], list) | ||
if isinstance(data[data_key][0], dict): | ||
assert data[data_key] == catalog_metrics[k] | ||
elif isinstance(data[data_key][0], float): | ||
assert data[data_key] == [x["value"] for x in catalog_metrics[k]] | ||
elif isinstance(data[data_key], dict): | ||
assert isinstance(catalog_metrics[k], dict) | ||
assert data[data_key] == catalog_metrics[k] | ||
else: | ||
assert isinstance(catalog_metrics[k], dict) | ||
assert data[data_key] == catalog_metrics[k]["value"] |