Skip to content

Commit

Permalink
FIX Galileo-Galilei#9 MlflowMetricsDataSet implemented.
Browse files Browse the repository at this point in the history
  • Loading branch information
Adrian Piotr Kruszewski committed Aug 20, 2020
1 parent 379b617 commit b71c81d
Show file tree
Hide file tree
Showing 5 changed files with 380 additions and 1 deletion.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## [Unreleased]

### Added

- Add dataset ``MlflowMetricsDataSet`` for metrics logging ([#9](https://github.com/Galileo-Galilei/kedro-mlflow/issues/9)) and update documentation for metrics.`

## [0.2.1] - 2018-08-06

### Added
Expand Down
61 changes: 60 additions & 1 deletion docs/source/03_tutorial/07_version_metrics.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,62 @@
# Version metrics

This is coming soon. If you want to keep track of the progress on this feature, [follow this issue](https://github.com/Galileo-Galilei/kedro-mlflow/issues/9).
## What is metric tracking?

Mlflow defines metrics as "Key-value metrics, where the value is numeric. Each metric can be updated throughout the course of the run (for example, to track how your model’s loss function is converging), and MLflow records and lets you visualize the metric’s full history".

## How to version metrics in a kedro project?

kedro-mlflow introduces a new ``AbstractDataSet`` called ``MlflowMetricsDataSet``. It is wrapper around dictionary with metrics which is returned by node and log metrics in MLflow.

Since it is a ``AbstractDataSet``, it can be used with the YAML API. You can define it as:

```yaml
my_model_metrics:
type: kedro_mlflow.io.MlflowMetricsDataSet
```
It can get also ``prefix`` configuration option. This is useful especially when your pipeline evaluate metrics on different datasets. For example:
```yaml
my_model_metrics_dev:
type: kedro_mlflow.io.MlflowMetricsDataSet
prefix: dev
my_model_metrics_test:
type: kedro_mlflow.io.MlflowMetricsDataSet
prefix: test
```
In that scenario metrics will be available in MLflow with given prefixes. For example your ``accuracy`` metric from example above, for ``my_model_metrics_test`` will be stored under key ``test.accuracy``, for ``my_model_metrics_dev``, under key ``dev.accuracy``.
## How to return metrics from node?
Let assume that you have node which doesn't have any inputs and returns dictionary with metrics to log:
```python
def metrics_node() -> Dict[str, Union[float, List[float]]]:
return {
"metric1": 1.0,
"metric2": [1.0, 1.1]
}
```

As you can see above, ``kedro_mlflow.io.MlflowMetricsDataSet`` can take as metrics ``floats`` or ``lists`` of ``floats``. In first case under the given metric key just one value will be logged, in second a series of values.

To store metrics we need to define metrics dataset in Kedro Catalog:

```yaml
my_model_metrics:
type: kedro_mlflow.io.MlflowMetricsDataSet
```
To fulfill example we also need pipeline which will use this node and store metrics under ``my_model_metrics`` name.
```python
def create_pipeline() -> Pipeline:
return Pipeline(node(
func=metrics_node,
inputs=None,
outputs="my_model_metrics",
name="log_metrics",
))
```
1 change: 1 addition & 0 deletions kedro_mlflow/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .mlflow_dataset import MlflowDataSet
from .mlflow_metrics_dataset import MlflowMetricsDataSet
188 changes: 188 additions & 0 deletions kedro_mlflow/io/mlflow_metrics_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from copy import deepcopy
from functools import partial, reduce
from itertools import chain
from typing import Any, Dict, Generator, List, Optional, Tuple, Union

import mlflow
from kedro.io import AbstractDataSet, DataSetError
from mlflow.tracking import MlflowClient

MetricItem = Union[float, List[float], Dict[str, float], List[Dict[str, float]]]
MetricTuple = Tuple[str, float, int]
MetricsDict = Dict[str, MetricItem]


class MlflowMetricsDataSet(AbstractDataSet):
"""This class represent MLflow metrics dataset."""

DEFAULT_LOAD_ARGS: Dict[str, Any] = {}
DEFAULT_SAVE_ARGS: Dict[str, Any] = {}

def __init__(
self,
prefix: Optional[str] = None,
save_args: Dict[str, Any] = None,
load_args: Dict[str, Any] = None,
):
"""Initialise MlflowMetricsDataSet.
Args:
prefix (Optional[str]): Prefix for metrics logged in MLflow.
save_args (Dict[str, Any]): MLflow options for loading metrics.
load_args (Dict[str, Any]): MLflow options for saving metrics.
"""
self._prefix = prefix

# Handle default load and save arguments
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)

def _load(self) -> MetricsDict:
"""Load MlflowMetricDataSet.
Returns:
Dict[str, Union[int, float]]: Dictionary with MLflow metrics dataset.
"""
client = MlflowClient()
run_id = self._load_args.get("run_id") or self._get_active_run_id()
all_metrics = client._tracking_client.store.get_all_metrics(run_uuid=run_id)
dataset_metrics = filter(self._is_dataset_metric, all_metrics)
dataset = reduce(
lambda xs, x: self._update_metric(
# get_all_metrics returns last saved values per metric key.
# All values are required here.
client.get_metric_history(run_id, x.key),
xs,
),
dataset_metrics,
{},
)
return dataset

def _save(self, data: MetricsDict) -> None:
"""Save given MLflow metrics dataset and log it in MLflow as metrics.
Args:
data (MetricsDict): MLflow metrics dataset.
"""
client = MlflowClient()
try:
run_id = self._save_args.get("run_id") or self._get_active_run_id()
except DataSetError:
# If run_id can't be found log_metric would create new run.
run_id = None

log_metric = (
partial(client.log_metric, run_id)
if run_id is not None
else mlflow.log_metric
)
metrics = (
self._build_args_list_from_metric_item(k, v) for k, v, in data.items()
)
for k, v, i in chain.from_iterable(metrics):
log_metric(k, v, step=i)

def _exists(self) -> bool:
"""Check if MLflow metrics dataset exists.
Returns:
bool: Is MLflow metrics dataset exists?
"""
client = MlflowClient()
run_id = self._load_args.get("run_id") or self._get_active_run_id()
all_metrics = client._tracking_client.store.get_all_metrics(run_uuid=run_id)
return any(self._is_dataset_metric(x) for x in all_metrics)

def _describe(self) -> Dict[str, Any]:
"""Describe MLflow metrics dataset.
Returns:
Dict[str, Any]: Dictionary with MLflow metrics dataset description.
"""
return {
"prefix": self._prefix,
}

def _get_active_run_id(self) -> str:
"""Get run id.
If active run is not found, tries to find last experiment.
Raise `DataSetError` exception if run id can't be found.
Returns:
str: String contains run_id.
"""
run = mlflow.active_run()
if run:
return run.info.run_id
raise DataSetError("Cannot find run id.")

def _is_dataset_metric(self, metric: mlflow.entities.Metric) -> bool:
"""Check if given metric belongs to dataset.
Args:
metric (mlflow.entities.Metric): MLflow metric instance.
"""
return self._prefix is None or (
self._prefix and metric.key.startswith(self._prefix)
)

@staticmethod
def _update_metric(
metrics: List[mlflow.entities.Metric], dataset: MetricsDict = {}
) -> MetricsDict:
"""Update metric in given dataset.
Args:
metrics (List[mlflow.entities.Metric]): List with MLflow metric objects.
dataset (MetricsDict): Dictionary contains MLflow metrics dataset.
Returns:
MetricsDict: Dictionary with MLflow metrics dataset.
"""
for metric in metrics:
metric_dict = {"step": metric.step, "value": metric.value}
if metric.key in dataset:
if isinstance(dataset[metric.key], list):
dataset[metric.key].append(metric_dict)
else:
dataset[metric.key] = [dataset[metric.key], metric_dict]
else:
dataset[metric.key] = metric_dict
return dataset

def _build_args_list_from_metric_item(
self, key: str, value: MetricItem
) -> Generator[MetricTuple, None, None]:
"""Build list of tuples with metrics.
First element of a tuple is key, second metric value, third step.
If MLflow metrics dataset has prefix, it will be attached to key.
Args:
key (str): Metric key.
value (MetricItem): Metric value
Returns:
List[MetricTuple]: List with metrics as tuples.
"""
if self._prefix:
key = f"{self._prefix}.{key}"
if isinstance(value, float):
return (i for i in [(key, value, 0)])
if isinstance(value, dict):
return (i for i in [(key, value["value"], value["step"])])
if isinstance(value, list) and len(value) > 0:
if isinstance(value[0], dict):
return ((key, x["value"], x["step"]) for x in value)
return ((key, v, i) for i, v in enumerate(value))
raise DataSetError(
f"Unexpected metric value. Should be of type `{MetricItem}`, got {type(value)}"
)
127 changes: 127 additions & 0 deletions tests/io/test_mlflow_metrics_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from typing import Dict, List, Optional, Union

import mlflow
import pytest
from mlflow.tracking import MlflowClient
from pytest_lazyfixture import lazy_fixture

from kedro_mlflow.io import MlflowMetricsDataSet


def assert_are_metrics_logged(
data: Dict[str, Union[float, List[float]]],
client: MlflowClient,
run_id: str,
prefix: Optional[str] = None,
) -> bool:
"""Helper function which checks if given metrics where logged.
Args:
data: (Dict[str, Union[float, List[float]]]): Logged metrics.
client: (MlflowClient): MLflow client instance.
run_id: (str): id of run where data was logged.
prefix: (Optional[str])
"""
for key in data.keys():
metric_key = f"{prefix}.{key}" if prefix else key
metric = client.get_metric_history(run_id, metric_key)
data_len = len(data[key]) if isinstance(data[key], list) else 1
assert len(metric) == data_len
for idx, item in enumerate(metric):
if isinstance(data[key], list):
data_value = (
data[key][idx]
if isinstance(data[key][idx], float)
else data[key][idx]["value"]
)
elif isinstance(data[key], dict):
data_value = data[key]["value"]
else:
data_value = data[key]
assert item.value == data_value and item.key == metric_key
assert True


@pytest.fixture
def tracking_uri(tmp_path):
return tmp_path / "mlruns"


@pytest.fixture
def metrics():
return {"metric1": 1.1, "metric2": 1.2}


@pytest.fixture
def metric_with_multiple_values():
return {"metric1": [1.1, 1.2, 1.3]}


@pytest.fixture
def metrics_with_one_and_multiple_values():
return {"metric1": [1.1, 1.2, 1.3], "metric2": 1.2}


@pytest.fixture
def metrics2():
return {
"metric1": [
{"step": 0, "value": 1.1},
{"step": 1, "value": 1.2},
{"step": 2, "value": 1.3},
],
"metric2": 1.2,
"metric3": {"step": 0, "value": 1.4},
}


@pytest.mark.parametrize(
"data, prefix",
[
(lazy_fixture("metrics"), None),
(lazy_fixture("metrics"), "test"),
(lazy_fixture("metrics2"), None),
(lazy_fixture("metrics2"), "test"),
(lazy_fixture("metric_with_multiple_values"), None),
(lazy_fixture("metric_with_multiple_values"), "test"),
(lazy_fixture("metrics_with_one_and_multiple_values"), None),
(lazy_fixture("metrics_with_one_and_multiple_values"), "test"),
],
)
def test_mlflow_metrics_dataset_saved_and_logged(tmp_path, tracking_uri, data, prefix):
"""Check if MlflowMetricsDataSet can be saved in catalog when filepath is given,
and if logged in mlflow.
"""
mlflow.set_tracking_uri(tracking_uri.as_uri())
mlflow_client = MlflowClient(tracking_uri=tracking_uri.as_uri())
mlflow_metrics_dataset = MlflowMetricsDataSet(prefix=prefix)

with mlflow.start_run():
run_id = mlflow.active_run().info.run_id
mlflow_metrics_dataset.save(data)

# Check if metrics where logged corectly in MLflow.
assert_are_metrics_logged(data, mlflow_client, run_id, prefix)

# Check if metrics are stored in catalog.
catalog_metrics = MlflowMetricsDataSet(
prefix=prefix,
# Run id needs to be provided as there is no active run.
load_args={"run_id": run_id},
).load()

assert len(catalog_metrics) == len(data)
for k in catalog_metrics.keys():
data_key = k.split(".")[-1] if prefix is not None else k
if isinstance(data[data_key], list):
assert isinstance(catalog_metrics[k], list)
if isinstance(data[data_key][0], dict):
assert data[data_key] == catalog_metrics[k]
elif isinstance(data[data_key][0], float):
assert data[data_key] == [x["value"] for x in catalog_metrics[k]]
elif isinstance(data[data_key], dict):
assert isinstance(catalog_metrics[k], dict)
assert data[data_key] == catalog_metrics[k]
else:
assert isinstance(catalog_metrics[k], dict)
assert data[data_key] == catalog_metrics[k]["value"]

0 comments on commit b71c81d

Please sign in to comment.