Skip to content

Commit

Permalink
FIX Galileo-Galilei#9 MlflowMetricsDataSet implemented.
Browse files Browse the repository at this point in the history
  • Loading branch information
Adrian Piotr Kruszewski committed Aug 20, 2020
1 parent 379b617 commit e86f274
Show file tree
Hide file tree
Showing 5 changed files with 378 additions and 1 deletion.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## [Unreleased]

### Added

- Add dataset ``MlflowMetricsDataSet`` for metrics logging ([#9](https://github.com/Galileo-Galilei/kedro-mlflow/issues/9)) and update documentation for metrics.`

## [0.2.1] - 2018-08-06

### Added
Expand Down
61 changes: 60 additions & 1 deletion docs/source/03_tutorial/07_version_metrics.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,62 @@
# Version metrics

This is coming soon. If you want to keep track of the progress on this feature, [follow this issue](https://github.com/Galileo-Galilei/kedro-mlflow/issues/9).
## What is metric tracking?

Mlflow defines metrics as "Key-value metrics, where the value is numeric. Each metric can be updated throughout the course of the run (for example, to track how your model’s loss function is converging), and MLflow records and lets you visualize the metric’s full history".

## How to version metrics in a kedro project?

kedro-mlflow introduces a new ``AbstractDataSet`` called ``MlflowMetricsDataSet``. It is wrapper around dictionary with metrics which is returned by node and log metrics in MLflow.

Since it is a ``AbstractDataSet``, it can be used with the YAML API. You can define it as:

```yaml
my_model_metrics:
type: kedro_mlflow.io.MlflowMetricsDataSet
```
It can get also ``prefix`` configuration option. This is useful especially when your pipeline evaluate metrics on different datasets. For example:
```yaml
my_model_metrics_dev:
type: kedro_mlflow.io.MlflowMetricsDataSet
prefix: dev
my_model_metrics_test:
type: kedro_mlflow.io.MlflowMetricsDataSet
prefix: test
```
In that scenario metrics will be available in MLflow with given prefixes. For example your ``accuracy`` metric from example above, for ``my_model_metrics_test`` will be stored under key ``test.accuracy``, for ``my_model_metrics_dev``, under key ``dev.accuracy``.
## How to return metrics from node?
Let assume that you have node which doesn't have any inputs and returns dictionary with metrics to log:
```python
def metrics_node() -> Dict[str, Union[float, List[float]]]:
return {
"metric1": 1.0,
"metric2": [1.0, 1.1]
}
```

As you can see above, ``kedro_mlflow.io.MlflowMetricsDataSet`` can take as metrics ``floats`` or ``lists`` of ``floats``. In first case under the given metric key just one value will be logged, in second a series of values.

To store metrics we need to define metrics dataset in Kedro Catalog:

```yaml
my_model_metrics:
type: kedro_mlflow.io.MlflowMetricsDataSet
```
To fulfill example we also need pipeline which will use this node and store metrics under ``my_model_metrics`` name.
```python
def create_pipeline() -> Pipeline:
return Pipeline(node(
func=metrics_node,
inputs=None,
outputs="my_model_metrics",
name="log_metrics",
))
```
1 change: 1 addition & 0 deletions kedro_mlflow/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .mlflow_dataset import MlflowDataSet
from .mlflow_metrics_dataset import MlflowMetricsDataSet
184 changes: 184 additions & 0 deletions kedro_mlflow/io/mlflow_metrics_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from copy import deepcopy
from functools import partial, reduce
from itertools import chain
from typing import Any, Dict, Generator, List, Optional, Tuple, Union

import mlflow
from kedro.io import AbstractDataSet, DataSetError
from mlflow.tracking import MlflowClient


MetricItem = Union[float, List[float], Dict[str, float], List[Dict[str, float]]]
MetricTuple = Tuple[str, float, int]
MetricsDict = Dict[str, MetricItem]


class MlflowMetricsDataSet(AbstractDataSet):
"""This class represent mlflow metrics.
It decorates their ``save`` method to log the metrics in mlflow when
``save`` is called.
"""

DEFAULT_LOAD_ARGS: Dict[str, Any] = {}
DEFAULT_SAVE_ARGS: Dict[str, Any] = {}

def __init__(
self,
prefix: Optional[str] = None,
save_args: Dict[str, Any] = None,
load_args: Dict[str, Any] = None,
):
"""Initialise MLflowMetricsDataSet.
Args:
prefix (Optional[str]): Prefix for metrics logged in mlflow.
save_args (Dict[str, Any]): Mlflow options for loading metrics.
load_args (Dict[str, Any]): Mlflow options for saving metrics.
"""
self._prefix = prefix

# Handle default load and save arguments
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)

def _load(self) -> MetricsDict:
"""Load MlflowMetricDataSet.
Returns:
Dict[str, Union[int, float]]: Dictionary with MLflow metrics.
"""
client = MlflowClient()
run_id = self._load_args.get("run_id") or self._get_active_run_id(client)
all_metrics = client._tracking_client.store.get_all_metrics(run_uuid=run_id)
dataset_metrics = filter(self._is_dataset_metric, all_metrics)
dataset = reduce(
lambda xs, x: self._update_metric(
# get_all_metrics returns last saved values per metric key.
# All values are required here.
client.get_metric_history(run_id, x.key), xs
),
dataset_metrics,
{}
)
return dataset

def _save(self, data: MetricsDict) -> None:
"""Save given dataset and log it in mlflow as metrics.
Args:
data (MetricsDict): MLflow metrics data.
"""
client = MlflowClient()
run_id = self._save_args.get("run_id") or self._get_active_run_id(client)
log_metric = (
partial(client.log_metric, run_id)
if run_id is not None
else mlflow.log_metric
)
metrics = (self._build_args_list_from_metric_item(k, v) for k, v, in data.items())
for k, v, i in chain.from_iterable(metrics):
log_metric(k, v, step=i)

def _exists(self) -> bool:
"""Check if dataset exists.
Returns:
bool: Is dataset exists?
"""
client = MlflowClient()
run_id = self._load_args.get("run_id") or self._get_active_run_id(client)
all_metrics = client._tracking_client.store.get_all_metrics(run_uuid=run_id)
return any(self._is_dataset_metric(x) for x in all_metrics)

def _describe(self) -> Dict[str, Any]:
"""Describe dataset.
Returns:
Dict[str, Any]: Dictionary with dataset description.
"""
return {
"prefix": self._prefix,
}

def _get_active_run_id(self, client: MlflowClient) -> str:
"""Get run id.
If active run is not found, tries to find last experiment.
Raise `DataSetError` exception if run id can't be found.
Args:
client (MlflowClient): MLflow client instance.
Returns:
str: String contains run_id.
"""
run = mlflow.active_run()
if run:
return run.info.run_id
raise DataSetError("Cannot find run id.")

def _is_dataset_metric(self, metric: mlflow.entities.Metric) -> bool:
"""Check if given metric belongs to dataset.
Args:
metric (mlflow.entities.Metric): MLflow metric instance.
"""
return self._prefix is None or (
self._prefix and metric.key.startswith(self._prefix)
)

@staticmethod
def _update_metric(metrics: List[mlflow.entities.Metric], dataset: MetricsDict = {}) -> MetricsDict:
"""Update metric in given dataset.
Args:
metrics (List[mlflow.entities.Metric]): List with MLflow metric objects.
dataset (MetricsDict): Dictionary contains metrics dataset.
Returns:
MetricsDict: Dictionary with metrics dataset.
"""
for metric in metrics:
metric_dict = {"step": metric.step, "value": metric.value}
if metric.key in dataset:
if isinstance(dataset[metric.key], list):
dataset[metric.key].append(metric_dict)
else:
dataset[metric.key] = [dataset[metric.key], metric_dict]
else:
dataset[metric.key] = metric_dict
return dataset

def _build_args_list_from_metric_item(self, key: str, value: MetricItem) -> Generator[MetricTuple, None, None]:
"""Build list of tuples with metrics.
First element of a tuple is key, second metric value, third step.
If dataset has prefix, it will be attached to key.
Args:
key (str): Metric key.
value (MetricItem): Metric value
Returns:
List[MetricTuple]: List with metrics as tuples.
"""
if self._prefix:
key = f"{self._prefix}.{key}"
if isinstance(value, float):
return (i for i in [(key, value, 0)])
if isinstance(value, dict):
return (i for i in [(key, value["value"], value["step"])])
if isinstance(value, list) and len(value) > 0:
if isinstance(value[0], dict):
return ((key, x["value"], x["step"]) for x in value)
return ((key, v, i) for i, v in enumerate(value))
raise DataSetError(
f"Unexpected metric value. Should be of type `{MetricItem}`, got {type(value)}"
)
129 changes: 129 additions & 0 deletions tests/io/test_mlflow_metrics_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from typing import Dict, List, Optional, Union

import mlflow
import pytest
from mlflow.tracking import MlflowClient
from pytest_lazyfixture import lazy_fixture

from kedro_mlflow.io import MlflowMetricsDataSet


def assert_are_metrics_logged(
data: Dict[str, Union[float, List[float]]],
client: MlflowClient,
run_id: str,
prefix: Optional[str] = None,
) -> bool:
"""Helper function which checks if given metrics where logged.
Args:
data: (Dict[str, Union[float, List[float]]]): Logged metrics.
client: (MlflowClient): MLflow client instance.
run_id: (str): id of run where data was logged.
prefix: (Optional[str])
"""
for key in data.keys():
metric_key = f"{prefix}.{key}" if prefix else key
metric = client.get_metric_history(run_id, metric_key)
data_len = (len(data[key]) if isinstance(data[key], list) else 1)
assert len(metric) == data_len
for idx, item in enumerate(metric):
if isinstance(data[key], list):
data_value = (
data[key][idx]
if isinstance(data[key][idx], float)
else data[key][idx]["value"]
)
elif isinstance(data[key], dict):
data_value = data[key]["value"]
else:
data_value = data[key]
assert (item.value == data_value and item.key == metric_key)
assert True


@pytest.fixture
def tracking_uri(tmp_path):
return tmp_path / "mlruns"


@pytest.fixture
def metrics():
return {"metric1": 1.1, "metric2": 1.2}


@pytest.fixture
def metric_with_multiple_values():
return {"metric1": [1.1, 1.2, 1.3]}


@pytest.fixture
def metrics_with_one_and_multiple_values():
return {"metric1": [1.1, 1.2, 1.3], "metric2": 1.2}


@pytest.fixture
def metrics2():
return {
"metric1": [
{"step": 0, "value": 1.1},
{"step": 1, "value": 1.2},
{"step": 2, "value": 1.3},
],
"metric2": 1.2,
"metric3": {"step": 0, "value": 1.4}
}


@pytest.mark.parametrize(
"data, prefix",
[
(lazy_fixture("metrics"), None),
(lazy_fixture("metrics"), "test"),
(lazy_fixture("metrics2"), None),
(lazy_fixture("metrics2"), "test"),
(lazy_fixture("metric_with_multiple_values"), None),
(lazy_fixture("metric_with_multiple_values"), "test"),
(lazy_fixture("metrics_with_one_and_multiple_values"), None),
(lazy_fixture("metrics_with_one_and_multiple_values"), "test"),
],
)
def test_mlflow_metrics_dataset_saved_and_logged(
tmp_path, tracking_uri, data, prefix
):
"""Check if MlflowMetricsDataSet can be saved in catalog when filepath is given,
and if logged in mlflow.
"""
mlflow.set_tracking_uri(tracking_uri.as_uri())
mlflow_client = MlflowClient(tracking_uri=tracking_uri.as_uri())
mlflow_metrics_dataset = MlflowMetricsDataSet(prefix=prefix)

with mlflow.start_run():
run_id = mlflow.active_run().info.run_id
mlflow_metrics_dataset.save(data)

# Check if metrics where logged corectly in MLflow.
assert_are_metrics_logged(data, mlflow_client, run_id, prefix)

# Check if metrics are stored in catalog.
catalog_metrics = MlflowMetricsDataSet(
prefix=prefix,
# Run id needs to be provided as there is no active run.
load_args={"run_id": run_id}
).load()

assert len(catalog_metrics) == len(data)
for k in catalog_metrics.keys():
data_key = k.split(".")[-1] if prefix is not None else k
if isinstance(data[data_key], list):
assert isinstance(catalog_metrics[k], list)
if isinstance(data[data_key][0], dict):
assert data[data_key] == catalog_metrics[k]
elif isinstance(data[data_key][0], float):
assert data[data_key] == [x["value"] for x in catalog_metrics[k]]
elif isinstance(data[data_key], dict):
assert isinstance(catalog_metrics[k], dict)
assert data[data_key] == catalog_metrics[k]
else:
assert isinstance(catalog_metrics[k], dict)
assert data[data_key] == catalog_metrics[k]["value"]

0 comments on commit e86f274

Please sign in to comment.