Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/optional mlflow #164

Merged
merged 12 commits into from
Dec 1, 2023
64 changes: 47 additions & 17 deletions howto/register_new_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@ from sheeprl.models.models import MLP
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.registry import register_algorithm
from sheeprl.utils.env import make_env
from sheeprl.utils.imports import _IS_MLFLOW_AVAILABLE
from sheeprl.utils.logger import get_logger, get_log_dir
from sheeprl.utils.timer import timer
from sheeprl.utils.utils import register_model, unwrap_fabric
from sheeprl.utils.utils import unwrap_fabric


def train(
Expand Down Expand Up @@ -131,9 +132,6 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):
# the optimizer and set up it with Fabric
optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters())

# In case you want to give the possiblity to register your models
local_vars = locals()

# Create a metric aggregator to log the metrics
aggregator = None
if not MetricAggregator.disabled:
Expand Down Expand Up @@ -293,20 +291,52 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):

# Optional part in case you want to give the possibility to register your models with MLFlow
if not cfg.model_manager.disabled and fabric.is_global_zero:
from sheeprl.algos.sota.utils import log_models
from sheeprl.utils.mlflow import register_model

models_to_log = {"agent": agent}
register_model(fabric, log_models, cfg, models_to_log)
```

where `log_models` has to be defined in the `sheeprl.algo.sota.utils` module, for example like this:

```python
from __future__ import annotations

def log_models(
run_id: str, experiment_id: str | None = None, run_name: str | None = None
) -> Dict[str, ModelInfo]:
with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=True) as _:
model_info = {}
unwrapped_models = {}
for k in cfg.model_manager.models.keys():
unwrapped_models[k] = unwrap_fabric(local_vars[k])
model_info[k] = mlflow.pytorch.log_model(unwrapped_models[k], artifact_path=k)
mlflow.log_dict(cfg, "config.json")
return model_info

register_model(fabric, log_models, cfg)
import warnings
from typing import TYPE_CHECKING, Any, Dict

import torch
from lightning.fabric.wrappers import _FabricModule

from sheeprl.utils.imports import _IS_MLFLOW_AVAILABLE
from sheeprl.utils.utils import unwrap_fabric

if TYPE_CHECKING:
from mlflow.models.model import ModelInfo

def log_models(
cfg: Dict[str, Any],
models_to_log: Dict[str, torch.nn.Module | _FabricModule],
run_id: str,
experiment_id: str | None = None,
run_name: str | None = None,
) -> Dict[str, "ModelInfo"]:
if not _IS_MLFLOW_AVAILABLE:
raise ModuleNotFoundError(str(_IS_MLFLOW_AVAILABLE))
import mlflow # noqa

