Skip to content

Commit

Permalink
♻️ Refactor KedroMlflowConfig with pydantic for robustness (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
Galileo-Galilei committed Sep 1, 2021
1 parent 6b060c0 commit 480cae2
Show file tree
Hide file tree
Showing 16 changed files with 279 additions and 128 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## [Unreleased]

### Changed

- :recycle: ``KedroMlflowConfig`` was refactored with pydantic for improved type checking when loading configuration, overall robustness and autocompletion. Its keys have changed, but it is not considered as a user facing changes since the public function ``get_mlflow_config()`` and ``KedroMlflowConfig().setup()`` are not modified.

## [0.7.4] - 2021-08-30

### Added
Expand Down
2 changes: 1 addition & 1 deletion docs/source/02_installation/03_migration_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ There are no breaking change in this patch release except if you retrieve the ml

```python
from kedro.framework.context import load_context
from kedro_mlflow.framework.context import get_mlflow_config
from kedro_mlflow.config import get_mlflow_config

context=load_context(".")
mlflow_config=get_mlflow_config(context)
Expand Down
4 changes: 4 additions & 0 deletions kedro_mlflow/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""kedro-mlflow context imports
"""

from .kedro_mlflow_config import get_mlflow_config
197 changes: 197 additions & 0 deletions kedro_mlflow/config/kedro_mlflow_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import os
from pathlib import Path, PurePath
from typing import List, Optional

import mlflow
from kedro.config import MissingConfigException
from kedro.framework.session import KedroSession, get_current_session
from kedro.framework.startup import _is_project
from mlflow.entities import Experiment
from mlflow.tracking.client import MlflowClient
from pydantic import BaseModel, PrivateAttr, StrictBool, validator
from typing_extensions import Literal


class DisableTrackingOptions(BaseModel):
# mutable default is ok for pydantic : https://stackoverflow.com/questions/63793662/how-to-give-a-pydantic-list-field-a-default-value
pipelines: List[str] = []

class Config:
extra = "forbid"


class ExperimentOptions(BaseModel):
name: str = "Default"
create: StrictBool = True

class Config:
extra = "forbid"


class RunOptions(BaseModel):
id: Optional[str]
name: Optional[str]
nested: StrictBool = True

class Config:
extra = "forbid"


class UiOptions(BaseModel):
port: str = "5000"
host: str = "127.0.0.1"

class Config:
extra = "forbid"


class NodeHookOptions(BaseModel):
flatten_dict_params: StrictBool = False
recursive: StrictBool = True
sep: str = "."
long_parameters_strategy: Literal["fail", "truncate", "tag"] = "fail"

class Config:
extra = "forbid"


class HookOptions(BaseModel):
node: NodeHookOptions = NodeHookOptions()

class Config:
extra = "forbid"


class KedroMlflowConfig(BaseModel):
project_path: Path # if str, will be converted
mlflow_tracking_uri: str = "mlruns"
credentials: Optional[str]
disable_tracking: DisableTrackingOptions = DisableTrackingOptions()
experiment: ExperimentOptions = ExperimentOptions()
run: RunOptions = RunOptions()
ui: UiOptions = UiOptions()
hooks: HookOptions = HookOptions()
_mlflow_client: MlflowClient = PrivateAttr()
_experiment: Experiment = PrivateAttr()
# do not create _experiment immediately to avoid creating
# a database connection when creating the object
# it will be instantiated on setup() call

class Config:
# force triggering type control when setting value instead of init
validate_assignment = True
# raise an error if an unknown key is passed to the constructor
extra = "forbid"

def __init__(self, **kwargs):
super().__init__(**kwargs)
# init after validating the uri, else mlflow creates a mlruns folder at the root
self._mlflow_client = MlflowClient(tracking_uri=self.mlflow_tracking_uri)

def setup(self, session: KedroSession = None):
"""Setup all the mlflow configuration"""

self._export_credentials(session)

# we set the configuration now: it takes priority
# if it has already be set in export_credentials
mlflow.set_tracking_uri(self.mlflow_tracking_uri)

self._get_or_create_experiment()

def _export_credentials(self, session: KedroSession = None):
session = session or get_current_session()
context = session.load_context()
conf_creds = context._get_config_credentials()
mlflow_creds = conf_creds.get(self.credentials, {})
for key, value in mlflow_creds.items():
os.environ[key] = value

def _get_or_create_experiment(self):
"""Best effort to get the experiment associated
to the configuration
Returns:
mlflow.entities.Experiment -- [description]
"""

# retrieve the experiment
self._experiment = self._mlflow_client.get_experiment_by_name(
name=self.experiment.name
)

# Deal with two side case when retrieving the experiment
if self.experiment.create:
if self._experiment is None:
# case 1 : the experiment does not exist, it must be created manually
experiment_id = self._mlflow_client.create_experiment(
name=self.experiment.name
)
self._experiment = self._mlflow_client.get_experiment(
experiment_id=experiment_id
)
elif self._experiment.lifecycle_stage == "deleted":
# case 2: the experiment was created, then deleted : we have to restore it manually
self._mlflow_client.restore_experiment(self._experiment.experiment_id)

@validator("project_path")
def _is_kedro_project(cls, folder_path):
if not _is_project(folder_path):
raise KedroMlflowConfigError(
f"'project_path' = '{folder_path}' is not the root of kedro project"
)
return folder_path

# pre=make a conversion before it is set
# always=even for default value
# values enable access to the other field, see https://pydantic-docs.helpmanual.io/usage/validators/
@validator("mlflow_tracking_uri", pre=True, always=True)
def _validate_uri(cls, uri, values):
"""Format the uri provided to match mlflow expectations.
Arguments:
uri {Union[None, str]} -- A valid filepath for mlflow uri
Returns:
str -- A valid mlflow_tracking_uri
"""

# if no tracking uri is provided, we register the runs locally at the root of the project
pathlib_uri = PurePath(uri)

from urllib.parse import urlparse

if pathlib_uri.is_absolute():
valid_uri = pathlib_uri.as_uri()
else:
parsed = urlparse(uri)
if parsed.scheme == "":
# if it is a local relative path, make it absolute
# .resolve() does not work well on windows
# .absolute is undocumented and have known bugs
# Path.cwd() / uri is the recommend way by core developpers.
# See : https://discuss.python.org/t/pathlib-absolute-vs-resolve/2573/6
valid_uri = (values["project_path"] / uri).as_uri()
else:
# else assume it is an uri
valid_uri = uri

return valid_uri


class KedroMlflowConfigError(Exception):
"""Error occurred when loading the configuration"""


def get_mlflow_config(session: Optional[KedroSession] = None):
session = session or get_current_session()
context = session.load_context()
try:
conf_mlflow_yml = context.config_loader.get("mlflow*", "mlflow*/**")
except MissingConfigException:
raise KedroMlflowConfigError(
"No 'mlflow.yml' config file found in environment. Use ``kedro mlflow init`` command in CLI to create a default config file."
)
conf_mlflow_yml["project_path"] = context.project_path
mlflow_config = KedroMlflowConfig.parse_obj(conf_mlflow_yml)
return mlflow_config
6 changes: 3 additions & 3 deletions kedro_mlflow/framework/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from kedro.framework.session import KedroSession
from kedro.framework.startup import _is_project, bootstrap_project

from kedro_mlflow.config import get_mlflow_config
from kedro_mlflow.framework.cli.cli_utils import write_jinja_template
from kedro_mlflow.framework.context import get_mlflow_config

TEMPLATE_FOLDER_PATH = Path(__file__).parent.parent.parent / "template" / "project"

Expand Down Expand Up @@ -149,8 +149,8 @@ def ui(env, port, host):
):

mlflow_conf = get_mlflow_config()
host = host or mlflow_conf.ui_opts.get("host")
port = port or mlflow_conf.ui_opts.get("port")
host = host or mlflow_conf.ui.host
port = port or mlflow_conf.ui.port

# call mlflow ui with specific options
# TODO : add more options for ui
Expand Down
16 changes: 8 additions & 8 deletions kedro_mlflow/framework/hooks/node_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from kedro.pipeline.node import Node
from mlflow.utils.validation import MAX_PARAM_VAL_LENGTH

from kedro_mlflow.framework.context import get_mlflow_config
from kedro_mlflow.config import get_mlflow_config
from kedro_mlflow.framework.hooks.utils import _assert_mlflow_enabled, _flatten_dict


Expand Down Expand Up @@ -53,14 +53,14 @@ def before_pipeline_run(
self._is_mlflow_enabled = _assert_mlflow_enabled(run_params["pipeline_name"])

if self._is_mlflow_enabled:
config = get_mlflow_config()
mlflow_config = get_mlflow_config()

self.flatten = config.node_hook_opts["flatten_dict_params"]
self.recursive = config.node_hook_opts["recursive"]
self.sep = config.node_hook_opts["sep"]
self.long_parameters_strategy = config.node_hook_opts[
"long_parameters_strategy"
]
self.flatten = mlflow_config.hooks.node.flatten_dict_params
self.recursive = mlflow_config.hooks.node.recursive
self.sep = mlflow_config.hooks.node.sep
self.long_parameters_strategy = (
mlflow_config.hooks.node.long_parameters_strategy
)

@hook_impl
def before_node_run(
Expand Down
19 changes: 8 additions & 11 deletions kedro_mlflow/framework/hooks/pipeline_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from mlflow.entities import RunStatus
from mlflow.models import infer_signature

from kedro_mlflow.framework.context import get_mlflow_config
from kedro_mlflow.config import get_mlflow_config
from kedro_mlflow.framework.hooks.utils import _assert_mlflow_enabled
from kedro_mlflow.io.catalog.switch_catalog_logging import switch_catalog_logging
from kedro_mlflow.io.metrics import (
Expand Down Expand Up @@ -111,19 +111,16 @@ def before_pipeline_run(
self._is_mlflow_enabled = _assert_mlflow_enabled(run_params["pipeline_name"])

if self._is_mlflow_enabled:
mlflow_conf = get_mlflow_config()
mlflow_conf.setup()
mlflow_config = get_mlflow_config()
mlflow_config.setup()

run_name = mlflow_config.run.name or run_params["pipeline_name"]

run_name = (
mlflow_conf.run_opts["name"]
if mlflow_conf.run_opts["name"] is not None
else run_params["pipeline_name"]
)
mlflow.start_run(
run_id=mlflow_conf.run_opts["id"],
experiment_id=mlflow_conf.experiment.experiment_id,
run_id=mlflow_config.run.id,
experiment_id=mlflow_config._experiment.experiment_id,
run_name=run_name,
nested=mlflow_conf.run_opts["nested"],
nested=mlflow_config.run.nested,
)
# Set tags only for run parameters that have values.
mlflow.set_tags({k: v for k, v in run_params.items() if v})
Expand Down
4 changes: 2 additions & 2 deletions kedro_mlflow/framework/hooks/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict

from kedro_mlflow.framework.context.mlflow_context import get_mlflow_config
from kedro_mlflow.config.kedro_mlflow_config import get_mlflow_config


def _assert_mlflow_enabled(pipeline_name: str) -> bool:
Expand All @@ -9,7 +9,7 @@ def _assert_mlflow_enabled(pipeline_name: str) -> bool:
# TODO: we may want to enable to filter on tags
# but we need to deal with the case when several tags are passed
# what to do if 1 out of 2 is in the list?
disabled_pipelines = mlflow_config.disable_tracking_opts.get("pipelines") or []
disabled_pipelines = mlflow_config.disable_tracking.pipelines
if pipeline_name in disabled_pipelines:
return False

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from kedro.framework.session import KedroSession
from kedro.framework.startup import bootstrap_project

from kedro_mlflow.framework.context import get_mlflow_config
from kedro_mlflow.framework.context.config import KedroMlflowConfigError
from kedro_mlflow.config import get_mlflow_config
from kedro_mlflow.config.kedro_mlflow_config import KedroMlflowConfigError


def _write_yaml(filepath, config):
Expand Down Expand Up @@ -43,7 +43,7 @@ def test_get_mlflow_config(kedro_project):
"mlflow_tracking_uri": (kedro_project / "mlruns").as_uri(),
"credentials": None,
"disable_tracking": {"pipelines": ["my_disabled_pipeline"]},
"experiments": {"name": "fake_package", "create": True},
"experiment": {"name": "fake_package", "create": True},
"run": {"id": "123456789", "name": "my_run", "nested": True},
"ui": {"port": "5151", "host": "localhost"},
"hooks": {
Expand All @@ -58,7 +58,7 @@ def test_get_mlflow_config(kedro_project):

bootstrap_project(kedro_project)
with KedroSession.create(project_path=kedro_project):
assert get_mlflow_config().to_dict() == expected
assert get_mlflow_config().dict(exclude={"project_path"}) == expected


def test_get_mlflow_config_in_uninitialized_project(kedro_project):
Expand Down Expand Up @@ -109,7 +109,7 @@ def test_mlflow_config_with_templated_config_loader(
"mlflow_tracking_uri": (kedro_project_with_tcl / "dynamic_mlruns").as_uri(),
"credentials": None,
"disable_tracking": {"pipelines": ["my_disabled_pipeline"]},
"experiments": {"name": "fake_package", "create": True},
"experiment": {"name": "fake_package", "create": True},
"run": {"id": "123456789", "name": "my_run", "nested": True},
"ui": {"port": "5151", "host": "localhost"},
"hooks": {
Expand All @@ -123,4 +123,4 @@ def test_mlflow_config_with_templated_config_loader(
}
bootstrap_project(kedro_project_with_tcl)
with KedroSession.create(project_path=kedro_project_with_tcl):
assert get_mlflow_config().to_dict() == expected
assert get_mlflow_config().dict(exclude={"project_path"}) == expected
Loading

0 comments on commit 480cae2

Please sign in to comment.