with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=True) as _:
model_info = {}
unwrapped_models = {}
for k in cfg.model_manager.models.keys():
if k not in models_to_log:
warnings.warn(f"Model {k} not found in models_to_log, skipping.", category=UserWarning)
continue
unwrapped_models[k] = unwrap_fabric(models_to_log[k])
model_info[k] = mlflow.pytorch.log_model(unwrapped_models[k], artifact_path=k)
mlflow.log_dict(cfg, "config.json")
return model_info
```

### Metrics and Model Manager
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ dependencies = [
"moviepy>=1.0.3",
"tensordict==0.2.*",
"tensorboard>=2.10",
"mlflow==2.8.0",
"python-dotenv>=1.0.0",
"lightning==2.1.*",
"lightning-utilities<=0.9",
Expand Down Expand Up @@ -84,6 +83,7 @@ minedojo = ["minedojo==0.1", "importlib_resources==5.12.0"]
minerl = ["setuptools==66.0.0", "minerl==0.4.4"]
diambra = ["diambra==0.0.16", "diambra-arena==2.2.2"]
crafter = ["crafter==1.8.1"]
mlflow = ["mlflow==2.8.0"]

[tool.ruff]
line-length = 120
Expand Down
22 changes: 5 additions & 17 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@

import gymnasium as gym
import hydra
import mlflow
import numpy as np
import torch
import torch.nn.functional as F
from lightning.fabric import Fabric
from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer
from mlflow.models.model import ModelInfo
from tensordict import TensorDict
from tensordict.tensordict import TensorDictBase
from torch.distributions import Bernoulli, Independent, Normal
Expand All @@ -30,7 +28,7 @@
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.registry import register_algorithm
from sheeprl.utils.timer import timer
from sheeprl.utils.utils import polynomial_decay, register_model, save_configs, unwrap_fabric
from sheeprl.utils.utils import polynomial_decay, save_configs

# Decomment the following two lines if you cannot start an experiment with DMC environments
# os.environ["PYOPENGL_PLATFORM"] = ""
Expand Down Expand Up @@ -505,7 +503,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
world_optimizer, actor_optimizer, critic_optimizer
)

local_vars = locals()
if fabric.is_global_zero:
save_configs(cfg, log_dir)

Expand Down Expand Up @@ -792,17 +789,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
test(player, fabric, cfg, log_dir)

if not cfg.model_manager.disabled and fabric.is_global_zero:
from sheeprl.algos.dreamer_v1.utils import log_models
from sheeprl.utils.mlflow import register_model

def log_models(
run_id: str, experiment_id: str | None = None, run_name: str | None = None
) -> Dict[str, ModelInfo]:
with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=True) as _:
model_info = {}
unwrapped_models = {}
for k in cfg.model_manager.models.keys():
unwrapped_models[k] = unwrap_fabric(local_vars[k])
model_info[k] = mlflow.pytorch.log_model(unwrapped_models[k], artifact_path=k)
mlflow.log_dict(cfg, "config.json")
return model_info

register_model(fabric, log_models, cfg)
models_to_log = {"world_model": world_model, "actor": actor, "critic": critic}
register_model(fabric, log_models, cfg, models_to_log)
41 changes: 37 additions & 4 deletions sheeprl/algos/dreamer_v1/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
from __future__ import annotations

from typing import Any, Dict, Sequence, Tuple
import warnings
from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple

import gymnasium as gym
import mlflow
import torch
import torch.nn.functional as F
from lightning import Fabric
from mlflow.models.model import ModelInfo
from lightning.fabric.wrappers import _FabricModule
from torch import Tensor
from torch.distributions import Distribution, Independent, Normal

from sheeprl.utils.imports import _IS_MLFLOW_AVAILABLE
from sheeprl.utils.utils import unwrap_fabric

if TYPE_CHECKING:
from mlflow.models.model import ModelInfo


AGGREGATOR_KEYS = {
"Rewards/rew_avg",
"Game/ep_len_avg",
Expand Down Expand Up @@ -102,9 +107,37 @@ def compute_stochastic_state(
return (mean, std), stochastic_state


def log_models(
cfg: Dict[str, Any],
models_to_log: Dict[str, torch.nn.Module | _FabricModule],
run_id: str,
experiment_id: str | None = None,
run_name: str | None = None,
) -> Dict[str, "ModelInfo"]:
if not _IS_MLFLOW_AVAILABLE:
raise ModuleNotFoundError(str(_IS_MLFLOW_AVAILABLE))
import mlflow # noqa

with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=True) as _:
model_info = {}
unwrapped_models = {}
for k in cfg.model_manager.models.keys():
if k not in models_to_log:
warnings.warn(f"Model {k} not found in models_to_log, skipping.", category=UserWarning)
continue
unwrapped_models[k] = unwrap_fabric(models_to_log[k])
model_info[k] = mlflow.pytorch.log_model(unwrapped_models[k], artifact_path=k)
mlflow.log_dict(cfg, "config.json")
return model_info


def log_models_from_checkpoint(
fabric: Fabric, env: gym.Env | gym.Wrapper, cfg: Dict[str, Any], state: Dict[str, Any]
) -> Sequence[ModelInfo]:
) -> Sequence["ModelInfo"]:
if not _IS_MLFLOW_AVAILABLE:
raise ModuleNotFoundError(str(_IS_MLFLOW_AVAILABLE))
import mlflow # noqa

from sheeprl.algos.dreamer_v1.agent import build_agent

# Create the models
Expand Down
21 changes: 5 additions & 16 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@

import gymnasium as gym
import hydra
import mlflow
import numpy as np
import torch
import torch.nn.functional as F
from lightning.fabric import Fabric
from lightning.fabric.wrappers import _FabricModule
from mlflow.models.model import ModelInfo
from tensordict import TensorDict
from tensordict.tensordict import TensorDictBase
from torch import Tensor
Expand All @@ -35,7 +33,7 @@
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.registry import register_algorithm
from sheeprl.utils.timer import timer
from sheeprl.utils.utils import polynomial_decay, register_model, save_configs, unwrap_fabric
from sheeprl.utils.utils import polynomial_decay, save_configs

# Decomment the following two lines if you cannot start an experiment with DMC environments
# os.environ["PYOPENGL_PLATFORM"] = ""
Expand Down Expand Up @@ -872,17 +870,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
test(player, fabric, cfg, log_dir)

if not cfg.model_manager.disabled and fabric.is_global_zero:
from sheeprl.algos.dreamer_v1.utils import log_models
from sheeprl.utils.mlflow import register_model

def log_models(
run_id: str, experiment_id: str | None = None, run_name: str | None = None
) -> Dict[str, ModelInfo]:
with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=True) as _:
model_info = {}
unwrapped_models = {}
for k in cfg.model_manager.models.keys():
unwrapped_models[k] = unwrap_fabric(local_vars[k])
model_info[k] = mlflow.pytorch.log_model(unwrapped_models[k], artifact_path=k)
mlflow.log_dict(cfg, "config.json")
return model_info

register_model(fabric, log_models, cfg)
models_to_log = {"world_model": world_model, "actor": actor, "critic": critic, "target_critic": target_critic}
register_model(fabric, log_models, cfg, models_to_log)
11 changes: 8 additions & 3 deletions sheeprl/algos/dreamer_v2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,21 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union

import gymnasium as gym
import mlflow
import numpy as np
import torch
import torch.nn as nn
from lightning import Fabric
from mlflow.models.model import ModelInfo
from torch import Tensor
from torch.distributions import Independent

from sheeprl.utils.distribution import OneHotCategoricalStraightThroughValidateArgs
from sheeprl.utils.env import make_env
from sheeprl.utils.imports import _IS_MLFLOW_AVAILABLE
from sheeprl.utils.utils import unwrap_fabric

if TYPE_CHECKING:
from mlflow.models.model import ModelInfo

from sheeprl.algos.dreamer_v1.agent import PlayerDV1
from sheeprl.algos.dreamer_v2.agent import PlayerDV2

Expand Down Expand Up @@ -165,7 +166,11 @@ def test(

def log_models_from_checkpoint(
fabric: Fabric, env: gym.Env | gym.Wrapper, cfg: Dict[str, Any], state: Dict[str, Any]
) -> Sequence[ModelInfo]:
) -> Sequence["ModelInfo"]:
if not _IS_MLFLOW_AVAILABLE:
raise ModuleNotFoundError(str(_IS_MLFLOW_AVAILABLE))
import mlflow # noqa

from sheeprl.algos.dreamer_v2.agent import build_agent

# Create the models
Expand Down
33 changes: 13 additions & 20 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@

import gymnasium as gym
import hydra
import mlflow
import numpy as np
import torch
import torch.nn.functional as F
from lightning.fabric import Fabric
from lightning.fabric.wrappers import _FabricModule
from mlflow.models.model import ModelInfo
from tensordict import TensorDict
from tensordict.tensordict import TensorDictBase
from torch import Tensor
Expand All @@ -41,7 +39,7 @@
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.registry import register_algorithm
from sheeprl.utils.timer import timer
from sheeprl.utils.utils import polynomial_decay, register_model, save_configs, unwrap_fabric
from sheeprl.utils.utils import polynomial_decay, save_configs

# Decomment the following two lines if you cannot start an experiment with DMC environments
# os.environ["PYOPENGL_PLATFORM"] = ""
Expand Down Expand Up @@ -295,7 +293,7 @@ def train(
policies: Sequence[Distribution] = actor(imagined_trajectories.detach())[1]

baseline = predicted_values[:-1]
offset, invscale = moments(lambda_values)
offset, invscale = moments(lambda_values, fabric)
normed_lambda_values = (lambda_values - offset) / invscale
normed_baseline = (baseline - offset) / invscale
advantage = normed_lambda_values - normed_baseline
Expand Down Expand Up @@ -466,7 +464,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
world_optimizer, actor_optimizer, critic_optimizer
)
moments = Moments(
fabric,
cfg.algo.actor.moments.decay,
cfg.algo.actor.moments.max,
cfg.algo.actor.moments.percentile.low,
Expand All @@ -475,7 +472,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
if cfg.checkpoint.resume_from:
moments.load_state_dict(state["moments"])

local_vars = locals()
if fabric.is_global_zero:
save_configs(cfg, log_dir)

Expand Down Expand Up @@ -791,17 +787,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
test(player, fabric, cfg, log_dir, sample_actions=True)

if not cfg.model_manager.disabled and fabric.is_global_zero:

def log_models(
run_id: str, experiment_id: str | None = None, run_name: str | None = None
) -> Dict[str, ModelInfo]:
with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=True) as _:
model_info = {}
unwrapped_models = {}
for k in cfg.model_manager.models.keys():
unwrapped_models[k] = unwrap_fabric(local_vars[k])
model_info[k] = mlflow.pytorch.log_model(unwrapped_models[k], artifact_path=k)
mlflow.log_dict(cfg, "config.json")
return model_info

register_model(fabric, log_models, cfg)
from sheeprl.algos.dreamer_v1.utils import log_models
from sheeprl.utils.mlflow import register_model

models_to_log = {
"world_model": world_model,
"actor": actor,
"critic": critic,
"target_critic": target_critic,
"moments": moments,
}
register_model(fabric, log_models, cfg, models_to_log)
Loading