From 18064e33e16824abcaf6c2c33ef6b369e00fcba9 Mon Sep 17 00:00:00 2001 From: michele-milesi <74559684+michele-milesi@users.noreply.github.com> Date: Tue, 28 Nov 2023 15:43:57 +0100 Subject: [PATCH] Feature/mlflow (#159) * feat: added mlflow logger * feat: unified get_logger methods * feat: generalized model register * feat: removed signature * feat: added mlflow register model to sac, sac_decoupled and droq * feat: added model manager to dreamers and sac_ae * feat: added model manager to p2e algorithms * fix: removed order dependencies between configs and code when registering models * fix: avoid p2e exploration models registered during finetuning * Feature/add build agents (#153) * [skip ci] Update README.md * [skip ci] Update README.md * feat: renamed build_models function into build_agent * feat: added build_agent() function to all the algorithms * feat: added build_agent() to evaluate() functions --------- Co-authored-by: Federico Belotti * feat: split model manager configs * feat: added script to register models from checkpoints * fix: bugs * fix: configs * fix: configs + registration model script * feat: added ensembles creation to build agent function (#154) * feat: added possibility to select experiment and run where to upload the models * fix: bugs * feat: added configs to artifact when model is registered from checkpoint * docs: update logs_and_checkpoints how to * feat: added model_manager howto * docs: update * docs: update * fix: added 'from __future__ import annotations' * feat: added mlflow model manager tutorial in examples * fix: bugs * fix: access to cnn and mlp keys * fix: experiment and run names * fix: bugs * feat: MlflowModelManager.register_best_models() function * fix: p2e build_agent * docs: update * fix: mlflow model manager * fix: mlflow model manager register best models --------- Co-authored-by: Federico Belotti --- .gitignore | 5 +- examples/model_manager.ipynb | 978 ++++++++++++++++++ howto/logs_and_checkpoints.md | 63 ++ howto/model_manager.md | 103 ++ howto/register_new_algorithm.md | 167 ++- pyproject.toml | 1 + sheeprl/algos/dreamer_v1/agent.py | 2 +- sheeprl/algos/dreamer_v1/dreamer_v1.py | 34 +- sheeprl/algos/dreamer_v1/evaluate.py | 10 +- sheeprl/algos/dreamer_v1/utils.py | 45 +- sheeprl/algos/dreamer_v2/agent.py | 5 +- sheeprl/algos/dreamer_v2/dreamer_v2.py | 34 +- sheeprl/algos/dreamer_v2/evaluate.py | 10 +- sheeprl/algos/dreamer_v2/utils.py | 44 +- sheeprl/algos/dreamer_v3/agent.py | 5 +- sheeprl/algos/dreamer_v3/dreamer_v3.py | 34 +- sheeprl/algos/dreamer_v3/evaluate.py | 10 +- sheeprl/algos/dreamer_v3/utils.py | 53 +- sheeprl/algos/droq/agent.py | 43 +- sheeprl/algos/droq/droq.py | 59 +- sheeprl/algos/droq/evaluate.py | 34 +- sheeprl/algos/droq/utils.py | 27 + sheeprl/algos/p2e_dv1/agent.py | 48 +- sheeprl/algos/p2e_dv1/evaluate.py | 10 +- sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py | 60 +- sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py | 41 +- sheeprl/algos/p2e_dv1/utils.py | 70 ++ sheeprl/algos/p2e_dv2/agent.py | 65 +- sheeprl/algos/p2e_dv2/evaluate.py | 11 +- sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py | 75 +- sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py | 39 +- sheeprl/algos/p2e_dv2/utils.py | 82 ++ sheeprl/algos/p2e_dv3/agent.py | 61 +- sheeprl/algos/p2e_dv3/evaluate.py | 11 +- sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py | 83 +- sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py | 40 +- sheeprl/algos/p2e_dv3/utils.py | 112 ++ sheeprl/algos/ppo/agent.py | 31 +- sheeprl/algos/ppo/evaluate.py | 23 +- sheeprl/algos/ppo/ppo.py | 69 +- sheeprl/algos/ppo/ppo_decoupled.py | 42 +- sheeprl/algos/ppo/utils.py | 39 +- sheeprl/algos/ppo_recurrent/agent.py | 33 +- sheeprl/algos/ppo_recurrent/evaluate.py | 26 +- sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 71 +- sheeprl/algos/ppo_recurrent/utils.py | 38 +- sheeprl/algos/sac/agent.py | 36 +- sheeprl/algos/sac/evaluate.py | 26 +- sheeprl/algos/sac/sac.py | 51 +- sheeprl/algos/sac/sac_decoupled.py | 66 +- sheeprl/algos/sac/utils.py | 25 +- sheeprl/algos/sac_ae/agent.py | 118 ++- sheeprl/algos/sac_ae/evaluate.py | 82 +- sheeprl/algos/sac_ae/sac_ae.py | 142 +-- sheeprl/algos/sac_ae/utils.py | 29 +- sheeprl/cli.py | 89 +- sheeprl/configs/config.yaml | 3 +- sheeprl/configs/exp/dreamer_v1.yaml | 1 + sheeprl/configs/exp/dreamer_v2.yaml | 1 + sheeprl/configs/exp/dreamer_v3.yaml | 1 + sheeprl/configs/exp/droq.yaml | 3 +- sheeprl/configs/exp/p2e_dv1_exploration.yaml | 1 + sheeprl/configs/exp/p2e_dv1_finetuning.yaml | 1 + sheeprl/configs/exp/p2e_dv2_exploration.yaml | 1 + sheeprl/configs/exp/p2e_dv2_finetuning.yaml | 1 + sheeprl/configs/exp/p2e_dv3_exploration.yaml | 7 +- sheeprl/configs/exp/p2e_dv3_finetuning.yaml | 1 + ...doapp_64px_gray_combo_discrete_5Mstps.yaml | 2 +- sheeprl/configs/exp/ppo.yaml | 1 + sheeprl/configs/exp/ppo_recurrent.yaml | 1 + sheeprl/configs/exp/sac.yaml | 1 + sheeprl/configs/exp/sac_ae.yaml | 1 + sheeprl/configs/logger/mlflow.yaml | 10 + sheeprl/configs/logger/tensorboard.yaml | 7 + sheeprl/configs/metric/default.yaml | 4 + sheeprl/configs/model_manager/default.yaml | 2 + sheeprl/configs/model_manager/dreamer_v1.yaml | 17 + sheeprl/configs/model_manager/dreamer_v2.yaml | 21 + sheeprl/configs/model_manager/dreamer_v3.yaml | 25 + sheeprl/configs/model_manager/droq.yaml | 9 + .../model_manager/p2e_dv1_exploration.yaml | 29 + .../model_manager/p2e_dv1_finetuning.yaml | 17 + .../model_manager/p2e_dv2_exploration.yaml | 37 + .../model_manager/p2e_dv2_finetuning.yaml | 21 + .../model_manager/p2e_dv3_exploration.yaml | 57 + .../model_manager/p2e_dv3_finetuning.yaml | 25 + sheeprl/configs/model_manager/ppo.yaml | 9 + .../configs/model_manager/ppo_recurrent.yaml | 9 + sheeprl/configs/model_manager/sac.yaml | 9 + sheeprl/configs/model_manager/sac_ae.yaml | 17 + sheeprl/configs/model_manager_config.yaml | 22 + sheeprl/utils/logger.py | 31 +- sheeprl/utils/model_manager.py | 323 ++++++ sheeprl/utils/utils.py | 108 +- sheeprl_model_manager.py | 4 + 95 files changed, 3689 insertions(+), 761 deletions(-) create mode 100644 examples/model_manager.ipynb create mode 100644 howto/model_manager.md create mode 100644 sheeprl/configs/logger/mlflow.yaml create mode 100644 sheeprl/configs/logger/tensorboard.yaml create mode 100644 sheeprl/configs/model_manager/default.yaml create mode 100644 sheeprl/configs/model_manager/dreamer_v1.yaml create mode 100644 sheeprl/configs/model_manager/dreamer_v2.yaml create mode 100644 sheeprl/configs/model_manager/dreamer_v3.yaml create mode 100644 sheeprl/configs/model_manager/droq.yaml create mode 100644 sheeprl/configs/model_manager/p2e_dv1_exploration.yaml create mode 100644 sheeprl/configs/model_manager/p2e_dv1_finetuning.yaml create mode 100644 sheeprl/configs/model_manager/p2e_dv2_exploration.yaml create mode 100644 sheeprl/configs/model_manager/p2e_dv2_finetuning.yaml create mode 100644 sheeprl/configs/model_manager/p2e_dv3_exploration.yaml create mode 100644 sheeprl/configs/model_manager/p2e_dv3_finetuning.yaml create mode 100644 sheeprl/configs/model_manager/ppo.yaml create mode 100644 sheeprl/configs/model_manager/ppo_recurrent.yaml create mode 100644 sheeprl/configs/model_manager/sac.yaml create mode 100644 sheeprl/configs/model_manager/sac_ae.yaml create mode 100644 sheeprl/configs/model_manager_config.yaml create mode 100644 sheeprl/utils/model_manager.py create mode 100644 sheeprl_model_manager.py diff --git a/.gitignore b/.gitignore index bc2f89e8..3ef1d5a0 100644 --- a/.gitignore +++ b/.gitignore @@ -167,4 +167,7 @@ pytest_* !sheeprl/configs/env .diambra* .hydra -.pypirc \ No newline at end of file +.pypirc +mlruns +mlartifacts +examples/models \ No newline at end of file diff --git a/examples/model_manager.ipynb b/examples/model_manager.ipynb new file mode 100644 index 00000000..c0c3d029 --- /dev/null +++ b/examples/model_manager.ipynb @@ -0,0 +1,978 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "> **Note**\n", + ">\n", + "> This notebook was inspired by [https://orobix.github.io/quadra/v1.3.6/tutorials/model_management.html](https://orobix.github.io/quadra/v1.3.6/tutorials/model_management.html)\n", + "\n", + "# Model Manager\n", + "\n", + "In this notebook, we present the [MlflowModelManager](../sheeprl/utils/model_manager.py) and possible use.\n", + "It includes methods such as:\n", + "* Register the model\n", + "* Retrieve the latest version\n", + "* Transition the model to a new stage\n", + "* Delete the model\n", + "\n", + "First of all, we need to run the Mlflow server with the artifact store. You can find the instructions for running the Mlflow server [here](https://mlflow.org/docs/latest/tracking.html#tracking-ui). Let's open a new terminal and run the following command:\n", + "```bash\n", + "mlflow ui\n", + "```\n", + "\n", + "> **Note**\n", + ">\n", + "> This is one of the possibilities, you could have the server running on another machine, so you just need to set the `tracking_uri` parameter properly." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Running the Experiment and Registering the Model\n", + "Second, we launch an experiment, so we need to retrieve the configs and execute the `run_algorithm` function. We train a PPO agent in the CartPole-v1 environment for few steps (we do not want to reach the best performance, but we want to show how SheepRL interprets model management for reinforcement learning)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Seed set to 42\n", + "Experiment with name mlflow_example not found. Creating it.\n", + "/home/mmilesi/repos/sheeprl/sheeprl/utils/logger.py:79: UserWarning: Missing logger folder: logs/runs/ppo/CartPole-v1/2023-11-28_11-50-38_mlflow_example_42\n", + " warnings.warn(\"Missing logger folder: %s\" % save_dir, UserWarning)\n", + "/home/mmilesi/miniconda3/envs/sheeprl/lib/python3.10/site-packages/gymnasium/experimental/wrappers/rendering.py:166: UserWarning: \u001b[33mWARN: Overwriting existing videos at /home/mmilesi/repos/sheeprl/examples/logs/runs/ppo/CartPole-v1/2023-11-28_11-50-38_mlflow_example_42/version_0/train_videos folder (try specifying a different `video_folder` for the `RecordVideo` wrapper if this is not desired)\u001b[0m\n", + " logger.warn(\n", + "/home/mmilesi/repos/sheeprl/sheeprl/algos/ppo/ppo.py:233: UserWarning: The metric.log_every parameter (5000) is not a multiple of the policy_steps_per_update value (512), so the metrics will be logged at the nearest greater multiple of the policy_steps_per_update value.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Encoder CNN keys: []\n", + "Encoder MLP keys: ['state']\n", + "Rank-0: policy_step=44, reward_env_1=11.0\n", + "Rank-0: policy_step=64, reward_env_3=16.0\n", + "Rank-0: policy_step=92, reward_env_2=23.0\n", + "Rank-0: policy_step=104, reward_env_3=10.0\n", + "Rank-0: policy_step=108, reward_env_0=27.0\n", + "Rank-0: policy_step=128, reward_env_1=21.0\n", + "Rank-0: policy_step=144, reward_env_2=13.0\n", + "Rank-0: policy_step=160, reward_env_0=13.0\n", + "Rank-0: policy_step=168, reward_env_1=10.0\n", + "Rank-0: policy_step=168, reward_env_3=16.0\n", + "Rank-0: policy_step=260, reward_env_1=23.0\n", + "Rank-0: policy_step=284, reward_env_0=31.0\n", + "Rank-0: policy_step=300, reward_env_2=39.0\n", + "Rank-0: policy_step=352, reward_env_1=23.0\n", + "Rank-0: policy_step=376, reward_env_0=23.0\n", + "Rank-0: policy_step=424, reward_env_0=12.0\n", + "Rank-0: policy_step=444, reward_env_1=23.0\n", + "Rank-0: policy_step=484, reward_env_0=15.0\n", + "Rank-0: policy_step=536, reward_env_1=23.0\n", + "Rank-0: policy_step=556, reward_env_3=97.0\n", + "Rank-0: policy_step=592, reward_env_2=73.0\n", + "Rank-0: policy_step=600, reward_env_1=16.0\n", + "Rank-0: policy_step=636, reward_env_0=38.0\n", + "Rank-0: policy_step=644, reward_env_3=22.0\n", + "Rank-0: policy_step=660, reward_env_1=15.0\n", + "Rank-0: policy_step=672, reward_env_2=20.0\n", + "Rank-0: policy_step=720, reward_env_2=12.0\n", + "Rank-0: policy_step=728, reward_env_0=23.0\n", + "Rank-0: policy_step=792, reward_env_3=37.0\n", + "Rank-0: policy_step=796, reward_env_1=34.0\n", + "Rank-0: policy_step=800, reward_env_0=18.0\n", + "Rank-0: policy_step=800, reward_env_2=20.0\n", + "Rank-0: policy_step=848, reward_env_1=13.0\n", + "Rank-0: policy_step=856, reward_env_2=14.0\n", + "Rank-0: policy_step=868, reward_env_3=19.0\n", + "Rank-0: policy_step=916, reward_env_2=15.0\n", + "Rank-0: policy_step=920, reward_env_0=30.0\n", + "Rank-0: policy_step=932, reward_env_3=16.0\n", + "Rank-0: policy_step=948, reward_env_1=25.0\n", + "Rank-0: policy_step=964, reward_env_2=12.0\n", + "Rank-0: policy_step=980, reward_env_3=12.0\n", + "Rank-0: policy_step=996, reward_env_0=19.0\n", + "Rank-0: policy_step=1008, reward_env_1=15.0\n", + "Test - Reward: 48.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023/11/28 11:50:46 WARNING mlflow.utils.requirements_utils: The following packages were not found in the public PyPI package index as of 2023-10-28; if these packages are not present in the public PyPI index, you must install them manually before loading your model: {'sheeprl'}\n", + "/home/mmilesi/miniconda3/envs/sheeprl/lib/python3.10/site-packages/_distutils_hack/__init__.py:33: UserWarning: Setuptools is replacing distutils.\n", + " warnings.warn(\"Setuptools is replacing distutils.\")\n", + "Successfully registered model 'mlflow_example'.\n", + "2023/11/28 11:50:46 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: mlflow_example, version 1\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Registered model mlflow_example with version 1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Created version '1' of model 'mlflow_example'.\n" + ] + } + ], + "source": [ + "import hydra\n", + "from omegaconf import OmegaConf\n", + "from sheeprl.utils.utils import dotdict\n", + "from sheeprl.cli import check_configs, run_algorithm\n", + "\n", + "# To retrieve the configs, we can simulate the cli command\n", + "# `python sheeprl.py exp=ppo algo.total_steps=1024 model_manager.disabled=False logger@metric.logger=mlflow checkpoint.every=1024 exp_name=mlflow_example metric.logger.tracking_uri=\"http://localhost:5000\"`\n", + "with hydra.initialize(version_base=\"1.3\", config_path=\"../sheeprl/configs\"):\n", + " cfg = hydra.compose(\n", + " config_name=\"config.yaml\",\n", + " overrides=[\n", + " \"exp=ppo\",\n", + " \"algo.total_steps=1024\",\n", + " \"model_manager.disabled=False\",\n", + " \"logger@metric.logger=mlflow\",\n", + " \"checkpoint.every=1024\",\n", + " \"exp_name=mlflow_example\",\n", + " \"metric.logger.tracking_uri=http://localhost:5000\",\n", + " ],\n", + " )\n", + " cfg = dotdict(OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True))\n", + "check_configs(cfg)\n", + "run_algorithm(cfg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Get Experiment Info\n", + "\n", + "The experiment is logged on MLFlow, and we can retrieve it just with the following instructions. Moreover, given the experiment, it is possible to retrieve all the runs with the `mlflow.search_runs()` function.\n", + "\n", + "> **Note**\n", + ">\n", + "> You can check this information from a browser, by entering the MLFlow address in a browser, e.g., `http://localhost:5000` if you are running mlflow locally." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Experiment: \n", + "Experiment (257965132461445889) runs:\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
run_idexperiment_idstatusartifact_uristart_timeend_timemetrics.Loss/value_lossmetrics.Test/cumulative_rewardmetrics.Loss/entropy_lossmetrics.Rewards/rew_avg...params.num_threadsparams.algo/mlp_keys/encoderparams.metric/logger/experiment_nameparams.algo/loss_reductionparams.algo/critic/dense_acttags.mlflow.source.typetags.mlflow.usertags.mlflow.source.nametags.mlflow.log-model.historytags.mlflow.runName
07432e9e1e0a2491e9cf2b80c7cefe6c7257965132461445889FINISHEDmlflow-artifacts:/257965132461445889/7432e9e1e...2023-11-28 10:50:39.184000+00:002023-11-28 10:50:46.452000+00:0036.16628348.0-0.68703122.953489...1['state']mlflow_examplemeantorch.nn.TanhLOCALmmilesi/home/mmilesi/miniconda3/envs/sheeprl/lib/pyth...[{\"run_id\": \"7432e9e1e0a2491e9cf2b80c7cefe6c7\"...ppo_CartPole-v1_2023-11-28_11-50-38
\n", + "

1 rows × 129 columns

\n", + "
" + ], + "text/plain": [ + " run_id experiment_id status \\\n", + "0 7432e9e1e0a2491e9cf2b80c7cefe6c7 257965132461445889 FINISHED \n", + "\n", + " artifact_uri \\\n", + "0 mlflow-artifacts:/257965132461445889/7432e9e1e... \n", + "\n", + " start_time end_time \\\n", + "0 2023-11-28 10:50:39.184000+00:00 2023-11-28 10:50:46.452000+00:00 \n", + "\n", + " metrics.Loss/value_loss metrics.Test/cumulative_reward \\\n", + "0 36.166283 48.0 \n", + "\n", + " metrics.Loss/entropy_loss metrics.Rewards/rew_avg ... \\\n", + "0 -0.687031 22.953489 ... \n", + "\n", + " params.num_threads params.algo/mlp_keys/encoder \\\n", + "0 1 ['state'] \n", + "\n", + " params.metric/logger/experiment_name params.algo/loss_reduction \\\n", + "0 mlflow_example mean \n", + "\n", + " params.algo/critic/dense_act tags.mlflow.source.type tags.mlflow.user \\\n", + "0 torch.nn.Tanh LOCAL mmilesi \n", + "\n", + " tags.mlflow.source.name \\\n", + "0 /home/mmilesi/miniconda3/envs/sheeprl/lib/pyth... \n", + "\n", + " tags.mlflow.log-model.history \\\n", + "0 [{\"run_id\": \"7432e9e1e0a2491e9cf2b80c7cefe6c7\"... \n", + "\n", + " tags.mlflow.runName \n", + "0 ppo_CartPole-v1_2023-11-28_11-50-38 \n", + "\n", + "[1 rows x 129 columns]" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import mlflow\n", + "\n", + "mlflow.set_tracking_uri(cfg.metric.logger.tracking_uri)\n", + "exp = mlflow.get_experiment_by_name(\"mlflow_example\")\n", + "print(\"Experiment:\", exp)\n", + "runs = mlflow.search_runs(experiment_ids=[exp.experiment_id])\n", + "print(f\"Experiment ({exp.experiment_id}) runs:\")\n", + "runs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Retrieve Model Info\n", + "Since we set the `model_manager.disabled` to `False` the PPO Agent is registered in MLFLow, we can get its information with the following instructions." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Name: mlflow_example\n", + "Description: # MODEL CHANGELOG\n", + "## **Version 1**\n", + "### Author: mmilesi\n", + "### Date: 28/11/2023 11:50:46 CET\n", + "### Description: \n", + "PPO Agent in CartPole-v1 Environment\n", + "\n", + "Tags: {}\n", + "Latest Version: 1\n" + ] + } + ], + "source": [ + "from sheeprl.utils.model_manager import MlflowModelManager\n", + "from lightning import Fabric\n", + "\n", + "fabric = Fabric(devices=1, accelerator=cfg.fabric.accelerator, precision=cfg.fabric.precision)\n", + "fabric.launch()\n", + "model_manager = MlflowModelManager(fabric, cfg.model_manager.tracking_uri)\n", + "\n", + "model_info = mlflow.search_registered_models(filter_string=\"name='mlflow_example'\")[-1]\n", + "model_name = model_info.name\n", + "print(\"Name:\", model_name)\n", + "print(\"Description:\", model_info.description)\n", + "print(\"Tags:\", model_info.tags)\n", + "latest_version = model_manager.get_latest_version(model_info.name)\n", + "print(\"Latest Version:\", latest_version.version)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Registering a New Model Version from Checkpoint\n", + "\n", + "Suppose to train a new PPO Agent in the CartPole-v1 environment and to obtain better results than before. You can register a new version of the model. To do this, we show another method to register models, not directly after training, but from a checkpoint.\n", + "\n", + "First of all, we need to run another experiment with different hyper-parameters." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Seed set to 42\n", + "/home/mmilesi/repos/sheeprl/sheeprl/utils/logger.py:79: UserWarning: Missing logger folder: logs/runs/ppo/CartPole-v1/2023-11-28_11-50-46_mlflow_example_42\n", + " warnings.warn(\"Missing logger folder: %s\" % save_dir, UserWarning)\n", + "/home/mmilesi/miniconda3/envs/sheeprl/lib/python3.10/site-packages/gymnasium/experimental/wrappers/rendering.py:166: UserWarning: \u001b[33mWARN: Overwriting existing videos at /home/mmilesi/repos/sheeprl/examples/logs/runs/ppo/CartPole-v1/2023-11-28_11-50-46_mlflow_example_42/version_0/train_videos folder (try specifying a different `video_folder` for the `RecordVideo` wrapper if this is not desired)\u001b[0m\n", + " logger.warn(\n", + "/home/mmilesi/repos/sheeprl/sheeprl/algos/ppo/ppo.py:233: UserWarning: The metric.log_every parameter (5000) is not a multiple of the policy_steps_per_update value (512), so the metrics will be logged at the nearest greater multiple of the policy_steps_per_update value.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Encoder CNN keys: []\n", + "Encoder MLP keys: ['state']\n", + "Rank-0: policy_step=44, reward_env_1=11.0\n", + "Rank-0: policy_step=64, reward_env_3=16.0\n", + "Rank-0: policy_step=92, reward_env_2=23.0\n", + "Rank-0: policy_step=104, reward_env_3=10.0\n", + "Rank-0: policy_step=108, reward_env_0=27.0\n", + "Rank-0: policy_step=128, reward_env_1=21.0\n", + "Rank-0: policy_step=144, reward_env_2=13.0\n", + "Rank-0: policy_step=160, reward_env_0=13.0\n", + "Rank-0: policy_step=168, reward_env_1=10.0\n", + "Rank-0: policy_step=168, reward_env_3=16.0\n", + "Rank-0: policy_step=260, reward_env_1=23.0\n", + "Rank-0: policy_step=284, reward_env_0=31.0\n", + "Rank-0: policy_step=300, reward_env_2=39.0\n", + "Rank-0: policy_step=352, reward_env_1=23.0\n", + "Rank-0: policy_step=376, reward_env_0=23.0\n", + "Rank-0: policy_step=424, reward_env_0=12.0\n", + "Rank-0: policy_step=444, reward_env_1=23.0\n", + "Rank-0: policy_step=484, reward_env_0=15.0\n", + "Rank-0: policy_step=536, reward_env_1=23.0\n", + "Rank-0: policy_step=556, reward_env_3=97.0\n", + "Rank-0: policy_step=592, reward_env_2=73.0\n", + "Rank-0: policy_step=600, reward_env_1=16.0\n", + "Rank-0: policy_step=636, reward_env_0=38.0\n", + "Rank-0: policy_step=644, reward_env_3=22.0\n", + "Rank-0: policy_step=660, reward_env_1=15.0\n", + "Rank-0: policy_step=672, reward_env_2=20.0\n", + "Rank-0: policy_step=720, reward_env_2=12.0\n", + "Rank-0: policy_step=728, reward_env_0=23.0\n", + "Rank-0: policy_step=792, reward_env_3=37.0\n", + "Rank-0: policy_step=796, reward_env_1=34.0\n", + "Rank-0: policy_step=800, reward_env_0=18.0\n", + "Rank-0: policy_step=800, reward_env_2=20.0\n", + "Rank-0: policy_step=848, reward_env_1=13.0\n", + "Rank-0: policy_step=856, reward_env_2=14.0\n", + "Rank-0: policy_step=868, reward_env_3=19.0\n", + "Rank-0: policy_step=916, reward_env_2=15.0\n", + "Rank-0: policy_step=920, reward_env_0=30.0\n", + "Rank-0: policy_step=932, reward_env_3=16.0\n", + "Rank-0: policy_step=948, reward_env_1=25.0\n", + "Rank-0: policy_step=964, reward_env_2=12.0\n", + "Rank-0: policy_step=980, reward_env_3=12.0\n", + "Rank-0: policy_step=996, reward_env_0=19.0\n", + "Rank-0: policy_step=1008, reward_env_1=15.0\n", + "Rank-0: policy_step=1056, reward_env_3=19.0\n", + "Rank-0: policy_step=1068, reward_env_0=18.0\n", + "Rank-0: policy_step=1080, reward_env_2=29.0\n", + "Rank-0: policy_step=1112, reward_env_1=26.0\n", + "Rank-0: policy_step=1124, reward_env_2=11.0\n", + "Rank-0: policy_step=1136, reward_env_0=17.0\n", + "Rank-0: policy_step=1152, reward_env_3=24.0\n", + "Rank-0: policy_step=1212, reward_env_0=19.0\n", + "Rank-0: policy_step=1240, reward_env_2=29.0\n", + "Rank-0: policy_step=1284, reward_env_0=18.0\n", + "Rank-0: policy_step=1284, reward_env_1=43.0\n", + "Rank-0: policy_step=1320, reward_env_2=20.0\n", + "Rank-0: policy_step=1336, reward_env_3=46.0\n", + "Rank-0: policy_step=1364, reward_env_1=20.0\n", + "Rank-0: policy_step=1384, reward_env_0=25.0\n", + "Rank-0: policy_step=1408, reward_env_3=18.0\n", + "Rank-0: policy_step=1432, reward_env_2=28.0\n", + "Rank-0: policy_step=1448, reward_env_0=16.0\n", + "Rank-0: policy_step=1472, reward_env_1=27.0\n", + "Rank-0: policy_step=1480, reward_env_3=18.0\n", + "Rank-0: policy_step=1500, reward_env_2=17.0\n", + "Rank-0: policy_step=1516, reward_env_0=17.0\n", + "Rank-0: policy_step=1580, reward_env_0=16.0\n", + "Rank-0: policy_step=1644, reward_env_1=43.0\n", + "Rank-0: policy_step=1664, reward_env_2=41.0\n", + "Rank-0: policy_step=1668, reward_env_0=22.0\n", + "Rank-0: policy_step=1692, reward_env_3=53.0\n", + "Rank-0: policy_step=1744, reward_env_1=25.0\n", + "Rank-0: policy_step=1768, reward_env_0=25.0\n", + "Rank-0: policy_step=1796, reward_env_2=33.0\n", + "Rank-0: policy_step=1864, reward_env_2=17.0\n", + "Rank-0: policy_step=1956, reward_env_0=47.0\n", + "Rank-0: policy_step=1960, reward_env_1=54.0\n", + "Rank-0: policy_step=1964, reward_env_2=25.0\n", + "Rank-0: policy_step=2012, reward_env_2=12.0\n", + "Rank-0: policy_step=2040, reward_env_1=20.0\n", + "Rank-0: policy_step=2152, reward_env_1=28.0\n", + "Rank-0: policy_step=2164, reward_env_3=118.0\n", + "Rank-0: policy_step=2196, reward_env_1=11.0\n", + "Rank-0: policy_step=2260, reward_env_2=62.0\n", + "Rank-0: policy_step=2268, reward_env_0=78.0\n", + "Rank-0: policy_step=2276, reward_env_1=20.0\n", + "Rank-0: policy_step=2308, reward_env_3=36.0\n", + "Rank-0: policy_step=2492, reward_env_2=58.0\n", + "Rank-0: policy_step=2520, reward_env_0=63.0\n", + "Rank-0: policy_step=2584, reward_env_2=23.0\n", + "Rank-0: policy_step=2608, reward_env_1=83.0\n", + "Rank-0: policy_step=2632, reward_env_3=81.0\n", + "Rank-0: policy_step=2692, reward_env_0=43.0\n", + "Rank-0: policy_step=2780, reward_env_1=43.0\n", + "Rank-0: policy_step=2796, reward_env_2=53.0\n", + "Rank-0: policy_step=2928, reward_env_0=59.0\n", + "Rank-0: policy_step=2940, reward_env_2=36.0\n", + "Rank-0: policy_step=2944, reward_env_3=78.0\n", + "Rank-0: policy_step=3012, reward_env_1=58.0\n", + "Rank-0: policy_step=3080, reward_env_0=38.0\n", + "Rank-0: policy_step=3224, reward_env_3=70.0\n", + "Rank-0: policy_step=3284, reward_env_1=68.0\n", + "Rank-0: policy_step=3284, reward_env_2=86.0\n", + "Rank-0: policy_step=3288, reward_env_0=52.0\n", + "Rank-0: policy_step=3480, reward_env_2=49.0\n", + "Rank-0: policy_step=3548, reward_env_3=81.0\n", + "Rank-0: policy_step=3588, reward_env_1=76.0\n", + "Rank-0: policy_step=3712, reward_env_2=58.0\n", + "Rank-0: policy_step=3840, reward_env_1=63.0\n", + "Rank-0: policy_step=3928, reward_env_2=54.0\n", + "Rank-0: policy_step=3956, reward_env_3=102.0\n", + "Rank-0: policy_step=4016, reward_env_0=182.0\n", + "Rank-0: policy_step=4096, reward_env_1=64.0\n", + "Rank-0: policy_step=4252, reward_env_1=39.0\n", + "Rank-0: policy_step=4256, reward_env_0=60.0\n", + "Rank-0: policy_step=4316, reward_env_3=90.0\n", + "Rank-0: policy_step=4416, reward_env_1=41.0\n", + "Rank-0: policy_step=4476, reward_env_2=137.0\n", + "Rank-0: policy_step=4776, reward_env_1=90.0\n", + "Rank-0: policy_step=4800, reward_env_0=136.0\n", + "Rank-0: policy_step=4836, reward_env_2=90.0\n", + "Rank-0: policy_step=4920, reward_env_3=151.0\n", + "Rank-0: policy_step=5176, reward_env_1=100.0\n", + "Rank-0: policy_step=5308, reward_env_0=127.0\n", + "Rank-0: policy_step=5332, reward_env_2=124.0\n", + "Rank-0: policy_step=5388, reward_env_3=117.0\n", + "Rank-0: policy_step=5556, reward_env_0=62.0\n", + "Rank-0: policy_step=5736, reward_env_3=87.0\n", + "Rank-0: policy_step=5744, reward_env_1=142.0\n", + "Rank-0: policy_step=5832, reward_env_2=125.0\n", + "Rank-0: policy_step=5968, reward_env_0=103.0\n", + "Rank-0: policy_step=6168, reward_env_1=106.0\n", + "Rank-0: policy_step=6272, reward_env_3=134.0\n", + "Rank-0: policy_step=6340, reward_env_2=127.0\n", + "Rank-0: policy_step=6408, reward_env_0=110.0\n", + "Rank-0: policy_step=6756, reward_env_1=147.0\n", + "Rank-0: policy_step=6856, reward_env_3=146.0\n", + "Rank-0: policy_step=6876, reward_env_2=134.0\n", + "Rank-0: policy_step=6920, reward_env_0=128.0\n", + "Rank-0: policy_step=7308, reward_env_1=138.0\n", + "Rank-0: policy_step=7644, reward_env_3=197.0\n", + "Rank-0: policy_step=7688, reward_env_2=203.0\n", + "Rank-0: policy_step=7700, reward_env_0=195.0\n", + "Rank-0: policy_step=8080, reward_env_1=193.0\n", + "Rank-0: policy_step=8636, reward_env_2=237.0\n", + "Rank-0: policy_step=8644, reward_env_3=250.0\n", + "Rank-0: policy_step=8756, reward_env_0=264.0\n", + "Rank-0: policy_step=8868, reward_env_1=197.0\n", + "Rank-0: policy_step=9728, reward_env_2=273.0\n", + "Rank-0: policy_step=9864, reward_env_3=305.0\n", + "Rank-0: policy_step=10756, reward_env_0=500.0\n", + "Rank-0: policy_step=10848, reward_env_1=495.0\n", + "Rank-0: policy_step=11124, reward_env_2=349.0\n", + "Rank-0: policy_step=11472, reward_env_3=402.0\n", + "Rank-0: policy_step=12756, reward_env_0=500.0\n", + "Rank-0: policy_step=12848, reward_env_1=500.0\n", + "Rank-0: policy_step=13124, reward_env_2=500.0\n", + "Rank-0: policy_step=13472, reward_env_3=500.0\n", + "Rank-0: policy_step=14324, reward_env_1=369.0\n", + "Rank-0: policy_step=14532, reward_env_0=444.0\n", + "Rank-0: policy_step=14796, reward_env_3=331.0\n", + "Rank-0: policy_step=15112, reward_env_2=497.0\n", + "Rank-0: policy_step=15924, reward_env_1=400.0\n", + "Rank-0: policy_step=16000, reward_env_0=367.0\n", + "Rank-0: policy_step=16044, reward_env_2=233.0\n", + "Rank-0: policy_step=16152, reward_env_3=339.0\n", + "Test - Reward: 392.0\n" + ] + } + ], + "source": [ + "# To retrieve the configs, we can simulate the cli command\n", + "# `python sheeprl.py exp=ppo algo.total_steps=16384 checkpoint.every=16384 logger@metric.logger=mlflow exp_name=mlflow_example metric.logger.tracking_uri=\"http://localhost:5000\"`\n", + "import os\n", + "\n", + "with hydra.initialize(version_base=\"1.3\", config_path=\"../sheeprl/configs\"):\n", + " cfg_ = hydra.compose(\n", + " config_name=\"config.yaml\",\n", + " overrides=[\n", + " \"exp=ppo\",\n", + " \"algo.total_steps=16384\",\n", + " \"checkpoint.every=16384\",\n", + " \"logger@metric.logger=mlflow\",\n", + " \"exp_name=mlflow_example\",\n", + " \"metric.logger.tracking_uri=http://localhost:5000\",\n", + " ],\n", + " )\n", + " cfg = dotdict(OmegaConf.to_container(cfg_, resolve=True, throw_on_missing=True))\n", + "run_algorithm(cfg)\n", + "os.mkdir(f\"./logs/runs/{cfg.root_dir}/{cfg.run_name}/.hydra/\")\n", + "OmegaConf.save(cfg_, f\"./logs/runs/{cfg.root_dir}/{cfg.run_name}/.hydra/config.yaml\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can use the `./sheeprl_model_manager.py` script to take a checkpoint and register the models of the checkpoint.\n", + "We want to retrieve the id of the last run, to associate the model to the correct run. We can take it from the UI (from the browser) or by retrieving it with the `mlflow.search_runs(experiment_ids=[exp.experiment_id])` instruction." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023/11/28 11:51:55 WARNING mlflow.utils.requirements_utils: The following packages were not found in the public PyPI package index as of 2023-10-28; if these packages are not present in the public PyPI index, you must install them manually before loading your model: {'sheeprl'}\n", + "Registered model 'mlflow_example' already exists. Creating a new version of this model...\n", + "2023/11/28 11:51:55 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: mlflow_example, version 2\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Registered model mlflow_example with version 2\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Created version '2' of model 'mlflow_example'.\n" + ] + } + ], + "source": [ + "from sheeprl.cli import registration\n", + "\n", + "# To retrieve the configs, we can simulate the cli command\n", + "# `python sheeprl_model_manager.py checkpoint_path= \\\n", + "# model_manager=ppo model_manager.models.agent.description='New PPO Agent version trained in CartPole-v1 environment' \\\n", + "# run.id=`\n", + "runs = mlflow.search_runs(experiment_ids=[exp.experiment_id])\n", + "run_id = runs[\"run_id\"][0]\n", + "with hydra.initialize(version_base=\"1.3\", config_path=\"../sheeprl/configs\"):\n", + " cfg = hydra.compose(\n", + " config_name=\"model_manager_config.yaml\",\n", + " overrides=[\n", + " # Substitute the checkpoint path with your /path/to/checkpoint.ckpt\n", + " \"checkpoint_path=./logs/runs/ppo/CartPole-v1/2023-11-28_11-50-46_mlflow_example_42/version_0/checkpoint/ckpt_16384_0.ckpt\",\n", + " \"model_manager=ppo\",\n", + " \"model_manager.models.agent.description='New PPO Agent version trained in CartPole-v1 environment'\",\n", + " f\"run.id={run_id}\",\n", + " ],\n", + " )\n", + "registration(cfg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And, of course, we can retrieve the new information of the registered model." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Name: mlflow_example\n", + "Description: # MODEL CHANGELOG\n", + "## **Version 1**\n", + "### Author: mmilesi\n", + "### Date: 28/11/2023 11:50:46 CET\n", + "### Description: \n", + "PPO Agent in CartPole-v1 Environment\n", + "## **Version 2**\n", + "### Author: mmilesi\n", + "### Date: 28/11/2023 11:51:55 CET\n", + "### Description: \n", + "New PPO Agent version trained in CartPole-v1 environment\n", + "\n", + "Tags: {}\n", + "Latest Version: 2\n" + ] + } + ], + "source": [ + "model_info = mlflow.search_registered_models(filter_string=f\"name='{model_name}'\")[-1]\n", + "print(\"Name:\", model_info.name)\n", + "print(\"Description:\", model_info.description)\n", + "print(\"Tags:\", model_info.tags)\n", + "latest_version = model_manager.get_latest_version(model_info.name)\n", + "print(\"Latest Version:\", latest_version.version)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Staging the Model\n", + "After registering the model, we can transition the model to a new stage. We can transition the model to the `\"staging\"` stage with the following command." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Transitioning model mlflow_example version 2 from None to staging\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_manager.transition_model(\n", + " model_name=\"mlflow_example\", version=latest_version.version, stage=\"staging\", description=\"Staging Model for demo\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Downloading the Model\n", + "You can download the registered models and load them with the `torch.load()` function." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/mmilesi/miniconda3/envs/sheeprl/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading model mlflow_example version 2 from mlflow-artifacts:/257965132461445889/7b8fcd3b2615483a9546380cf8f313c4/artifacts/agent to ./models/ppo-agent-cartpole\n", + "Creating output path ./models/ppo-agent-cartpole\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Downloading artifacts: 100%|██████████| 6/6 [00:00<00:00, 218.04it/s] \n" + ] + }, + { + "data": { + "text/plain": [ + "PPOAgent(\n", + " (feature_extractor): MultiEncoder(\n", + " (mlp_encoder): MLPEncoder(\n", + " (model): MLP(\n", + " (_model): Sequential(\n", + " (0): Linear(in_features=4, out_features=64, bias=True)\n", + " (1): Tanh()\n", + " (2): Linear(in_features=64, out_features=64, bias=True)\n", + " (3): Tanh()\n", + " (4): Linear(in_features=64, out_features=64, bias=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (critic): MLP(\n", + " (_model): Sequential(\n", + " (0): Linear(in_features=64, out_features=64, bias=True)\n", + " (1): Tanh()\n", + " (2): Linear(in_features=64, out_features=64, bias=True)\n", + " (3): Tanh()\n", + " (4): Linear(in_features=64, out_features=1, bias=True)\n", + " )\n", + " )\n", + " (actor_backbone): MLP(\n", + " (_model): Sequential(\n", + " (0): Linear(in_features=64, out_features=64, bias=True)\n", + " (1): Tanh()\n", + " (2): Linear(in_features=64, out_features=64, bias=True)\n", + " (3): Tanh()\n", + " )\n", + " )\n", + " (actor_heads): ModuleList(\n", + " (0): Linear(in_features=64, out_features=2, bias=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "\n", + "download_path = \"./models/ppo-agent-cartpole\"\n", + "model_manager.download_model(model_name, latest_version.version, download_path)\n", + "agent = torch.load(\"models/ppo-agent-cartpole/agent/data/model.pth\")\n", + "agent" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Register Best Models\n", + "Another possibility is to register the best models of a specific experiment. Let us suppose we want to register the best model of the two experiments we ran before: the only thing we have to do is to call the `model_manager.register_best_models()` function by specifying the `experiment_name`, the `metric`, and the `models_info` (a python dictionary containing the name, the path, the description and the tags of the models we want to register), as shown below.\n", + "\n", + "> **Note**\n", + ">\n", + "> If your experiment contains different agents, and each agent has different model paths, then you have to specify in the `models_info` all the models you want to register (i.e., the union of the models of all the agents). The MLFlow model manager will automatically select the correct models for each agent." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Successfully registered model 'ppo_agent_cartpole_best_reward'.\n", + "2023/11/28 11:52:11 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: ppo_agent_cartpole_best_reward, version 1\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Registered model ppo_agent_cartpole_best_reward with version 1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Created version '1' of model 'ppo_agent_cartpole_best_reward'.\n" + ] + }, + { + "data": { + "text/plain": [ + "{'agent': }" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "models_info = {\n", + " \"agent\": {\n", + " \"name\": \"ppo_agent_cartpole_best_reward\",\n", + " \"path\": \"agent\",\n", + " \"tags\": {},\n", + " \"description\": \"The best PPO Agent in CartPole environment.\",\n", + " }\n", + "}\n", + "model_manager.register_best_models(\"mlflow_example\", models_info)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Delete Model\n", + "Finally, you can delete registered models you no longer need." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Deleting model mlflow_example version 1\n" + ] + }, + { + "data": { + "text/plain": [ + "], name='mlflow_example', tags={}>" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_manager.delete_model(\n", + " model_name, int(latest_version.version) - 1, f\"Delete model version {int(latest_version.version)-1}\"\n", + ")\n", + "mlflow.search_registered_models(filter_string=\"name='mlflow_example'\")[-1]" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sheeprl", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/howto/logs_and_checkpoints.md b/howto/logs_and_checkpoints.md index 708f7129..df2c255f 100644 --- a/howto/logs_and_checkpoints.md +++ b/howto/logs_and_checkpoints.md @@ -7,6 +7,10 @@ By default the logging of metrics is enabled with the following settings: ```yaml # ./sheeprl/configs/metric/default.yaml +defaults: + - _self_ + - /logger@logger: tensorboard + log_every: 5000 disable_timer: False @@ -33,6 +37,7 @@ aggregator: ``` where +* `logger` is the configuration of the logger you want to use for logging. There are two possible values: `tensorboard` (default) and `mlflow`, but one can define and choose its own logger. * `log_every` is the number of policy steps (number of steps played in the environment, e.g. if one has 2 processes with 4 environments per process then the policy steps are 2*4=8) between two consecutive logging operations. For more info about the policy steps, check the [Work with Steps Tutorial](./work_with_steps.md). * `disable_timer` is a boolean flag that enables/disables the timer to measure both the time spent in the environment and the time spent during the agent training. The timer class used can be found [here](../sheeprl/utils/timer.py). * `log_level` is the level of logging: $0$ means no log (it disables also the timer), whereas $1$ means logging everything. @@ -41,6 +46,64 @@ where So, if one wants to disable everything related to logging, he/she can set `log_level` to $0$ if one wants to disable the timer, he/she can set `disable_timer` to `True`. +### Loggers +Two loggers are made available: the Tensorboard logger and the MLFlow one. In any case, it is possible to define or choose another logger. +The configurations of the loggers are under the `./sheeprl/configs/logger/` folder. + +#### Tensorboard +Let us start with the Tensorboard logger, which is the default logger used in SheepRL. + +```yaml +# ./sheeprl/configs/logger/tensorboard.yaml + +# For more information, check https://lightning.ai/docs/fabric/stable/api/generated/lightning.fabric.loggers.TensorBoardLogger.html +_target_: lightning.fabric.loggers.TensorBoardLogger +name: ${run_name} +root_dir: logs/runs/${root_dir} +version: null +default_hp_metric: True +prefix: "" +sub_dir: null +``` +As shown in the configurations, it is necessary to specify the `_target_` class to instantiate. For the Tensorboard logger, it is necessary to specify the `name` and the `root_dir` arguments equal to the `run_name` and `logs/runs/` parameters, respectively, because we want that all the logs and files (configs, checkpoint, videos, ...) are under the same folder for a specific experiment. + +> **Note** +> +> In general we want the path of the logs files to be in the same folder created by Hydra when the experiment is launched, so make sure to properly define the `root_dir` and `name` parameters of the logger so that it is within the folder created by hydra (defined by the `hydra.run.dir` parameter). The tensorboard logger will save the logs in the `////` folder (if `sub_dir` is defined, otherwise in the `///` folder). + +The documentation of the TensorboardLogger class can be found [here](https://lightning.ai/docs/fabric/stable/api/generated/lightning.fabric.loggers.TensorBoardLogger.html). + +#### MLFlow +Another possibility provided by SheepRL is [MLFlow](https://mlflow.org/docs/2.8.0/index.html). + +```yaml +# ./sheeprl/configs/logger/mlflow.yaml + +# For more information, check https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.mlflow.html#lightning.pytorch.loggers.mlflow.MLFlowLogger +_target_: lightning.pytorch.loggers.MLFlowLogger +experiment_name: ${exp_name} +tracking_uri: ${oc.env:MLFLOW_TRACKING_URI} +run_name: ${algo.name}_${env.id}_${now:%Y-%m-%d_%H-%M-%S} +tags: null +save_dir: null +prefix: "" +artifact_location: null +run_id: null +log_model: false +``` + +The parameters that can be specified for creating the MLFlow logger are explained [here](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.mlflow.html#lightning.pytorch.loggers.mlflow.MLFlowLogger). + +You can specify the MLFlow logger instead of the Tensorboard one in the CLI, by adding the `logger@metric.logger=mlflow` argument. In this way, hydra will take the configurations defined in the `./sheeprl/configs/logger/mlflow.yaml` file. + +```bash +python sheeprl.py exp=ppo exp_name=ppo-cartpole logger@metric.logger=mlflow +``` + +> **Note** +> +> If you are using an MLFlow server, you can specify the `tracking_uri` in the config file or with the `MLFLOW_TRACKING_URI` environment variable (that is the default value in the configs). + ### Logged metrics Every algorithm should specify a set of default metrics to log, called `AGGREGATOR_KEYS`, under its own `utils.py` file. For instance, the default metrics logged by DreamerV2 are the following: diff --git a/howto/model_manager.md b/howto/model_manager.md new file mode 100644 index 00000000..ceb8d00c --- /dev/null +++ b/howto/model_manager.md @@ -0,0 +1,103 @@ +# Model Manager + +SheepRL makes it possible to register trained models on MLFLow, so as to keep track of model versions and stages. + +## Register models with training +The configurations of the model manager are placed in the `./sheeprl/configs/model_manager/` folder, and the default configuration is defined as follows: +```yaml +# ./sheeprl/configs/model_manager/default.yaml + +disabled: True +models: {} +``` +Since the algorithms have different models, then the `models` parameter is set to an empty python dictionary, and each agent will define its own configuration. The `disabled` parameter indicates whether or not the user wants to register the agent when the training is finished (`False` means that the agent will be registered, otherwise not). + +> **Note** +> +> The model manager can be used even if the chosen logger is Tensorboard, the only requirement is that an instance of the MLFlow server is running and is accessible, moreover, it is necessary to specify its URI in the `MLFLOW_TRACKING_URI` environment variable. + +To better understand how to define the configurations of the models you want to register, take a look at the DreamerV3 model manager configuration: +```yaml +# ./sheeprl/configs/model_manager/dreamer_v3.yaml + +defaults: + - default + - _self_ + +models: + world_model: + model_name: "${exp_name}_world_model" + description: "DreamerV3 World Model used in ${env.id} Environment" + tags: {} + actor: + model_name: "${exp_name}_actor" + description: "DreamerV3 Actor used in ${env.id} Environment" + tags: {} + critic: + model_name: "${exp_name}_critic" + description: "DreamerV3 Critic used in ${env.id} Environment" + tags: {} + target_critic: + model_name: "${exp_name}_target_critic" + description: "DreamerV3 Target Critic used in ${env.id} Environment" + tags: {} + moments: + model_name: "${exp_name}_moments" + description: "DreamerV3 Moments used in ${env.id} Environment" + tags: {} +``` +For each model, it is necessary to define the `model_name`, the `description`, and the `tags` (i.e., a python dictionary with strings as keys and values). The keys that can be specified are defined by the `MODELS_TO_REGISTER` variable in the `./sheeprl/algos//utils.py`. For DreamerV3, it is defined as follows: `MODELS_TO_REGISTER = {"world_model", "actor", "critic", "target_critic", "moments"}`. +If you do not want to log some models, then, you just need to remove it from the configuration file. + +> **Note** +> +> The name of the models in the `MODELS_TO_REGISTER` variable is equal to the name of the variables of the models in the `./sheeprl/algos//.py` file. +> +> Make sure that the models specified in the configuration file are a subset of the models defined by the `MODELS_TO_REGISTER` variable. + +## Register models from checkpoints +Another possibility is to register the models after the training, by manually selecting the checkpoint where to retrieve the agent. To do this, it is possible to run the `sheeprl_model_manager.py` script by properly specifying the `checkpoint_path`, the `model_manager`, and the MLFlow-related configurations. +The default configurations are defined in the `./sheeprl/configs/model_manager_config.yaml` file, that is reported below: +```yaml +# ./sheeprl/configs/model_manager_config.yaml + +# @package _global_ +defaults: + - _self_ + - model_manager: ??? + - override hydra/hydra_logging: disabled + - override hydra/job_logging: disabled + +hydra: + output_subdir: null + run: + dir: . + +checkpoint_path: ??? +run: + id: null + name: ${now:%Y-%m-%d_%H-%M-%S}_${exp_name} +experiment: + id: null + name: ${exp_name}_${now:%Y-%m-%d_%H-%M-%S} +tracking_uri: ${oc.env:MLFLOW_TRACKING_URI} +``` + +As before, it is necessary to specify the `model_manager` configurations (the models we want to register with names, descriptions, and tags). Moreover, it is mandatory to set the `checkpoint_path`, which must be the path to the `ckpt` file created during the training. Finally, the `run` and `experiment` parameters contain the MLFlow configurations: +* If you set the `run.id` to a value different from `null`, then all the other parameters are ignored, indeed, the models will be logged and registered under the run with the specified ID. +* If you want to create a new run (with a name equal to `run.name`) and put it into an existing experiment, then you have to set `run.id=null` and `experiment.id=`. +* If you set `experiment.id=null` and `run.id=null`, then a new experiment and a new run are created with the specified names. + +> **Note** +> +> Also, in this case, the models specified in the `model_manager` configuration must be a subset of the `MODELS_TO_REGISTER` variable. + +For instance, you can register the DreamerV3 models from a checkpoint with the following command: + +```bash +python sheeprl_model_manager.py model_manager=dreamer_v3 checkpoint_path=/path/to/checkpoint.ckpt +``` + +## Delete, Transition and Download Models +The MLFlow model manager enables the deletion of the registered models, moving them from one stage to another or downloading them. +[This notebook](../examples/model_manager.ipynb) contains a tutorial on how to use the MLFlow model manager. We recommend taking a look to see what APIs the model manager makes available. \ No newline at end of file diff --git a/howto/register_new_algorithm.md b/howto/register_new_algorithm.md index cc6ee326..20fae37e 100644 --- a/howto/register_new_algorithm.md +++ b/howto/register_new_algorithm.md @@ -49,6 +49,7 @@ from tensordict.tensordict import TensorDictBase from torch.optim import Adam from torchmetrics import MeanMetric, SumMetric +from sheeprl.algos.ppo.agent import build_agent from sheeprl.algos.sota.loss import loss1, loss2 from sheeprl.algos.sota.utils import test from sheeprl.data import ReplayBuffer @@ -56,8 +57,9 @@ from sheeprl.models.models import MLP from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_logger, get_log_dir from sheeprl.utils.timer import timer +from sheeprl.utils.utils import register_model, unwrap_fabric def train( @@ -77,8 +79,9 @@ def train( optimizer.step() # Update metrics - aggregator.update("Loss/loss1", l1.detach()) - aggregator.update("Loss/loss2", l2.detach()) + if aggregator and not aggregator.disabled: + aggregator.update("Loss/loss1", l1.detach()) + aggregator.update("Loss/loss2", l2.detach()) @register_algorithm(decoupled=False) @@ -88,10 +91,10 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): device = fabric.device fabric.seed_everything(cfg.seed) - # Create TensorBoardLogger. This will create the logger only on the + # Create Logger. This will create the logger only on the # rank-0 process - logger = create_tensorboard_logger(fabric, cfg) - if fabric.is_global_zero: + logger = get_logger(fabric, cfg) + if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) @@ -113,27 +116,28 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): ) # Create the agent model: this should be a torch.nn.Module to be accelerated with Fabric - # Given that the environment has been created with the `make_dict_env` method, the agent + # Given that the environment has been created with the `make_env` method, the agent # forward method must accept as input a dictionary like {"obs1_name": obs1, "obs2_name": obs2, ...}. # The agent should be able to process both image and vector-like observations. - agent = ... + agent = build_agent( + fabric, + actions_dim, + is_continuous, + cfg, + observation_space, + state["agent"] if cfg.checkpoint.resume_from else None, + ) - # Define the agent and the optimizer and set up them with Fabric - optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=list(agent.parameters())) - agent = fabric.setup_module(agent) - optimizer = fabric.setup_optimizers(optimizer) + # the optimizer and set up it with Fabric + optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters()) + + # In case you want to give the possiblity to register your models + local_vars = locals() # Create a metric aggregator to log the metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(), - "Game/ep_len_avg": MeanMetric(), - "Loss/value_loss": MeanMetric(), - "Loss/policy_loss": MeanMetric(), - "Loss/entropy_loss": MeanMetric(), - } - ) + aggregator = None + if not MetricAggregator.disabled: + aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator).to(device) # Local data rb = ReplayBuffer(cfg.algo.rollout_steps, cfg.env.num_envs, device=device, memmap=cfg.buffer.memmap) @@ -226,8 +230,10 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): if agent_ep_info is not None: ep_rew = agent_ep_info["episode"]["r"] ep_len = agent_ep_info["episode"]["l"] - aggregator.update("Rewards/rew_avg", ep_rew) - aggregator.update("Game/ep_len_avg", ep_len) + if aggregator and "Rewards/rew_avg" in aggregator: + aggregator.update("Rewards/rew_avg", ep_rew) + if aggregator and "Game/ep_len_avg" in aggregator: + aggregator.update("Game/ep_len_avg", ep_len) fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Flatten the batch @@ -239,26 +245,28 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): # Log metrics if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: # Sync distributed metrics - metrics_dict = aggregator.compute() - fabric.log_dict(metrics_dict, policy_step) - aggregator.reset() + if aggregator and not aggregator.disabled: + metrics_dict = aggregator.compute() + fabric.log_dict(metrics_dict, policy_step) + aggregator.reset() # Sync distributed timers - timer_metrics = timer.compute() - if "Time/train_time" in timer_metrics: - fabric.log( - "Time/sps_train", - (train_step - last_train) / timer_metrics["Time/train_time"], - policy_step, - ) - if "Time/env_interaction_time" in timer_metrics: - fabric.log( - "Time/sps_env_interaction", - ((policy_step - last_log) / world_size * cfg.env.action_repeat) - / timer_metrics["Time/env_interaction_time"], - policy_step, - ) - timer.reset() + if not timer.disabled: + timer_metrics = timer.compute() + if "Time/train_time" in timer_metrics: + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + if "Time/env_interaction_time" in timer_metrics: + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) + timer.reset() # Reset counters last_log = policy_step @@ -281,7 +289,46 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero: - test(actor.module, envs, fabric, cfg) + test(agent.module, fabric, cfg, log_dir) + + # Optional part in case you want to give the possibility to register your models with MLFlow + if not cfg.model_manager.disabled and fabric.is_global_zero: + + def log_models( + run_id: str, experiment_id: str | None = None, run_name: str | None = None + ) -> Dict[str, ModelInfo]: + with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=True) as _: + model_info = {} + unwrapped_models = {} + for k in cfg.model_manager.models.keys(): + unwrapped_models[k] = unwrap_fabric(local_vars[k]) + model_info[k] = mlflow.pytorch.log_model(unwrapped_models[k], artifact_path=k) + mlflow.log_dict(cfg, "config.json") + return model_info + + register_model(fabric, log_models, cfg) +``` + +### Metrics and Model Manager +Each algorithm logs its own metrics, during training or environment interaction. To define which are the metrics that can be logged, you need to define the `AGGREGATOR_KEYS` variable in the `./sheeprl/algos/sota/utils.py` file. It must be a set of strings (the name of the metrics to log). Then, you can decide which metrics to log by defining the `metric.aggregator.metrics` in the configs. + +> **Remember** +> +> The intersection between the keys in the `AGGREGATOR_KEYS` and the ones in the `metric.aggregator.metrics` config will be logged. + +As for metrics, you have to specify which are the models that can be registered after training, you need to define the `MODELS_TO_REGISTER` variable in the `./sheeprl/algos/sota/utils.py` file. It must be a set of strings (the name of the variables of the models you want to register). As before, you can easily select which agents to register by defining the `model_manager.models` in the configs. Also in this case, the models that will be registered are the intersection between the `MODELS_TO_REGISTER` variable and the keys of the `model_manager.models` config. + +In this case, the `./sheeprl/algos/sota/utils.py` file could be defined as below: + +```python +# `./sheeprl/algos/sota/utils.py` + +... + +AGGREGATOR_KEYS = {"Rewards/rew_avg", "Game/ep_len_avg", "Loss/loss1", "Loss/loss2"} +MODELS_TO_REGISTER = {"agent"} + +... ``` ## Config files @@ -306,7 +353,7 @@ configs └── sota.yaml ``` -#### Algo configs +#### Algo Configs In the `./sheeprl/configs/algo/sota.yaml` we need to specify all the configs needed to initialize and train your agent. Here is an example of the `./sheeprl/configs/algo/sota.yaml` config file: @@ -370,7 +417,24 @@ will add two optimizers, one accessible with `algo.encoder.optimizer`, the other > > The field `algo.name` **must** be set and **must** be equal to the name of the file.py, found under the `sheeprl/algos/sota` folder, where the implementation of the algorithm is defined. For example, if your implementation is defined in a python file named `my_sota.py`, i.e. `sheeprl/algos/sota/my_sota.py`, then `algo.name="my_sota"` -#### Experiment config +#### Model Manager Configs +In the `./sheeprl/configs/model_manager/sota.yaml` we need to specify all the configs needed to register your agent. You can specify a name, a description, and some tags for each model you want to register. The `disabled` parameter indicates whether or not you want to register your models. +Here is an example of the `./sheeprl/configs/model_manager/sota.yaml` config file: + +```yaml +defaults: + - default + - _self_ + +disabled: False +models: + agent: + model_name: "${exp_name}" + description: "SOTA Agent in ${env.id} Environment" + tags: {} +``` + +#### Experiment Configs In the second file, you have to specify all the elements you want in your experiment and you can override all the parameters you want. Here is an example of the `./sheeprl/configs/exp/sota.yaml` config file: @@ -380,6 +444,8 @@ Here is an example of the `./sheeprl/configs/exp/sota.yaml` config file: defaults: - override /algo: sota - override /env: atari + # select the model manager configs + - override /model_manager: sota - _self_ algo: @@ -393,6 +459,17 @@ buffer: env: env: id: MsPacmanNoFrameskip-v4 + +# select which metrics to log +metric: + aggregator: + metrics: + Loss/loss1: + _target_: torchmetrics.MeanMetric + sync_on_compute: ${metric.sync_on_compute} + Loss/loss2: + _target_: torchmetrics.MeanMetric + sync_on_compute: ${metric.sync_on_compute} ``` With `override /algo: sota` in `defaults` you are specifying you want to use the new `sota` algorithm, whereas, with `override /env: gym` you are specifying that you want to train your agent on an *Atari* environment. diff --git a/pyproject.toml b/pyproject.toml index 78297304..58ef2058 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "moviepy>=1.0.3", "tensordict==0.2.*", "tensorboard>=2.10", + "mlflow==2.8.0", "python-dotenv>=1.0.0", "lightning==2.1.*", "lightning-utilities<=0.9", diff --git a/sheeprl/algos/dreamer_v1/agent.py b/sheeprl/algos/dreamer_v1/agent.py index 9b2ea9c0..8c6bc351 100644 --- a/sheeprl/algos/dreamer_v1/agent.py +++ b/sheeprl/algos/dreamer_v1/agent.py @@ -324,7 +324,7 @@ def get_greedy_action( return actions -def build_models( +def build_agent( fabric: Fabric, actions_dim: Sequence[int], is_continuous: bool, diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 1b382852..22db32c2 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -7,28 +7,30 @@ import gymnasium as gym import hydra +import mlflow import numpy as np import torch import torch.nn.functional as F from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer +from mlflow.models.model import ModelInfo from tensordict import TensorDict from tensordict.tensordict import TensorDictBase from torch.distributions import Bernoulli, Independent, Normal from torch.utils.data import BatchSampler from torchmetrics import SumMetric -from sheeprl.algos.dreamer_v1.agent import PlayerDV1, WorldModel, build_models +from sheeprl.algos.dreamer_v1.agent import PlayerDV1, WorldModel, build_agent from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss from sheeprl.algos.dreamer_v1.utils import compute_lambda_values from sheeprl.algos.dreamer_v2.utils import test from sheeprl.data.buffers import AsyncReplayBuffer from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay +from sheeprl.utils.utils import polynomial_decay, register_model, unwrap_fabric # Decomment the following two lines if you cannot start an experiment with DMC environments # os.environ["PYOPENGL_PLATFORM"] = "" @@ -414,9 +416,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): cfg.env.screen_size = 64 cfg.env.frame_stack = 1 - # Create TensorBoardLogger. This will create the logger only on the + # Create Logger. This will create the logger only on the # rank-0 process - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -442,7 +444,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r @@ -470,7 +472,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder - world_model, actor, critic = build_models( + world_model, actor, critic = build_agent( fabric, actions_dim, is_continuous, @@ -504,6 +506,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): world_optimizer, actor_optimizer, critic_optimizer ) + local_vars = locals() + # Metrics aggregator = None if not MetricAggregator.disabled: @@ -777,3 +781,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero: test(player, fabric, cfg, log_dir) + + if not cfg.model_manager.disabled and fabric.is_global_zero: + + def log_models( + run_id: str, experiment_id: str | None = None, run_name: str | None = None + ) -> Dict[str, ModelInfo]: + with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=True) as _: + model_info = {} + unwrapped_models = {} + for k in cfg.model_manager.models.keys(): + unwrapped_models[k] = unwrap_fabric(local_vars[k]) + model_info[k] = mlflow.pytorch.log_model(unwrapped_models[k], artifact_path=k) + mlflow.log_dict(cfg, "config.json") + return model_info + + register_model(fabric, log_models, cfg) diff --git a/sheeprl/algos/dreamer_v1/evaluate.py b/sheeprl/algos/dreamer_v1/evaluate.py index 1b4a30cc..8772ffe6 100644 --- a/sheeprl/algos/dreamer_v1/evaluate.py +++ b/sheeprl/algos/dreamer_v1/evaluate.py @@ -5,16 +5,16 @@ import gymnasium as gym from lightning import Fabric -from sheeprl.algos.dreamer_v1.agent import PlayerDV1, build_models +from sheeprl.algos.dreamer_v1.agent import PlayerDV1, build_agent from sheeprl.algos.dreamer_v2.utils import test from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.registry import register_evaluation @register_evaluation(algorithms="dreamer_v1") def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -39,11 +39,11 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) # Create the actor and critic models - world_model, actor, _ = build_models( + world_model, actor, _ = build_agent( fabric, actions_dim, is_continuous, diff --git a/sheeprl/algos/dreamer_v1/utils.py b/sheeprl/algos/dreamer_v1/utils.py index 95eec658..37783078 100644 --- a/sheeprl/algos/dreamer_v1/utils.py +++ b/sheeprl/algos/dreamer_v1/utils.py @@ -1,10 +1,18 @@ -from typing import Tuple +from __future__ import annotations +from typing import Any, Dict, Sequence, Tuple + +import gymnasium as gym +import mlflow import torch import torch.nn.functional as F +from lightning import Fabric +from mlflow.models.model import ModelInfo from torch import Tensor from torch.distributions import Distribution, Independent, Normal +from sheeprl.utils.utils import unwrap_fabric + AGGREGATOR_KEYS = { "Rewards/rew_avg", "Game/ep_len_avg", @@ -23,6 +31,7 @@ "Grads/actor", "Grads/critic", } +MODELS_TO_REGISTER = {"world_model", "actor", "critic"} def compute_lambda_values( @@ -91,3 +100,37 @@ def compute_stochastic_state( state_distribution = Independent(state_distribution, event_shape, validate_args=validate_args) stochastic_state = state_distribution.rsample() return (mean, std), stochastic_state + + +def log_models_from_checkpoint( + fabric: Fabric, env: gym.Env | gym.Wrapper, cfg: Dict[str, Any], state: Dict[str, Any] +) -> Sequence[ModelInfo]: + from sheeprl.algos.dreamer_v1.agent import build_agent + + # Create the models + is_continuous = isinstance(env.action_space, gym.spaces.Box) + is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) + actions_dim = tuple( + env.action_space.shape + if is_continuous + else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) + ) + world_model, actor, critic = build_agent( + fabric, + actions_dim, + is_continuous, + cfg, + env.observation_space, + state["world_model"], + state["actor"], + state["critic"], + ) + + # Log the model, create a new run if `cfg.run_id` is None. + model_info = {} + with mlflow.start_run(run_id=cfg.run.id, experiment_id=cfg.experiment.id, run_name=cfg.run.name, nested=True) as _: + model_info["world_model"] = mlflow.pytorch.log_model(unwrap_fabric(world_model), artifact_path="world_model") + model_info["actor"] = mlflow.pytorch.log_model(unwrap_fabric(actor), artifact_path="actor") + model_info["critic"] = mlflow.pytorch.log_model(unwrap_fabric(critic), artifact_path="critic") + mlflow.log_dict(cfg.to_log, "config.json") + return model_info diff --git a/sheeprl/algos/dreamer_v2/agent.py b/sheeprl/algos/dreamer_v2/agent.py index 34a94bd0..c83e46b1 100644 --- a/sheeprl/algos/dreamer_v2/agent.py +++ b/sheeprl/algos/dreamer_v2/agent.py @@ -452,8 +452,7 @@ def __init__( ) -> None: super().__init__() self.distribution_cfg = distribution_cfg - self.distribution = distribution_cfg.pop("type", "auto").lower() - self.distribution_cfg.type = self.distribution + self.distribution = distribution_cfg.get("type", "auto").lower() if self.distribution not in ("auto", "normal", "tanh_normal", "discrete", "trunc_normal"): raise ValueError( "The distribution must be on of: `auto`, `discrete`, `normal`, `tanh_normal` and `trunc_normal`. " @@ -862,7 +861,7 @@ def get_greedy_action( return actions -def build_models( +def build_agent( fabric: Fabric, actions_dim: Sequence[int], is_continuous: bool, diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 326a27db..1ff0c7a7 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -10,11 +10,13 @@ import gymnasium as gym import hydra +import mlflow import numpy as np import torch import torch.nn.functional as F from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule +from mlflow.models.model import ModelInfo from tensordict import TensorDict from tensordict.tensordict import TensorDictBase from torch import Tensor @@ -23,17 +25,17 @@ from torch.utils.data import BatchSampler from torchmetrics import SumMetric -from sheeprl.algos.dreamer_v2.agent import PlayerDV2, WorldModel, build_models +from sheeprl.algos.dreamer_v2.agent import PlayerDV2, WorldModel, build_agent from sheeprl.algos.dreamer_v2.loss import reconstruction_loss from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, test from sheeprl.data.buffers import AsyncReplayBuffer, EpisodeBuffer from sheeprl.utils.distribution import OneHotCategoricalValidateArgs from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay +from sheeprl.utils.utils import polynomial_decay, register_model, unwrap_fabric # Decomment the following two lines if you cannot start an experiment with DMC environments # os.environ["PYOPENGL_PLATFORM"] = "" @@ -438,9 +440,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): cfg.env.screen_size = 64 cfg.env.frame_stack = 1 - # Create TensorBoardLogger. This will create the logger only on the + # Create Logger. This will create the logger only on the # rank-0 process - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -466,7 +468,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r @@ -495,7 +497,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder - world_model, actor, critic, target_critic = build_models( + world_model, actor, critic, target_critic = build_agent( fabric, actions_dim, is_continuous, @@ -531,6 +533,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): world_optimizer, actor_optimizer, critic_optimizer ) + local_vars = locals() + # Metrics aggregator = None if not MetricAggregator.disabled: @@ -857,3 +861,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero: test(player, fabric, cfg, log_dir) + + if not cfg.model_manager.disabled and fabric.is_global_zero: + + def log_models( + run_id: str, experiment_id: str | None = None, run_name: str | None = None + ) -> Dict[str, ModelInfo]: + with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=True) as _: + model_info = {} + unwrapped_models = {} + for k in cfg.model_manager.models.keys(): + unwrapped_models[k] = unwrap_fabric(local_vars[k]) + model_info[k] = mlflow.pytorch.log_model(unwrapped_models[k], artifact_path=k) + mlflow.log_dict(cfg, "config.json") + return model_info + + register_model(fabric, log_models, cfg) diff --git a/sheeprl/algos/dreamer_v2/evaluate.py b/sheeprl/algos/dreamer_v2/evaluate.py index a320324e..78ebee5b 100644 --- a/sheeprl/algos/dreamer_v2/evaluate.py +++ b/sheeprl/algos/dreamer_v2/evaluate.py @@ -5,16 +5,16 @@ import gymnasium as gym from lightning import Fabric -from sheeprl.algos.dreamer_v2.agent import PlayerDV2, build_models +from sheeprl.algos.dreamer_v2.agent import PlayerDV2, build_agent from sheeprl.algos.dreamer_v2.utils import test from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.registry import register_evaluation @register_evaluation(algorithms="dreamer_v2") def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -39,11 +39,11 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) # Create the actor and critic models - world_model, actor, _, _ = build_models( + world_model, actor, _, _ = build_agent( fabric, actions_dim, is_continuous, diff --git a/sheeprl/algos/dreamer_v2/utils.py b/sheeprl/algos/dreamer_v2/utils.py index 29153658..1aa41711 100644 --- a/sheeprl/algos/dreamer_v2/utils.py +++ b/sheeprl/algos/dreamer_v2/utils.py @@ -1,15 +1,20 @@ -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union import gymnasium as gym +import mlflow import numpy as np import torch import torch.nn as nn from lightning import Fabric +from mlflow.models.model import ModelInfo from torch import Tensor from torch.distributions import Independent from sheeprl.utils.distribution import OneHotCategoricalStraightThroughValidateArgs from sheeprl.utils.env import make_env +from sheeprl.utils.utils import unwrap_fabric if TYPE_CHECKING: from sheeprl.algos.dreamer_v1.agent import PlayerDV1 @@ -34,6 +39,7 @@ "Grads/actor", "Grads/critic", } +MODELS_TO_REGISTER = {"world_model", "actor", "critic", "target_critic"} def compute_stochastic_state(logits: Tensor, discrete: int = 32, sample=True, validate_args=False) -> Tensor: @@ -155,3 +161,39 @@ def test( if cfg.metric.log_level > 0 and len(fabric.loggers) > 0: fabric.logger.log_metrics({"Test/cumulative_reward": cumulative_rew}, 0) env.close() + + +def log_models_from_checkpoint( + fabric: Fabric, env: gym.Env | gym.Wrapper, cfg: Dict[str, Any], state: Dict[str, Any] +) -> Sequence[ModelInfo]: + from sheeprl.algos.dreamer_v2.agent import build_agent + + # Create the models + is_continuous = isinstance(env.action_space, gym.spaces.Box) + is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) + actions_dim = tuple( + env.action_space.shape + if is_continuous + else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) + ) + world_model, actor, critic, target_critic = build_agent( + fabric, + actions_dim, + is_continuous, + cfg, + env.observation_space, + state["world_model"], + state["actor"], + state["critic"], + state["target_critic"], + ) + + # Log the model, create a new run if `cfg.run_id` is None. + model_info = {} + with mlflow.start_run(run_id=cfg.run.id, experiment_id=cfg.experiment.id, run_name=cfg.run.name, nested=True) as _: + model_info["world_model"] = mlflow.pytorch.log_model(unwrap_fabric(world_model), artifact_path="world_model") + model_info["actor"] = mlflow.pytorch.log_model(unwrap_fabric(actor), artifact_path="actor") + model_info["critic"] = mlflow.pytorch.log_model(unwrap_fabric(critic), artifact_path="critic") + model_info["target_critic"] = mlflow.pytorch.log_model(target_critic, artifact_path="target_critic") + mlflow.log_dict(cfg.to_log, "config.json") + return model_info diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index 4174eccf..2c3e1d8a 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -633,8 +633,7 @@ def __init__( ) -> None: super().__init__() self.distribution_cfg = distribution_cfg - self.distribution = distribution_cfg.pop("type", "auto").lower() - self.distribution_cfg.type = self.distribution + self.distribution = distribution_cfg.get("type", "auto").lower() if self.distribution not in ("auto", "normal", "tanh_normal", "discrete", "trunc_normal"): raise ValueError( "The distribution must be on of: `auto`, `discrete`, `normal`, `tanh_normal` and `trunc_normal`. " @@ -897,7 +896,7 @@ def add_exploration_noise( return tuple(expl_actions) -def build_models( +def build_agent( fabric: Fabric, actions_dim: Sequence[int], is_continuous: bool, diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index a631ae1f..4fe6fbec 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -10,11 +10,13 @@ import gymnasium as gym import hydra +import mlflow import numpy as np import torch import torch.nn.functional as F from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule +from mlflow.models.model import ModelInfo from tensordict import TensorDict from tensordict.tensordict import TensorDictBase from torch import Tensor @@ -23,7 +25,7 @@ from torch.utils.data import BatchSampler from torchmetrics import SumMetric -from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel, build_models +from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel, build_agent from sheeprl.algos.dreamer_v3.loss import reconstruction_loss from sheeprl.algos.dreamer_v3.utils import Moments, compute_lambda_values, test from sheeprl.data.buffers import AsyncReplayBuffer @@ -35,11 +37,11 @@ TwoHotEncodingDistribution, ) from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay +from sheeprl.utils.utils import polynomial_decay, register_model, unwrap_fabric # Decomment the following two lines if you cannot start an experiment with DMC environments # os.environ["PYOPENGL_PLATFORM"] = "" @@ -370,9 +372,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if 2 ** int(np.log2(cfg.env.screen_size)) != cfg.env.screen_size: raise ValueError(f"The screen size must be a power of 2, got: {cfg.env.screen_size}") - # Create TensorBoardLogger. This will create the logger only on the + # Create Logger. This will create the logger only on the # rank-0 process - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -401,7 +403,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r @@ -430,7 +432,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder - world_model, actor, critic, target_critic = build_models( + world_model, actor, critic, target_critic = build_agent( fabric, actions_dim, is_continuous, @@ -474,6 +476,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if cfg.checkpoint.resume_from: moments.load_state_dict(state["moments"]) + local_vars = locals() + # Metrics aggregator = None if not MetricAggregator.disabled: @@ -776,3 +780,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero: test(player, fabric, cfg, log_dir, sample_actions=True) + + if not cfg.model_manager.disabled and fabric.is_global_zero: + + def log_models( + run_id: str, experiment_id: str | None = None, run_name: str | None = None + ) -> Dict[str, ModelInfo]: + with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=True) as _: + model_info = {} + unwrapped_models = {} + for k in cfg.model_manager.models.keys(): + unwrapped_models[k] = unwrap_fabric(local_vars[k]) + model_info[k] = mlflow.pytorch.log_model(unwrapped_models[k], artifact_path=k) + mlflow.log_dict(cfg, "config.json") + return model_info + + register_model(fabric, log_models, cfg) diff --git a/sheeprl/algos/dreamer_v3/evaluate.py b/sheeprl/algos/dreamer_v3/evaluate.py index 63a84dd5..94775a45 100644 --- a/sheeprl/algos/dreamer_v3/evaluate.py +++ b/sheeprl/algos/dreamer_v3/evaluate.py @@ -5,16 +5,16 @@ import gymnasium as gym from lightning import Fabric -from sheeprl.algos.dreamer_v3.agent import PlayerDV3, build_models +from sheeprl.algos.dreamer_v3.agent import PlayerDV3, build_agent from sheeprl.algos.dreamer_v3.utils import test from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.registry import register_evaluation @register_evaluation(algorithms="dreamer_v3") def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -39,11 +39,11 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) # Create the actor and critic models - world_model, actor, _, _ = build_models( + world_model, actor, _, _ = build_agent( fabric, actions_dim, is_continuous, diff --git a/sheeprl/algos/dreamer_v3/utils.py b/sheeprl/algos/dreamer_v3/utils.py index 7c31278f..48ad57f6 100644 --- a/sheeprl/algos/dreamer_v3/utils.py +++ b/sheeprl/algos/dreamer_v3/utils.py @@ -1,12 +1,17 @@ -from typing import TYPE_CHECKING, Any, Dict +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Sequence import gymnasium as gym +import mlflow import numpy as np import torch from lightning import Fabric +from mlflow.models.model import ModelInfo from torch import Tensor, nn from sheeprl.utils.env import make_env +from sheeprl.utils.utils import unwrap_fabric if TYPE_CHECKING: from sheeprl.algos.dreamer_v3.agent import PlayerDV3 @@ -29,6 +34,7 @@ "Grads/actor", "Grads/critic", } +MODELS_TO_REGISTER = {"world_model", "actor", "critic", "target_critic", "moments"} class Moments(nn.Module): @@ -176,3 +182,48 @@ def f(m): m.bias.data.fill_(0.0) return f + + +def log_models_from_checkpoint( + fabric: Fabric, env: gym.Env | gym.Wrapper, cfg: Dict[str, Any], state: Dict[str, Any] +) -> Sequence[ModelInfo]: + from sheeprl.algos.dreamer_v3.agent import build_agent + + # Create the models + is_continuous = isinstance(env.action_space, gym.spaces.Box) + is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) + actions_dim = tuple( + env.action_space.shape + if is_continuous + else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) + ) + world_model, actor, critic, target_critic = build_agent( + fabric, + actions_dim, + is_continuous, + cfg, + env.observation_space, + state["world_model"], + state["actor"], + state["critic"], + state["target_critic"], + ) + moments = Moments( + fabric, + cfg.algo.actor.moments.decay, + cfg.algo.actor.moments.max, + cfg.algo.actor.moments.percentile.low, + cfg.algo.actor.moments.percentile.high, + ) + moments.load_state_dict(state["moments"]) + + # Log the model, create a new run if `cfg.run_id` is None. + model_info = {} + with mlflow.start_run(run_id=cfg.run.id, experiment_id=cfg.experiment.id, run_name=cfg.run.name, nested=True) as _: + model_info["world_model"] = mlflow.pytorch.log_model(unwrap_fabric(world_model), artifact_path="world_model") + model_info["actor"] = mlflow.pytorch.log_model(unwrap_fabric(actor), artifact_path="actor") + model_info["critic"] = mlflow.pytorch.log_model(unwrap_fabric(critic), artifact_path="critic") + model_info["target_critic"] = mlflow.pytorch.log_model(target_critic, artifact_path="target_critic") + model_info["moments"] = mlflow.pytorch.log_model(moments, artifact_path="moments") + mlflow.log_dict(cfg.to_log, "config.json") + return model_info diff --git a/sheeprl/algos/droq/agent.py b/sheeprl/algos/droq/agent.py index a3c88b12..7a02da77 100644 --- a/sheeprl/algos/droq/agent.py +++ b/sheeprl/algos/droq/agent.py @@ -1,8 +1,11 @@ import copy -from typing import Sequence, Tuple, Union +from math import prod +from typing import Any, Dict, Optional, Sequence, Tuple, Union +import gymnasium import torch import torch.nn as nn +from lightning import Fabric from lightning.fabric.wrappers import _FabricModule from torch import Tensor @@ -198,3 +201,41 @@ def qfs_target_ema(self, critic_idx: int) -> None: self.qfs_unwrapped[critic_idx].parameters(), self.qfs_target[critic_idx].parameters() ): target_param.data.copy_(self._tau * param.data + (1 - self._tau) * target_param.data) + + +def build_agent( + fabric: Fabric, + cfg: Dict[str, Any], + obs_space: gymnasium.spaces.Dict, + action_space: gymnasium.spaces.Box, + agent_state: Optional[Dict[str, Tensor]] = None, +) -> DROQAgent: + act_dim = prod(action_space.shape) + obs_dim = sum([prod(obs_space[k].shape) for k in cfg.algo.mlp_keys.encoder]) + actor = SACActor( + observation_dim=obs_dim, + action_dim=act_dim, + distribution_cfg=cfg.distribution, + hidden_size=cfg.algo.actor.hidden_size, + action_low=action_space.low, + action_high=action_space.high, + ) + critics = [ + DROQCritic( + observation_dim=obs_dim + act_dim, + hidden_size=cfg.algo.critic.hidden_size, + num_critics=1, + dropout=cfg.algo.critic.dropout, + ) + for _ in range(cfg.algo.critic.n) + ] + target_entropy = -act_dim + agent = DROQAgent( + actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device + ) + if agent_state: + agent.load_state_dict(agent_state) + agent.actor = fabric.setup_module(agent.actor) + agent.critics = [fabric.setup_module(critic) for critic in agent.critics] + + return agent diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index f7d8fcee..7728e24c 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -3,31 +3,32 @@ import copy import os import warnings -from math import prod from typing import Any, Dict import gymnasium as gym import hydra +import mlflow import numpy as np import torch import torch.nn.functional as F from lightning.fabric import Fabric +from mlflow.models.model import ModelInfo from tensordict import TensorDict, make_tensordict from torch.optim import Optimizer from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import BatchSampler from torchmetrics import SumMetric -from sheeprl.algos.droq.agent import DROQAgent, DROQCritic -from sheeprl.algos.sac.agent import SACActor +from sheeprl.algos.droq.agent import DROQAgent, build_agent from sheeprl.algos.sac.loss import entropy_loss, policy_loss from sheeprl.algos.sac.sac import test from sheeprl.data.buffers import ReplayBuffer from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer +from sheeprl.utils.utils import register_model, unwrap_fabric def train( @@ -152,9 +153,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): warnings.warn("DroQ algorithm cannot allow to use images as observations, the CNN keys will be ignored") cfg.algo.cnn_keys.encoder = [] - # Create TensorBoardLogger. This will create the logger only on the + # Create Logger. This will create the logger only on the # rank-0 process - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -193,33 +194,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) # Define the agent and the optimizer and setup them with Fabric - act_dim = prod(action_space.shape) - obs_dim = sum([prod(observation_space[k].shape) for k in cfg.algo.mlp_keys.encoder]) - actor = SACActor( - observation_dim=obs_dim, - action_dim=act_dim, - distribution_cfg=cfg.distribution, - hidden_size=cfg.algo.actor.hidden_size, - action_low=action_space.low, - action_high=action_space.high, + agent = build_agent( + fabric, cfg, observation_space, action_space, state["agent"] if cfg.checkpoint.resume_from else None ) - critics = [ - DROQCritic( - observation_dim=obs_dim + act_dim, - hidden_size=cfg.algo.critic.hidden_size, - num_critics=1, - dropout=cfg.algo.critic.dropout, - ) - for _ in range(cfg.algo.critic.n) - ] - target_entropy = -act_dim - agent = DROQAgent( - actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device - ) - if cfg.checkpoint.resume_from: - agent.load_state_dict(state["agent"]) - agent.actor = fabric.setup_module(agent.actor) - agent.critics = [fabric.setup_module(critic) for critic in agent.critics] # Optimizers qf_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.qfs.parameters()) @@ -233,6 +210,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): qf_optimizer, actor_optimizer, alpha_optimizer ) + local_vars = locals() + # Metrics aggregator = None if not MetricAggregator.disabled: @@ -407,3 +386,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero: test(agent.actor.module, fabric, cfg, log_dir) + + if not cfg.model_manager.disabled and fabric.is_global_zero: + + def log_models( + run_id: str, experiment_id: str | None = None, run_name: str | None = None + ) -> Dict[str, ModelInfo]: + with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=True) as _: + model_info = {} + unwrapped_models = {} + for k in cfg.model_manager.models.keys(): + unwrapped_models[k] = unwrap_fabric(local_vars[k]) + model_info[k] = mlflow.pytorch.log_model(unwrapped_models[k], artifact_path=k) + mlflow.log_dict(cfg, "config.json") + return model_info + + register_model(fabric, log_models, cfg) diff --git a/sheeprl/algos/droq/evaluate.py b/sheeprl/algos/droq/evaluate.py index 5738bafc..a80f0ef0 100644 --- a/sheeprl/algos/droq/evaluate.py +++ b/sheeprl/algos/droq/evaluate.py @@ -1,22 +1,20 @@ from __future__ import annotations -from math import prod from typing import Any, Dict import gymnasium as gym from lightning import Fabric -from sheeprl.algos.droq.agent import DROQAgent, DROQCritic -from sheeprl.algos.sac.agent import SACActor +from sheeprl.algos.droq.agent import build_agent from sheeprl.algos.sac.utils import test from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.registry import register_evaluation @register_evaluation(algorithms="droq") def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -47,29 +45,5 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): if cfg.metric.log_level > 0: fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) - act_dim = prod(action_space.shape) - obs_dim = sum([prod(observation_space[k].shape) for k in cfg.algo.mlp_keys.encoder]) - actor = SACActor( - observation_dim=obs_dim, - action_dim=act_dim, - distribution_cfg=cfg.distribution, - hidden_size=cfg.algo.actor.hidden_size, - action_low=action_space.low, - action_high=action_space.high, - ) - critics = [ - DROQCritic( - observation_dim=obs_dim + act_dim, - hidden_size=cfg.algo.critic.hidden_size, - num_critics=1, - dropout=cfg.algo.critic.dropout, - ) - for _ in range(cfg.algo.critic.n) - ] - target_entropy = -act_dim - agent = DROQAgent( - actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device - ) - agent.load_state_dict(state["agent"]) - agent = fabric.setup_module(agent) + agent = build_agent(fabric, cfg, observation_space, action_space, state["agent"]) test(agent.actor, fabric, cfg, log_dir) diff --git a/sheeprl/algos/droq/utils.py b/sheeprl/algos/droq/utils.py index dbb49561..ef6800c7 100644 --- a/sheeprl/algos/droq/utils.py +++ b/sheeprl/algos/droq/utils.py @@ -1,3 +1,30 @@ +from __future__ import annotations + +from typing import Any, Dict, Sequence + +import gymnasium as gym +import mlflow +from lightning import Fabric +from mlflow.models.model import ModelInfo + +from sheeprl.algos.droq.agent import build_agent from sheeprl.algos.sac.utils import AGGREGATOR_KEYS as sac_aggregator_keys +from sheeprl.algos.sac.utils import MODELS_TO_REGISTER as sac_models_to_register +from sheeprl.utils.utils import unwrap_fabric AGGREGATOR_KEYS = sac_aggregator_keys +MODELS_TO_REGISTER = sac_models_to_register + + +def log_models_from_checkpoint( + fabric: Fabric, env: gym.Env | gym.Wrapper, cfg: Dict[str, Any], state: Dict[str, Any] +) -> Sequence[ModelInfo]: + # Create the models + agent = build_agent(fabric, cfg, env.observation_space, env.action_space, state["agent"]) + + # Log the model, create a new run if `cfg.run_id` is None. + model_info = {} + with mlflow.start_run(run_id=cfg.run.id, experiment_id=cfg.experiment.id, run_name=cfg.run.name, nested=True) as _: + model_info["agent"] = mlflow.pytorch.log_model(unwrap_fabric(agent), artifact_path="agent") + mlflow.log_dict(cfg.to_log, "config.json") + return model_info diff --git a/sheeprl/algos/p2e_dv1/agent.py b/sheeprl/algos/p2e_dv1/agent.py index 9c844dd8..7d13b21d 100644 --- a/sheeprl/algos/p2e_dv1/agent.py +++ b/sheeprl/algos/p2e_dv1/agent.py @@ -5,9 +5,11 @@ import torch from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule +from lightning.pytorch.utilities.seed import isolate_rng +from torch import nn from sheeprl.algos.dreamer_v1.agent import WorldModel -from sheeprl.algos.dreamer_v1.agent import build_models as dv1_build_models +from sheeprl.algos.dreamer_v1.agent import build_agent as dv1_build_agent from sheeprl.algos.dreamer_v2.agent import Actor as DV2Actor from sheeprl.algos.dreamer_v2.agent import MinedojoActor as DV2MinedojoActor from sheeprl.models.models import MLP @@ -20,18 +22,19 @@ MinedojoActor = DV2MinedojoActor -def build_models( +def build_agent( fabric: Fabric, actions_dim: Sequence[int], is_continuous: bool, cfg: Dict[str, Any], obs_space: gymnasium.spaces.Dict, world_model_state: Optional[Dict[str, torch.Tensor]] = None, + ensembles_state: Optional[Dict[str, torch.Tensor]] = None, actor_task_state: Optional[Dict[str, torch.Tensor]] = None, critic_task_state: Optional[Dict[str, torch.Tensor]] = None, actor_exploration_state: Optional[Dict[str, torch.Tensor]] = None, critic_exploration_state: Optional[Dict[str, torch.Tensor]] = None, -) -> Tuple[WorldModel, _FabricModule, _FabricModule, _FabricModule, _FabricModule]: +) -> Tuple[WorldModel, _FabricModule, _FabricModule, _FabricModule, _FabricModule, _FabricModule]: """Build the models and wrap them with Fabric. Args: @@ -42,6 +45,8 @@ def build_models( obs_space (Dict[str, Any]): the observation space. world_model_state (Dict[str, Tensor], optional): the state of the world model. Default to None. + ensembles_state (Dict[str, Tensor], optional): the state of the ensembles. + Default to None. actor_task_state (Dict[str, Tensor], optional): the state of the actor_task. Default to None. critic_task_state (Dict[str, Tensor], optional): the state of the critic_task. @@ -53,11 +58,12 @@ def build_models( Returns: The world model (WorldModel): composed by the encoder, rssm, observation and - reward models and the continue model. - The actor_task (_FabricModule). - The critic_task (_FabricModule). - The actor_exploration (_FabricModule). - The critic_exploration (_FabricModule). + reward models and the continue model. + The ensembles (_FabricModule): for estimating the intrinsic reward. + The actor_task (_FabricModule): for learning the task. + The critic_task (_FabricModule): for predicting the values of the task. + The actor_exploration (_FabricModule): for exploring the environment. + The critic_exploration (_FabricModule): for predicting the values of the exploration. """ world_model_cfg = cfg.algo.world_model actor_cfg = cfg.algo.actor @@ -67,7 +73,7 @@ def build_models( latent_state_size = world_model_cfg.stochastic_size + world_model_cfg.recurrent_model.recurrent_state_size # Create exploration models - world_model, actor_exploration, critic_exploration = dv1_build_models( + world_model, actor_exploration, critic_exploration = dv1_build_agent( fabric, actions_dim=actions_dim, is_continuous=is_continuous, @@ -110,4 +116,26 @@ def build_models( actor_task = fabric.setup_module(actor_task) critic_task = fabric.setup_module(critic_task) - return world_model, actor_task, critic_task, actor_exploration, critic_exploration + ens_list = [] + with isolate_rng(): + for i in range(cfg.algo.ensembles.n): + fabric.seed_everything(cfg.seed + i) + ens_list.append( + MLP( + input_dims=( + int(sum(actions_dim)) + + cfg.algo.world_model.recurrent_model.recurrent_state_size + + cfg.algo.world_model.stochastic_size + ), + output_dim=world_model.encoder.cnn_output_dim + world_model.encoder.mlp_output_dim, + hidden_sizes=[cfg.algo.ensembles.dense_units] * cfg.algo.ensembles.mlp_layers, + activation=eval(cfg.algo.ensembles.dense_act), + ).apply(init_weights) + ) + ensembles = nn.ModuleList(ens_list) + if ensembles_state: + ensembles.load_state_dict(ensembles_state) + for i in range(len(ensembles)): + ensembles[i] = fabric.setup_module(ensembles[i]) + + return world_model, ensembles, actor_task, critic_task, actor_exploration, critic_exploration diff --git a/sheeprl/algos/p2e_dv1/evaluate.py b/sheeprl/algos/p2e_dv1/evaluate.py index 27b1396a..e6d2de92 100644 --- a/sheeprl/algos/p2e_dv1/evaluate.py +++ b/sheeprl/algos/p2e_dv1/evaluate.py @@ -7,15 +7,15 @@ from sheeprl.algos.dreamer_v1.agent import PlayerDV1 from sheeprl.algos.dreamer_v2.utils import test -from sheeprl.algos.p2e_dv1.agent import build_models +from sheeprl.algos.p2e_dv1.agent import build_agent from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.registry import register_evaluation @register_evaluation(algorithms=["p2e_dv1_exploration", "p2e_dv1_finetuning"]) def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -40,11 +40,11 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) # Create the actor and critic models - world_model, actor_task, _, _, _ = build_models( + world_model, _, actor_task, _, _, _ = build_agent( fabric, actions_dim, is_continuous, diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index edc4e8d2..d66d9161 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -7,15 +7,15 @@ import gymnasium as gym import hydra +import mlflow import numpy as np import torch import torch.nn.functional as F from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer -from lightning.pytorch.utilities.seed import isolate_rng +from mlflow.models.model import ModelInfo from tensordict import TensorDict from tensordict.tensordict import TensorDictBase -from torch import nn from torch.distributions import Bernoulli, Independent, Normal from torch.utils.data import BatchSampler from torchmetrics import SumMetric @@ -24,15 +24,14 @@ from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss from sheeprl.algos.dreamer_v1.utils import compute_lambda_values from sheeprl.algos.dreamer_v2.utils import test -from sheeprl.algos.p2e_dv1.agent import build_models +from sheeprl.algos.p2e_dv1.agent import build_agent from sheeprl.data.buffers import AsyncReplayBuffer -from sheeprl.models.models import MLP from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import init_weights, polynomial_decay +from sheeprl.utils.utils import polynomial_decay, register_model, unwrap_fabric # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -426,14 +425,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): cfg.env.frame_stack = 1 cfg.algo.player.actor_type = "exploration" - # Create TensorBoardLogger. This will create the logger only on the + # Create Logger. This will create the logger only on the # rank-0 process - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) - # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv envs = vectorized_env( @@ -454,7 +452,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r @@ -483,40 +481,20 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder - world_model, actor_task, critic_task, actor_exploration, critic_exploration = build_models( + world_model, ensembles, actor_task, critic_task, actor_exploration, critic_exploration = build_agent( fabric, actions_dim, is_continuous, cfg, observation_space, state["world_model"] if cfg.checkpoint.resume_from else None, + state["ensembles"] if cfg.checkpoint.resume_from else None, state["actor_task"] if cfg.checkpoint.resume_from else None, state["critic_task"] if cfg.checkpoint.resume_from else None, state["actor_exploration"] if cfg.checkpoint.resume_from else None, state["critic_exploration"] if cfg.checkpoint.resume_from else None, ) - # initialize the ensembles with different seeds to be sure they have different weights - ens_list = [] - with isolate_rng(): - for i in range(cfg.algo.ensembles.n): - fabric.seed_everything(cfg.seed + i) - ens_list.append( - MLP( - input_dims=( - int(sum(actions_dim)) - + cfg.algo.world_model.recurrent_model.recurrent_state_size - + cfg.algo.world_model.stochastic_size - ), - output_dim=world_model.encoder.cnn_output_dim + world_model.encoder.mlp_output_dim, - hidden_sizes=[cfg.algo.ensembles.dense_units] * cfg.algo.ensembles.mlp_layers, - activation=eval(cfg.algo.ensembles.dense_act), - ).apply(init_weights) - ) - ensembles = nn.ModuleList(ens_list) - if cfg.checkpoint.resume_from: - ensembles.load_state_dict(state["ensembles"]) - fabric.setup_module(ensembles) player = PlayerDV1( world_model.encoder.module, world_model.rssm.recurrent_model.module, @@ -565,6 +543,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): critic_exploration_optimizer, ) + local_vars = locals() + # Metrics aggregator = None if not MetricAggregator.disabled: @@ -868,3 +848,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.actor = actor_task.module player.actor_type = "task" test(player, fabric, cfg, log_dir, "zero-shot") + + if not cfg.model_manager.disabled and fabric.is_global_zero: + + def log_models( + run_id: str, experiment_id: str | None = None, run_name: str | None = None + ) -> Dict[str, ModelInfo]: + with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=True) as _: + model_info = {} + unwrapped_models = {} + for k in cfg.model_manager.models.keys(): + unwrapped_models[k] = unwrap_fabric(local_vars[k]) + model_info[k] = mlflow.pytorch.log_model(unwrapped_models[k], artifact_path=k) + mlflow.log_dict(cfg, "config.json") + return model_info + + register_model(fabric, log_models, cfg) diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index 109f9c83..174dac8b 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -8,9 +8,11 @@ import gymnasium as gym import hydra +import mlflow import numpy as np import torch from lightning.fabric import Fabric +from mlflow.models.model import ModelInfo from tensordict import TensorDict from torch.utils.data import BatchSampler from torchmetrics import SumMetric @@ -18,14 +20,14 @@ from sheeprl.algos.dreamer_v1.agent import PlayerDV1 from sheeprl.algos.dreamer_v1.dreamer_v1 import train from sheeprl.algos.dreamer_v2.utils import test -from sheeprl.algos.p2e_dv1.agent import build_models +from sheeprl.algos.p2e_dv1.agent import build_agent from sheeprl.data.buffers import AsyncReplayBuffer from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay +from sheeprl.utils.utils import polynomial_decay, register_model, unwrap_fabric # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -48,7 +50,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): else: state = fabric.load(ckpt_path) - # All the models must be equal to the ones of the exploration phase + # All the models must be equal to the ones of the exploration phase cfg.algo.gamma = exploration_cfg.algo.gamma cfg.algo.lmbda = exploration_cfg.algo.lmbda cfg.algo.horizon = exploration_cfg.algo.horizon @@ -76,9 +78,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): cfg.env.screen_size = 64 cfg.env.frame_stack = 1 - # Create TensorBoardLogger. This will create the logger only on the + # Create Logger. This will create the logger only on the # rank-0 process - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -104,7 +106,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r @@ -133,13 +135,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder - world_model, actor_task, critic_task, actor_exploration, _ = build_models( + world_model, _, actor_task, critic_task, actor_exploration, _ = build_agent( fabric, actions_dim, is_continuous, cfg, observation_space, state["world_model"], + None, state["actor_task"], state["critic_task"], state["actor_exploration"], @@ -169,6 +172,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): world_optimizer, actor_task_optimizer, critic_task_optimizer ) + local_vars = locals() + # Metrics aggregator = None if not MetricAggregator.disabled: @@ -448,3 +453,23 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): player.actor = actor_task.module player.actor_type = "task" test(player, fabric, cfg, log_dir, "few-shot") + + if not cfg.model_manager.disabled and fabric.is_global_zero: + + def log_models( + run_id: str, experiment_id: str | None = None, run_name: str | None = None + ) -> Dict[str, ModelInfo]: + with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=True) as _: + model_info = {} + unwrapped_models = {} + models_keys = set(cfg.model_manager.models.keys()) + for k in models_keys: + if "exploration" not in k and k != "ensembles": + unwrapped_models[k] = unwrap_fabric(local_vars[k]) + model_info[k] = mlflow.pytorch.log_model(unwrapped_models[k], artifact_path=k) + else: + cfg.model_manager.models.pop(k, None) + mlflow.log_dict(cfg, "config.json") + return model_info + + register_model(fabric, log_models, cfg) diff --git a/sheeprl/algos/p2e_dv1/utils.py b/sheeprl/algos/p2e_dv1/utils.py index 8297f3c8..46ba98fd 100644 --- a/sheeprl/algos/p2e_dv1/utils.py +++ b/sheeprl/algos/p2e_dv1/utils.py @@ -1,4 +1,15 @@ +from __future__ import annotations + +from typing import Any, Dict, Sequence + +import gymnasium as gym +import mlflow +from lightning import Fabric +from mlflow.models.model import ModelInfo + from sheeprl.algos.dreamer_v1.utils import AGGREGATOR_KEYS as AGGREGATOR_KEYS_DV1 +from sheeprl.algos.p2e_dv1.agent import build_agent +from sheeprl.utils.utils import unwrap_fabric AGGREGATOR_KEYS = { "Rewards/rew_avg", @@ -28,3 +39,62 @@ "Grads/critic_exploration", "Grads/ensemble", }.union(AGGREGATOR_KEYS_DV1) +MODELS_TO_REGISTER = { + "world_model", + "ensembles", + "actor_exploration", + "critic_exploration", + "actor_task", + "critic_task", +} + + +def log_models_from_checkpoint( + fabric: Fabric, env: gym.Env | gym.Wrapper, cfg: Dict[str, Any], state: Dict[str, Any] +) -> Sequence[ModelInfo]: + # Create the models + is_continuous = isinstance(env.action_space, gym.spaces.Box) + is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) + actions_dim = tuple( + env.action_space.shape + if is_continuous + else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) + ) + ( + world_model, + ensembles, + actor_task, + critic_task, + actor_exploration, + critic_exploration, + ) = build_agent( + fabric, + actions_dim, + is_continuous, + cfg, + env.observation_space, + state["world_model"], + state["ensembles"] if "exploration" in cfg.algo.name else None, + state["actor_task"], + state["critic_task"], + state["actor_exploration"] if "exploration" in cfg.algo.name else None, + state["critic_exploration"] if "exploration" in cfg.algo.name else None, + ) + + # Log the model, create a new run if `cfg.run_id` is None. + model_info = {} + with mlflow.start_run(run_id=cfg.run.id, experiment_id=cfg.experiment.id, run_name=cfg.run.name, nested=True) as _: + model_info["world_model"] = mlflow.pytorch.log_model(unwrap_fabric(world_model), artifact_path="world_model") + model_info["actor_task"] = mlflow.pytorch.log_model(unwrap_fabric(actor_task), artifact_path="actor_task") + model_info["critic_task"] = mlflow.pytorch.log_model(unwrap_fabric(critic_task), artifact_path="critic_task") + if "exploration" in cfg.algo.name: + model_info["ensembles"] = mlflow.pytorch.log_model(unwrap_fabric(ensembles), artifact_path="ensembles") + model_info["actor_exploration"] = mlflow.pytorch.log_model( + unwrap_fabric(actor_exploration), artifact_path="actor_exploration" + ) + model_info["critic_exploration"] = mlflow.pytorch.log_model( + unwrap_fabric(critic_exploration), artifact_path="critic_exploration" + ) + mlflow.log_dict(cfg.to_log, "config.json") + + return model_info diff --git a/sheeprl/algos/p2e_dv2/agent.py b/sheeprl/algos/p2e_dv2/agent.py index b40973d6..ffb146cd 100644 --- a/sheeprl/algos/p2e_dv2/agent.py +++ b/sheeprl/algos/p2e_dv2/agent.py @@ -6,12 +6,13 @@ import torch from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule +from lightning.pytorch.utilities.seed import isolate_rng from torch import nn from sheeprl.algos.dreamer_v2.agent import Actor as DV2Actor from sheeprl.algos.dreamer_v2.agent import MinedojoActor as DV2MinedojoActor from sheeprl.algos.dreamer_v2.agent import WorldModel -from sheeprl.algos.dreamer_v2.agent import build_models as dv2_build_models +from sheeprl.algos.dreamer_v2.agent import build_agent as dv2_build_agent from sheeprl.models.models import MLP from sheeprl.utils.utils import init_weights @@ -22,20 +23,21 @@ MinedojoActor = DV2MinedojoActor -def build_models( +def build_agent( fabric: Fabric, actions_dim: Sequence[int], is_continuous: bool, cfg: Dict[str, Any], obs_space: gymnasium.spaces.Dict, world_model_state: Optional[Dict[str, torch.Tensor]] = None, + ensembles_state: Optional[Dict[str, torch.Tensor]] = None, actor_task_state: Optional[Dict[str, torch.Tensor]] = None, critic_task_state: Optional[Dict[str, torch.Tensor]] = None, target_critic_task_state: Optional[Dict[str, torch.Tensor]] = None, actor_exploration_state: Optional[Dict[str, torch.Tensor]] = None, critic_exploration_state: Optional[Dict[str, torch.Tensor]] = None, target_critic_exploration_state: Optional[Dict[str, torch.Tensor]] = None, -) -> Tuple[WorldModel, _FabricModule, _FabricModule, nn.Module, _FabricModule, _FabricModule, nn.Module]: +) -> Tuple[WorldModel, _FabricModule, _FabricModule, _FabricModule, nn.Module, _FabricModule, _FabricModule, nn.Module]: """Build the models and wrap them with Fabric. Args: @@ -46,6 +48,8 @@ def build_models( obs_space (Dict[str, Any]): The observations space of the environment. world_model_state (Dict[str, Tensor], optional): the state of the world model. Default to None. + ensembles_state (Dict[str, Tensor], optional): the state of the ensembles. + Default to None. actor_task_state (Dict[str, Tensor], optional): the state of the actor_task. Default to None. critic_task_state (Dict[str, Tensor], optional): the state of the critic_task. @@ -61,13 +65,14 @@ def build_models( Returns: The world model (WorldModel): composed by the encoder, rssm, observation and - reward models and the continue model. - The actor_task (_FabricModule). - The critic_task (_FabricModule). - The target_critic_task (nn.Module). - The actor_exploration (_FabricModule). - The critic_exploration (_FabricModule). - The target_critic_exploration (nn.Module). + reward models and the continue model. + The ensembles (_FabricModule): for estimating the intrinsic reward. + The actor_task (_FabricModule): for learning the task. + The critic_task (_FabricModule): for predicting the values of the task. + The target_critic_task (nn.Module): takes a EMA of the critic_task weights. + The actor_exploration (_FabricModule): for exploring the environment. + The critic_exploration (_FabricModule): for predicting the values of the exploration. + The target_critic_exploration (nn.Module): takes a EMA of the critic_exploration weights. """ world_model_cfg = cfg.algo.world_model actor_cfg = cfg.algo.actor @@ -78,7 +83,7 @@ def build_models( latent_state_size = stochastic_size + world_model_cfg.recurrent_model.recurrent_state_size # Create exploration models - world_model, actor_exploration, critic_exploration, target_critic_exploration = dv2_build_models( + world_model, actor_exploration, critic_exploration, target_critic_exploration = dv2_build_agent( fabric, actions_dim=actions_dim, is_continuous=is_continuous, @@ -131,8 +136,46 @@ def build_models( if target_critic_task_state: target_critic_task.load_state_dict(target_critic_task_state) + # initialize the ensembles with different seeds to be sure they have different weights + ens_list = [] + with isolate_rng(): + for i in range(cfg.algo.ensembles.n): + fabric.seed_everything(cfg.seed + i) + ens_list.append( + MLP( + input_dims=int( + sum(actions_dim) + + cfg.algo.world_model.recurrent_model.recurrent_state_size + + cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size + ), + output_dim=cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size, + hidden_sizes=[cfg.algo.ensembles.dense_units] * cfg.algo.ensembles.mlp_layers, + activation=eval(cfg.algo.ensembles.dense_act), + flatten_dim=None, + norm_layer=( + [nn.LayerNorm for _ in range(cfg.algo.ensembles.mlp_layers)] + if cfg.algo.ensembles.layer_norm + else None + ), + norm_args=( + [ + {"normalized_shape": cfg.algo.ensembles.dense_units} + for _ in range(cfg.algo.ensembles.mlp_layers) + ] + if cfg.algo.ensembles.layer_norm + else None + ), + ).apply(init_weights) + ) + ensembles = nn.ModuleList(ens_list) + if ensembles_state: + ensembles.load_state_dict(ensembles_state) + for i in range(len(ensembles)): + ensembles[i] = fabric.setup_module(ensembles[i]) + return ( world_model, + ensembles, actor_task, critic_task, target_critic_task, diff --git a/sheeprl/algos/p2e_dv2/evaluate.py b/sheeprl/algos/p2e_dv2/evaluate.py index c2ccf666..28757330 100644 --- a/sheeprl/algos/p2e_dv2/evaluate.py +++ b/sheeprl/algos/p2e_dv2/evaluate.py @@ -7,15 +7,15 @@ from sheeprl.algos.dreamer_v2.agent import PlayerDV2 from sheeprl.algos.dreamer_v2.utils import test -from sheeprl.algos.p2e_dv2.agent import build_models +from sheeprl.algos.p2e_dv2.agent import build_agent from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.registry import register_evaluation @register_evaluation(algorithms=["p2e_dv2_exploration", "p2e_dv2_finetuning"]) def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -40,17 +40,18 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) # Create the actor and critic models - world_model, actor_task, _, _, _, _, _ = build_models( + world_model, _, actor_task, _, _, _, _, _ = build_agent( fabric, actions_dim, is_continuous, cfg, observation_space, state["world_model"], + None, state["actor_task"], ) player = PlayerDV2( diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index ab2599f5..ed24cea3 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -7,12 +7,13 @@ import gymnasium as gym import hydra +import mlflow import numpy as np import torch import torch.nn.functional as F from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer -from lightning.pytorch.utilities.seed import isolate_rng +from mlflow.models.model import ModelInfo from tensordict import TensorDict from tensordict.tensordict import TensorDictBase from torch import Tensor, nn @@ -22,17 +23,16 @@ from sheeprl.algos.dreamer_v2.agent import PlayerDV2, WorldModel from sheeprl.algos.dreamer_v2.loss import reconstruction_loss -from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, init_weights, test -from sheeprl.algos.p2e_dv2.agent import build_models +from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, test +from sheeprl.algos.p2e_dv2.agent import build_agent from sheeprl.data.buffers import AsyncReplayBuffer, EpisodeBuffer -from sheeprl.models.models import MLP from sheeprl.utils.distribution import OneHotCategoricalValidateArgs from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay +from sheeprl.utils.utils import polynomial_decay, register_model, unwrap_fabric # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -540,9 +540,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): cfg.env.frame_stack = 1 cfg.algo.player.actor_type = "exploration" - # Create TensorBoardLogger. This will create the logger only on the + # Create Logger. This will create the logger only on the # rank-0 process - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -568,7 +568,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r @@ -599,19 +599,21 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ( world_model, + ensembles, actor_task, critic_task, target_critic_task, actor_exploration, critic_exploration, target_critic_exploration, - ) = build_models( + ) = build_agent( fabric, actions_dim, is_continuous, cfg, observation_space, state["world_model"] if cfg.checkpoint.resume_from else None, + state["ensembles"] if cfg.checkpoint.resume_from else None, state["actor_task"] if cfg.checkpoint.resume_from else None, state["critic_task"] if cfg.checkpoint.resume_from else None, state["target_critic_task"] if cfg.checkpoint.resume_from else None, @@ -620,41 +622,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): state["target_critic_exploration"] if cfg.checkpoint.resume_from else None, ) - # initialize the ensembles with different seeds to be sure they have different weights - ens_list = [] - with isolate_rng(): - for i in range(cfg.algo.ensembles.n): - fabric.seed_everything(cfg.seed + i) - ens_list.append( - MLP( - input_dims=int( - sum(actions_dim) - + cfg.algo.world_model.recurrent_model.recurrent_state_size - + cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size - ), - output_dim=cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size, - hidden_sizes=[cfg.algo.ensembles.dense_units] * cfg.algo.ensembles.mlp_layers, - activation=eval(cfg.algo.ensembles.dense_act), - flatten_dim=None, - norm_layer=( - [nn.LayerNorm for _ in range(cfg.algo.ensembles.mlp_layers)] - if cfg.algo.ensembles.layer_norm - else None - ), - norm_args=( - [ - {"normalized_shape": cfg.algo.ensembles.dense_units} - for _ in range(cfg.algo.ensembles.mlp_layers) - ] - if cfg.algo.ensembles.layer_norm - else None - ), - ).apply(init_weights) - ) - ensembles = nn.ModuleList(ens_list) - if cfg.checkpoint.resume_from: - ensembles.load_state_dict(state["ensembles"]) - fabric.setup_module(ensembles) player = PlayerDV2( world_model.encoder.module, world_model.rssm.recurrent_model.module, @@ -703,6 +670,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): critic_exploration_optimizer, ) + local_vars = locals() + # Metrics aggregator = None if not MetricAggregator.disabled: @@ -1063,3 +1032,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.actor = actor_task.module player.actor_type = "task" test(player, fabric, cfg, log_dir, "zero-shot") + + if not cfg.model_manager.disabled and fabric.is_global_zero: + + def log_models( + run_id: str, experiment_id: str | None = None, run_name: str | None = None + ) -> Dict[str, ModelInfo]: + with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=True) as _: + model_info = {} + unwrapped_models = {} + for k in cfg.model_manager.models.keys(): + unwrapped_models[k] = unwrap_fabric(local_vars[k]) + model_info[k] = mlflow.pytorch.log_model(unwrapped_models[k], artifact_path=k) + mlflow.log_dict(cfg, "config.json") + return model_info + + register_model(fabric, log_models, cfg) diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index 54d5fec1..adb97e60 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -8,9 +8,11 @@ import gymnasium as gym import hydra +import mlflow import numpy as np import torch from lightning.fabric import Fabric +from mlflow.models.model import ModelInfo from tensordict import TensorDict from torch import Tensor from torch.utils.data import BatchSampler @@ -19,14 +21,14 @@ from sheeprl.algos.dreamer_v2.agent import PlayerDV2 from sheeprl.algos.dreamer_v2.dreamer_v2 import train from sheeprl.algos.dreamer_v2.utils import test -from sheeprl.algos.p2e_dv2.agent import build_models +from sheeprl.algos.p2e_dv2.agent import build_agent from sheeprl.data.buffers import AsyncReplayBuffer, EpisodeBuffer from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay +from sheeprl.utils.utils import polynomial_decay, register_model, unwrap_fabric # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -80,9 +82,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): cfg.env.screen_size = 64 cfg.env.frame_stack = 1 - # Create TensorBoardLogger. This will create the logger only on the + # Create Logger. This will create the logger only on the # rank-0 process - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -108,7 +110,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r @@ -137,13 +139,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder - world_model, actor_task, critic_task, target_critic_task, actor_exploration, _, _ = build_models( + world_model, _, actor_task, critic_task, target_critic_task, actor_exploration, _, _ = build_agent( fabric, actions_dim, is_continuous, cfg, observation_space, state["world_model"], + None, state["actor_task"], state["critic_task"], state["target_critic_task"], @@ -175,6 +178,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): world_optimizer, actor_task_optimizer, critic_task_optimizer ) + local_vars = locals() + # Metrics aggregator = None if not MetricAggregator.disabled: @@ -504,3 +509,23 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): player.actor = actor_task.module player.actor_type = "task" test(player, fabric, cfg, log_dir, "few-shot") + + if not cfg.model_manager.disabled and fabric.is_global_zero: + + def log_models( + run_id: str, experiment_id: str | None = None, run_name: str | None = None + ) -> Dict[str, ModelInfo]: + with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=True) as _: + model_info = {} + unwrapped_models = {} + models_keys = set(cfg.model_manager.models.keys()) + for k in models_keys: + if "exploration" not in k and k != "ensembles": + unwrapped_models[k] = unwrap_fabric(local_vars[k]) + model_info[k] = mlflow.pytorch.log_model(unwrapped_models[k], artifact_path=k) + else: + cfg.model_manager.models.pop(k, None) + mlflow.log_dict(cfg, "config.json") + return model_info + + register_model(fabric, log_models, cfg) diff --git a/sheeprl/algos/p2e_dv2/utils.py b/sheeprl/algos/p2e_dv2/utils.py index c9aeb9ed..b0273535 100644 --- a/sheeprl/algos/p2e_dv2/utils.py +++ b/sheeprl/algos/p2e_dv2/utils.py @@ -1,4 +1,15 @@ +from __future__ import annotations + +from typing import Any, Dict, Sequence + +import gymnasium as gym +import mlflow +from lightning import Fabric +from mlflow.models.model import ModelInfo + from sheeprl.algos.dreamer_v2.utils import AGGREGATOR_KEYS as AGGREGATOR_KEYS_DV2 +from sheeprl.algos.p2e_dv2.agent import build_agent +from sheeprl.utils.utils import unwrap_fabric AGGREGATOR_KEYS = { "Rewards/rew_avg", @@ -28,3 +39,74 @@ "Grads/critic_exploration", "Grads/ensemble", }.union(AGGREGATOR_KEYS_DV2) +MODELS_TO_REGISTER = { + "world_model", + "ensembles", + "actor_exploration", + "critic_exploration", + "target_critic_exploration", + "actor_task", + "critic_task", + "target_critic_task", +} + + +def log_models_from_checkpoint( + fabric: Fabric, env: gym.Env | gym.Wrapper, cfg: Dict[str, Any], state: Dict[str, Any] +) -> Sequence[ModelInfo]: + # Create the models + is_continuous = isinstance(env.action_space, gym.spaces.Box) + is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) + actions_dim = tuple( + env.action_space.shape + if is_continuous + else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) + ) + ( + world_model, + ensembles, + actor_task, + critic_task, + target_critic_task, + actor_exploration, + critic_exploration, + target_critic_exploration, + ) = build_agent( + fabric, + actions_dim, + is_continuous, + cfg, + env.observation_space, + state["world_model"], + state["ensembles"] if "exploration" in cfg.algo.name else None, + state["actor_task"], + state["critic_task"], + state["target_critic_task"], + state["actor_exploration"] if "exploration" in cfg.algo.name else None, + state["critic_exploration"] if "exploration" in cfg.algo.name else None, + state["target_critic_exploration"] if "exploration" in cfg.algo.name else None, + ) + + # Log the model, create a new run if `cfg.run_id` is None. + model_info = {} + with mlflow.start_run(run_id=cfg.run.id, experiment_id=cfg.experiment.id, run_name=cfg.run.name, nested=True) as _: + model_info["world_model"] = mlflow.pytorch.log_model(unwrap_fabric(world_model), artifact_path="world_model") + model_info["actor_task"] = mlflow.pytorch.log_model(unwrap_fabric(actor_task), artifact_path="actor_task") + model_info["critic_task"] = mlflow.pytorch.log_model(unwrap_fabric(critic_task), artifact_path="critic_task") + model_info["target_critic_task"] = mlflow.pytorch.log_model( + target_critic_task, artifact_path="target_critic_task" + ) + if "exploration" in cfg.algo.name: + model_info["ensembles"] = mlflow.pytorch.log_model(unwrap_fabric(ensembles), artifact_path="ensembles") + model_info["actor_exploration"] = mlflow.pytorch.log_model( + unwrap_fabric(actor_exploration), artifact_path="actor_exploration" + ) + model_info["critic_exploration"] = mlflow.pytorch.log_model( + unwrap_fabric(critic_exploration), artifact_path="critic_exploration" + ) + model_info["target_critic_exploration"] = mlflow.pytorch.log_model( + target_critic_exploration, artifact_path="target_critic_exploration" + ) + mlflow.log_dict(cfg.to_log, "config.json") + + return model_info diff --git a/sheeprl/algos/p2e_dv3/agent.py b/sheeprl/algos/p2e_dv3/agent.py index c4019b1c..66547eaa 100644 --- a/sheeprl/algos/p2e_dv3/agent.py +++ b/sheeprl/algos/p2e_dv3/agent.py @@ -5,12 +5,13 @@ import torch from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule +from lightning.pytorch.utilities.seed import isolate_rng from torch import nn from sheeprl.algos.dreamer_v3.agent import Actor as DV3Actor from sheeprl.algos.dreamer_v3.agent import MinedojoActor as DV3MinedojoActor from sheeprl.algos.dreamer_v3.agent import WorldModel -from sheeprl.algos.dreamer_v3.agent import build_models as dv3_build_models +from sheeprl.algos.dreamer_v3.agent import build_agent as dv3_build_agent from sheeprl.algos.dreamer_v3.utils import init_weights, uniform_init_weights from sheeprl.models.models import MLP @@ -21,19 +22,20 @@ MinedojoActor = DV3MinedojoActor -def build_models( +def build_agent( fabric: Fabric, actions_dim: Sequence[int], is_continuous: bool, cfg: Dict[str, Any], obs_space: Dict[str, Any], world_model_state: Optional[Dict[str, torch.Tensor]] = None, + ensembles_state: Optional[Dict[str, torch.Tensor]] = None, actor_task_state: Optional[Dict[str, torch.Tensor]] = None, critic_task_state: Optional[Dict[str, torch.Tensor]] = None, target_critic_task_state: Optional[Dict[str, torch.Tensor]] = None, actor_exploration_state: Optional[Dict[str, torch.Tensor]] = None, critics_exploration_state: Optional[Dict[str, Dict[str, Any]]] = None, -) -> Tuple[WorldModel, _FabricModule, _FabricModule, nn.Module, _FabricModule, Dict[str, Any]]: +) -> Tuple[WorldModel, _FabricModule, _FabricModule, _FabricModule, nn.Module, _FabricModule, Dict[str, Any]]: """Build the models and wrap them with Fabric. Args: @@ -44,6 +46,8 @@ def build_models( obs_space (Dict[str, Any]): The observations space of the environment. world_model_state (Dict[str, Tensor], optional): the state of the world model. Default to None. + ensembles_state (Dict[str, Tensor], optional): the state of the ensembles. + Default to None. actor_task_state (Dict[str, Tensor], optional): the state of the actor_task. Default to None. critic_task_state (Dict[str, Tensor], optional): the state of the critic_task. @@ -58,11 +62,15 @@ def build_models( Returns: The world model (WorldModel): composed by the encoder, rssm, observation and reward models and the continue model. - The actor_task (_FabricModule). - The critic_task (_FabricModule). - The target_critic_task (nn.Module). - The actor_exploration (_FabricModule). - The critics_exploration (Dict[str, Dict[str, Any]]). + + The ensembles (_FabricModule): for estimating the intrinsic reward. + The actor_task (_FabricModule): for learning the task. + The critic_task (_FabricModule): for predicting the values of the task. + The target_critic_task (nn.Module): takes a EMA of the critic_task weights. + The actor_exploration (_FabricModule): for exploring the environment. + The critics_exploration (_FabricModule): for predicting the values of the exploration. + The critics_exploration (Dict[str, Dict[str, Any]]): python dictionary containing all the exploration critics. + The critic is under the 'module' key, whereas, the target critic is under the 'target_critic' key. """ world_model_cfg = cfg.algo.world_model actor_cfg = cfg.algo.actor @@ -73,7 +81,7 @@ def build_models( latent_state_size = stochastic_size + world_model_cfg.recurrent_model.recurrent_state_size # Create task models - world_model, actor_task, critic_task, target_critic_task = dv3_build_models( + world_model, actor_task, critic_task, target_critic_task = dv3_build_agent( fabric, actions_dim=actions_dim, is_continuous=is_continuous, @@ -152,8 +160,43 @@ def build_models( for c in critics_exploration.values(): c["target_module"].requires_grad_(False) + # initialize the ensembles with different seeds to be sure they have different weights + ens_list = [] + cfg_ensembles = cfg.algo.ensembles + with isolate_rng(): + for i in range(cfg_ensembles.n): + fabric.seed_everything(cfg.seed + i) + ens_list.append( + MLP( + input_dims=int( + sum(actions_dim) + + cfg.algo.world_model.recurrent_model.recurrent_state_size + + cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size + ), + output_dim=cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size, + hidden_sizes=[cfg_ensembles.dense_units] * cfg_ensembles.mlp_layers, + activation=eval(cfg_ensembles.dense_act), + flatten_dim=None, + layer_args={"bias": not cfg.algo.ensembles.layer_norm}, + norm_layer=( + [nn.LayerNorm for _ in range(cfg_ensembles.mlp_layers)] if cfg_ensembles.layer_norm else None + ), + norm_args=( + [{"normalized_shape": cfg_ensembles.dense_units} for _ in range(cfg_ensembles.mlp_layers)] + if cfg_ensembles.layer_norm + else None + ), + ).apply(init_weights) + ) + ensembles = nn.ModuleList(ens_list) + if ensembles_state: + ensembles.load_state_dict(ensembles_state) + for i in range(len(ensembles)): + ensembles[i] = fabric.setup_module(ensembles[i]) + return ( world_model, + ensembles, actor_task, critic_task, target_critic_task, diff --git a/sheeprl/algos/p2e_dv3/evaluate.py b/sheeprl/algos/p2e_dv3/evaluate.py index 97b86112..b99c2d28 100644 --- a/sheeprl/algos/p2e_dv3/evaluate.py +++ b/sheeprl/algos/p2e_dv3/evaluate.py @@ -7,15 +7,15 @@ from sheeprl.algos.dreamer_v3.agent import PlayerDV3 from sheeprl.algos.dreamer_v3.utils import test -from sheeprl.algos.p2e_dv3.agent import build_models +from sheeprl.algos.p2e_dv3.agent import build_agent from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.registry import register_evaluation @register_evaluation(algorithms=["p2e_dv3_exploration", "p2e_dv3_finetuning"]) def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -40,17 +40,18 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) # Create the actor and critic models - world_model, actor, _, _, _, _ = build_models( + world_model, _, actor, _, _, _, _ = build_agent( fabric, actions_dim, is_continuous, cfg, observation_space, state["world_model"], + None, state["actor_task"], ) player = PlayerDV3( diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index a5f1b1a6..e7ed5ca4 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -5,12 +5,13 @@ import gymnasium as gym import hydra +import mlflow import numpy as np import torch import torch.nn.functional as F from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer -from lightning.pytorch.utilities.seed import isolate_rng +from mlflow.models.model import ModelInfo from omegaconf import DictConfig from tensordict import TensorDict from tensordict.tensordict import TensorDictBase @@ -21,10 +22,9 @@ from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel from sheeprl.algos.dreamer_v3.loss import reconstruction_loss -from sheeprl.algos.dreamer_v3.utils import Moments, compute_lambda_values, init_weights, test -from sheeprl.algos.p2e_dv3.agent import build_models +from sheeprl.algos.dreamer_v3.utils import Moments, compute_lambda_values, test +from sheeprl.algos.p2e_dv3.agent import build_agent from sheeprl.data.buffers import AsyncReplayBuffer -from sheeprl.models.models import MLP from sheeprl.utils.distribution import ( MSEDistribution, OneHotCategoricalValidateArgs, @@ -32,11 +32,11 @@ TwoHotEncodingDistribution, ) from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay +from sheeprl.utils.utils import polynomial_decay, register_model, unwrap_fabric # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -570,9 +570,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): cfg.env.frame_stack = 1 cfg.algo.player.actor_type = "exploration" - # Create TensorBoardLogger. This will create the logger only on the + # Create Logger. This will create the logger only on the # rank-0 process - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -598,7 +598,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r @@ -629,18 +629,20 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ( world_model, + ensembles, actor_task, critic_task, target_critic_task, actor_exploration, critics_exploration, - ) = build_models( + ) = build_agent( fabric, actions_dim, is_continuous, cfg, observation_space, state["world_model"] if cfg.checkpoint.resume_from else None, + state["ensembles"] if cfg.checkpoint.resume_from else None, state["actor_task"] if cfg.checkpoint.resume_from else None, state["critic_task"] if cfg.checkpoint.resume_from else None, state["target_critic_task"] if cfg.checkpoint.resume_from else None, @@ -648,38 +650,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): state["critics_exploration"] if cfg.checkpoint.resume_from else None, ) - # initialize the ensembles with different seeds to be sure they have different weights - ens_list = [] - cfg_ensembles = cfg.algo.ensembles - with isolate_rng(): - for i in range(cfg_ensembles.n): - fabric.seed_everything(cfg.seed + i) - ens_list.append( - MLP( - input_dims=int( - sum(actions_dim) - + cfg.algo.world_model.recurrent_model.recurrent_state_size - + cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size - ), - output_dim=cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size, - hidden_sizes=[cfg_ensembles.dense_units] * cfg_ensembles.mlp_layers, - activation=eval(cfg_ensembles.dense_act), - flatten_dim=None, - layer_args={"bias": not cfg.algo.ensembles.layer_norm}, - norm_layer=( - [nn.LayerNorm for _ in range(cfg_ensembles.mlp_layers)] if cfg_ensembles.layer_norm else None - ), - norm_args=( - [{"normalized_shape": cfg_ensembles.dense_units} for _ in range(cfg_ensembles.mlp_layers)] - if cfg_ensembles.layer_norm - else None - ), - ).apply(init_weights) - ) - ensembles = nn.ModuleList(ens_list) - if cfg.checkpoint.resume_from: - ensembles.load_state_dict(state["ensembles"]) - fabric.setup_module(ensembles) player = PlayerDV3( world_model.encoder.module, world_model.rssm, @@ -749,6 +719,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): m.load_state_dict(state[f"moments_exploration_{k}"]) moments_task.load_state_dict(state["moments_task"]) + local_vars = locals() + # Metrics # Since there could be more exploration critics, the key of the critic is added # to the metrics that the user has selected. @@ -1120,3 +1092,30 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.actor = actor_task.module player.actor_type = "task" test(player, fabric, cfg, log_dir, "zero-shot") + + if not cfg.model_manager.disabled and fabric.is_global_zero: + + def log_models( + run_id: str, experiment_id: str | None = None, run_name: str | None = None + ) -> Dict[str, ModelInfo]: + with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=True) as _: + model_info = {} + unwrapped_models = {} + for k in cfg.model_manager.models.keys(): + if k.startswith("critic_exploration"): + unwrapped_models[k] = unwrap_fabric( + critics_exploration[k.replace("critic_exploration_", "")]["module"] + ) + elif k.startswith("target_critic_exploration"): + unwrapped_models[k] = critics_exploration[k.replace("target_critic_exploration_", "")][ + "target_module" + ] + elif k.startswith("moments_exploration"): + unwrapped_models[k] = unwrap_fabric(moments_exploration[k.replace("moments_exploration_", "")]) + else: + unwrapped_models[k] = unwrap_fabric(local_vars[k]) + model_info[k] = mlflow.pytorch.log_model(unwrapped_models[k], artifact_path=k) + mlflow.log_dict(cfg, "config.json") + return model_info + + register_model(fabric, log_models, cfg) diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index 83debfe3..84d887c0 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -6,9 +6,11 @@ import gymnasium as gym import hydra +import mlflow import numpy as np import torch from lightning.fabric import Fabric +from mlflow.models.model import ModelInfo from tensordict import TensorDict from torch import Tensor from torch.utils.data import BatchSampler @@ -17,14 +19,14 @@ from sheeprl.algos.dreamer_v3.agent import PlayerDV3 from sheeprl.algos.dreamer_v3.dreamer_v3 import train from sheeprl.algos.dreamer_v3.utils import Moments, test -from sheeprl.algos.p2e_dv3.agent import build_models +from sheeprl.algos.p2e_dv3.agent import build_agent from sheeprl.data.buffers import AsyncReplayBuffer from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay +from sheeprl.utils.utils import polynomial_decay, register_model, unwrap_fabric @register_algorithm() @@ -74,9 +76,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # These arguments cannot be changed cfg.env.frame_stack = 1 - # Create TensorBoardLogger. This will create the logger only on the + # Create Logger. This will create the logger only on the # rank-0 process - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -102,7 +104,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r @@ -133,18 +135,20 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): ( world_model, + _, actor_task, critic_task, target_critic_task, actor_exploration, _, - ) = build_models( + ) = build_agent( fabric, actions_dim, is_continuous, cfg, observation_space, state["world_model"], + None, state["actor_task"], state["critic_task"], state["target_critic_task"], @@ -185,6 +189,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): ) moments_task.load_state_dict(state["moments_task"]) + local_vars = locals() + # Metrics aggregator = None if not MetricAggregator.disabled: @@ -492,3 +498,23 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): player.actor = actor_task.module player.actor_type = "task" test(player, fabric, cfg, log_dir, "few-shot") + + if not cfg.model_manager.disabled and fabric.is_global_zero: + + def log_models( + run_id: str, experiment_id: str | None = None, run_name: str | None = None + ) -> Dict[str, ModelInfo]: + with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=True) as _: + model_info = {} + unwrapped_models = {} + models_keys = set(cfg.model_manager.models.keys()) + for k in models_keys: + if "exploration" not in k and k != "ensembles": + unwrapped_models[k] = unwrap_fabric(local_vars[k]) + model_info[k] = mlflow.pytorch.log_model(unwrapped_models[k], artifact_path=k) + else: + cfg.model_manager.models.pop(k, None) + mlflow.log_dict(cfg, "config.json") + return model_info + + register_model(fabric, log_models, cfg) diff --git a/sheeprl/algos/p2e_dv3/utils.py b/sheeprl/algos/p2e_dv3/utils.py index 454d683e..db3dbaf8 100644 --- a/sheeprl/algos/p2e_dv3/utils.py +++ b/sheeprl/algos/p2e_dv3/utils.py @@ -1,4 +1,16 @@ +from __future__ import annotations + +from typing import Any, Dict, Sequence + +import gymnasium as gym +import mlflow +from lightning import Fabric +from mlflow.models.model import ModelInfo + from sheeprl.algos.dreamer_v3.utils import AGGREGATOR_KEYS as AGGREGATOR_KEYS_DV3 +from sheeprl.algos.dreamer_v3.utils import Moments +from sheeprl.algos.p2e_dv3.agent import build_agent +from sheeprl.utils.utils import unwrap_fabric AGGREGATOR_KEYS = { "Rewards/rew_avg", @@ -29,3 +41,103 @@ "Grads/critic_exploration", "Rewards/intrinsic", }.union(AGGREGATOR_KEYS_DV3) +MODELS_TO_REGISTER = { + "world_model", + "ensembles", + "actor_exploration", + "critic_exploration_intrinsic", + "target_critic_exploration_intrinsic", + "moments_exploration_intrinsic", + "critic_exploration_extrinsic", + "target_critic_exploration_extrinsic", + "moments_exploration_extrinsic", + "actor_task", + "critic_task", + "target_critic_task", + "moments_task", +} + + +def log_models_from_checkpoint( + fabric: Fabric, env: gym.Env | gym.Wrapper, cfg: Dict[str, Any], state: Dict[str, Any] +) -> Sequence[ModelInfo]: + # Create the models + is_continuous = isinstance(env.action_space, gym.spaces.Box) + is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) + actions_dim = tuple( + env.action_space.shape + if is_continuous + else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) + ) + ( + world_model, + ensembles, + actor_task, + critic_task, + target_critic_task, + actor_exploration, + critics_exploration, + ) = build_agent( + fabric, + actions_dim, + is_continuous, + cfg, + env.observation_space, + state["world_model"], + state["ensembles"] if "exploration" in cfg.algo.name else None, + state["actor_task"], + state["critic_task"], + state["target_critic_task"], + state["actor_exploration"] if "exploration" in cfg.algo.name else None, + state["critics_exploration"] if "exploration" in cfg.algo.name else None, + ) + moments_task = Moments( + fabric, + cfg.algo.actor.moments.decay, + cfg.algo.actor.moments.max, + cfg.algo.actor.moments.percentile.low, + cfg.algo.actor.moments.percentile.high, + ) + moments_task.load_state_dict(state["moments_task"]) + + if "exploration" in cfg.algo.name: + moments_exploration = { + k: Moments( + fabric, + cfg.algo.actor.moments.decay, + cfg.algo.actor.moments.max, + cfg.algo.actor.moments.percentile.low, + cfg.algo.actor.moments.percentile.high, + ) + for k in critics_exploration.keys() + } + for k, m in moments_exploration.items(): + m.load_state_dict(state[f"moments_exploration_{k}"]) + + # Log the model, create a new run if `cfg.run_id` is None. + model_info = {} + with mlflow.start_run(run_id=cfg.run.id, experiment_id=cfg.experiment.id, run_name=cfg.run.name, nested=True) as _: + model_info["world_model"] = mlflow.pytorch.log_model(unwrap_fabric(world_model), artifact_path="world_model") + model_info["actor_task"] = mlflow.pytorch.log_model(unwrap_fabric(actor_task), artifact_path="actor_task") + model_info["critic_task"] = mlflow.pytorch.log_model(unwrap_fabric(critic_task), artifact_path="critic_task") + model_info["target_critic_task"] = mlflow.pytorch.log_model( + target_critic_task, artifact_path="target_critic_task" + ) + model_info["moments_task"] = mlflow.pytorch.log_model(moments_task, artifact_path="moments_task") + if "exploration" in cfg.algo.name: + model_info["ensembles"] = mlflow.pytorch.log_model(unwrap_fabric(ensembles), artifact_path="ensembles") + model_info["actor_exploration"] = mlflow.pytorch.log_model( + unwrap_fabric(actor_exploration), artifact_path="actor_exploration" + ) + for k in critics_exploration.keys(): + model_info[f"critic_exploration_{k}"] = mlflow.pytorch.log_model( + critics_exploration[k]["module"], artifact_path=f"critic_exploration_{k}" + ) + model_info[f"target_critic_exploration_{k}"] = mlflow.pytorch.log_model( + critics_exploration[k]["target_module"], artifact_path=f"target_critic_exploration_{k}" + ) + model_info[f"moments_exploration_{k}"] = mlflow.pytorch.log_model( + moments_exploration[k], artifact_path=f"moments_exploration_{k}" + ) + mlflow.log_dict(cfg.to_log, "config.json") + return model_info diff --git a/sheeprl/algos/ppo/agent.py b/sheeprl/algos/ppo/agent.py index 4efcea6b..f780098a 100644 --- a/sheeprl/algos/ppo/agent.py +++ b/sheeprl/algos/ppo/agent.py @@ -4,6 +4,8 @@ import gymnasium import torch import torch.nn as nn +from lightning import Fabric +from lightning.fabric.wrappers import _FabricModule from torch import Tensor from torch.distributions import Distribution, Independent, Normal @@ -62,7 +64,7 @@ def forward(self, obs: Dict[str, Tensor]) -> Tensor: class PPOAgent(nn.Module): def __init__( self, - actions_dim: List[int], + actions_dim: Sequence[int], obs_space: gymnasium.spaces.Dict, encoder_cfg: Dict[str, Any], actor_cfg: Dict[str, Any], @@ -194,3 +196,30 @@ def get_greedy_actions(self, obs: Dict[str, Tensor]) -> Sequence[Tensor]: for logits in pre_dist ] ) + + +def build_agent( + fabric: Fabric, + actions_dim: Sequence[int], + is_continuous: bool, + cfg: Dict[str, Any], + obs_space: gymnasium.spaces.Dict, + agent_state: Optional[Dict[str, Tensor]] = None, +) -> _FabricModule: + agent = PPOAgent( + actions_dim=actions_dim, + obs_space=obs_space, + encoder_cfg=cfg.algo.encoder, + actor_cfg=cfg.algo.actor, + critic_cfg=cfg.algo.critic, + cnn_keys=cfg.algo.cnn_keys.encoder, + mlp_keys=cfg.algo.mlp_keys.encoder, + screen_size=cfg.env.screen_size, + distribution_cfg=cfg.distribution, + is_continuous=is_continuous, + ) + if agent_state: + agent.load_state_dict(agent_state) + agent = fabric.setup_module(agent) + + return agent diff --git a/sheeprl/algos/ppo/evaluate.py b/sheeprl/algos/ppo/evaluate.py index 35220f80..1d629938 100644 --- a/sheeprl/algos/ppo/evaluate.py +++ b/sheeprl/algos/ppo/evaluate.py @@ -5,16 +5,16 @@ import gymnasium as gym from lightning import Fabric -from sheeprl.algos.ppo.agent import PPOAgent +from sheeprl.algos.ppo.agent import build_agent from sheeprl.algos.ppo.utils import test from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.registry import register_evaluation @register_evaluation(algorithms=["ppo"]) def evaluate_ppo(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -42,26 +42,13 @@ def evaluate_ppo(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): is_continuous = isinstance(env.action_space, gym.spaces.Box) is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( env.action_space.shape if is_continuous else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) ) # Create the actor and critic models - agent = PPOAgent( - actions_dim=actions_dim, - obs_space=observation_space, - encoder_cfg=cfg.algo.encoder, - actor_cfg=cfg.algo.actor, - critic_cfg=cfg.algo.critic, - cnn_keys=cfg.algo.cnn_keys.encoder, - mlp_keys=cfg.algo.mlp_keys.encoder, - screen_size=cfg.env.screen_size, - distribution_cfg=cfg.distribution, - is_continuous=is_continuous, - ) - agent.load_state_dict(state["agent"]) - agent = fabric.setup_module(agent) + agent = build_agent(fabric, actions_dim, is_continuous, cfg, observation_space, state["agent"]) test(agent, fabric, cfg, log_dir) diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 49e62ece..67ea2d8f 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -7,26 +7,28 @@ import gymnasium as gym import hydra +import mlflow import numpy as np import torch from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule +from mlflow.models.model import ModelInfo from tensordict import TensorDict, make_tensordict from tensordict.tensordict import TensorDictBase from torch import nn from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler from torchmetrics import SumMetric -from sheeprl.algos.ppo.agent import PPOAgent +from sheeprl.algos.ppo.agent import build_agent from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss -from sheeprl.algos.ppo.utils import test +from sheeprl.algos.ppo.utils import normalize_obs, test from sheeprl.data import ReplayBuffer from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay +from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay, register_model, unwrap_fabric def train( @@ -56,10 +58,9 @@ def train( sampler.sampler.set_epoch(epoch) for batch_idxes in sampler: batch = data[batch_idxes] - normalized_obs = { - k: batch[k] / 255 - 0.5 if k in cfg.algo.cnn_keys.encoder else batch[k] - for k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder - } + normalized_obs = normalize_obs( + batch, cfg.algo.cnn_keys.encoder, cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder + ) _, logprobs, entropy, new_values = agent( normalized_obs, torch.split(batch["actions"], agent.actions_dim, dim=-1) ) @@ -130,9 +131,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): state = fabric.load(cfg.checkpoint.resume_from) cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size - # Create TensorBoardLogger. This will create the logger only on the + # Create Logger. This will create the logger only on the # rank-0 process - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -169,35 +170,31 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): is_continuous = isinstance(envs.single_action_space, gym.spaces.Box) is_multidiscrete = isinstance(envs.single_action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( envs.single_action_space.shape if is_continuous else (envs.single_action_space.nvec.tolist() if is_multidiscrete else [envs.single_action_space.n]) ) # Create the actor and critic models - agent = PPOAgent( - actions_dim=actions_dim, - obs_space=observation_space, - encoder_cfg=cfg.algo.encoder, - actor_cfg=cfg.algo.actor, - critic_cfg=cfg.algo.critic, - cnn_keys=cfg.algo.cnn_keys.encoder, - mlp_keys=cfg.algo.mlp_keys.encoder, - screen_size=cfg.env.screen_size, - distribution_cfg=cfg.distribution, - is_continuous=is_continuous, + agent = build_agent( + fabric, + actions_dim, + is_continuous, + cfg, + observation_space, + state["agent"] if cfg.checkpoint.resume_from else None, ) # Define the optimizer optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters()) + local_vars = locals() + # Load the state from the checkpoint if cfg.checkpoint.resume_from: - agent.load_state_dict(state["agent"]) optimizer.load_state_dict(state["optimizer"]) # Setup agent and optimizer with Fabric - agent = fabric.setup_module(agent) optimizer = fabric.setup_optimizers(optimizer) # Create a metric aggregator to log the metrics @@ -276,9 +273,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): with torch.no_grad(): # Sample an action given the observation received by the environment - normalized_obs = { - k: next_obs[k] / 255 - 0.5 if k in cfg.algo.cnn_keys.encoder else next_obs[k] for k in obs_keys - } + normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) actions, logprobs, _, values = agent.module(normalized_obs) if is_continuous: real_actions = torch.cat(actions, -1).cpu().numpy() @@ -349,9 +344,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) with torch.no_grad(): - normalized_obs = { - k: next_obs[k] / 255 - 0.5 if k in cfg.algo.cnn_keys.encoder else next_obs[k] for k in obs_keys - } + normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) next_values = agent.module.get_value(normalized_obs) returns, advantages = gae( rb["rewards"], @@ -452,3 +445,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero: test(agent.module, fabric, cfg, log_dir) + + if not cfg.model_manager.disabled and fabric.is_global_zero: + + def log_models( + run_id: str, experiment_id: str | None = None, run_name: str | None = None + ) -> Dict[str, ModelInfo]: + with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=True) as _: + model_info = {} + unwrapped_models = {} + for k in cfg.model_manager.models.keys(): + unwrapped_models[k] = unwrap_fabric(local_vars[k]) + model_info[k] = mlflow.pytorch.log_model(unwrapped_models[k], artifact_path=k) + mlflow.log_dict(cfg, "config.json") + return model_info + + register_model(fabric, log_models, cfg) diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index 98883e91..c0331515 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -6,12 +6,14 @@ import gymnasium as gym import hydra +import mlflow import numpy as np import torch from lightning.fabric import Fabric from lightning.fabric.plugins.collectives import TorchCollective from lightning.fabric.plugins.collectives.collective import CollectibleGroup from lightning.fabric.strategies import DDPStrategy +from mlflow.models.model import ModelInfo from tensordict import TensorDict from tensordict.tensordict import TensorDictBase, make_tensordict from torch.distributed.algorithms.join import Join @@ -20,14 +22,14 @@ from sheeprl.algos.ppo.agent import PPOAgent from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss -from sheeprl.algos.ppo.utils import test +from sheeprl.algos.ppo.utils import normalize_obs, test from sheeprl.data import ReplayBuffer from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay +from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay, register_model, unwrap_fabric @torch.no_grad() @@ -76,7 +78,7 @@ def player( is_continuous = isinstance(envs.single_action_space, gym.spaces.Box) is_multidiscrete = isinstance(envs.single_action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( envs.single_action_space.shape if is_continuous else (envs.single_action_space.nvec.tolist() if is_multidiscrete else [envs.single_action_space.n]) @@ -101,6 +103,8 @@ def player( } agent = PPOAgent(**agent_args).to(device) + local_vars = locals() + # Broadcast the parameters needed to the trainers to instantiate the PPOAgent world_collective.broadcast_object_list([agent_args], src=0) @@ -191,10 +195,7 @@ def player( with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): with torch.no_grad(): # Sample an action given the observation received by the environment - normalized_obs = { - k: next_obs[k] / 255.0 - 0.5 if k in cfg.algo.cnn_keys.encoder else next_obs[k] - for k in obs_keys - } + normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) actions, logprobs, _, values = agent(normalized_obs) if is_continuous: real_actions = torch.cat(actions, -1).cpu().numpy() @@ -263,9 +264,7 @@ def player( fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) - normalized_obs = { - k: next_obs[k] / 255.0 - 0.5 if k in cfg.algo.cnn_keys.encoder else next_obs[k] for k in obs_keys - } + normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) next_values = agent.get_value(normalized_obs) returns, advantages = gae( rb["rewards"], @@ -348,6 +347,22 @@ def player( if fabric.is_global_zero: test(agent, fabric, cfg, log_dir) + if not cfg.model_manager.disabled and fabric.is_global_zero: + + def log_models( + run_id: str, experiment_id: str | None = None, run_name: str | None = None + ) -> Dict[str, ModelInfo]: + with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=True) as _: + model_info = {} + unwrapped_models = {} + for k in cfg.model_manager.models.keys(): + unwrapped_models[k] = unwrap_fabric(local_vars[k]) + model_info[k] = mlflow.pytorch.log_model(unwrapped_models[k], artifact_path=k) + mlflow.log_dict(cfg, "config.json") + return model_info + + register_model(fabric, log_models, cfg) + def trainer( world_collective: TorchCollective, @@ -477,10 +492,9 @@ def trainer( for _ in range(cfg.algo.update_epochs): for batch_idxes in sampler: batch = data[batch_idxes] - normalized_obs = { - k: batch[k] / 255.0 - 0.5 if k in agent.feature_extractor.cnn_keys else batch[k] - for k in cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder - } + normalized_obs = normalize_obs( + batch, cfg.algo.cnn_keys.encoder, cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder + ) _, logprobs, entropy, new_values = agent( normalized_obs, torch.split(batch["actions"], agent.actions_dim, dim=-1) ) diff --git a/sheeprl/algos/ppo/utils.py b/sheeprl/algos/ppo/utils.py index a923a808..e339c792 100644 --- a/sheeprl/algos/ppo/utils.py +++ b/sheeprl/algos/ppo/utils.py @@ -1,12 +1,22 @@ +from __future__ import annotations + from typing import Any, Dict +import gymnasium as gym +import mlflow +import numpy as np import torch +from git import Sequence from lightning import Fabric +from mlflow.models.model import ModelInfo +from torch import Tensor -from sheeprl.algos.ppo.agent import PPOAgent +from sheeprl.algos.ppo.agent import PPOAgent, build_agent from sheeprl.utils.env import make_env +from sheeprl.utils.utils import unwrap_fabric AGGREGATOR_KEYS = {"Rewards/rew_avg", "Game/ep_len_avg", "Loss/value_loss", "Loss/policy_loss", "Loss/entropy_loss"} +MODELS_TO_REGISTER = {"agent"} @torch.no_grad() @@ -53,3 +63,30 @@ def test(agent: PPOAgent, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): if cfg.metric.log_level > 0: fabric.log_dict({"Test/cumulative_reward": cumulative_rew}, 0) env.close() + + +def normalize_obs( + obs: Dict[str, np.ndarray | Tensor], cnn_keys: Sequence[str], obs_keys: Sequence[str] +) -> Dict[str, np.ndarray | Tensor]: + return {k: obs[k] / 255 - 0.5 if k in cnn_keys else obs[k] for k in obs_keys} + + +def log_models_from_checkpoint( + fabric: Fabric, env: gym.Env | gym.Wrapper, cfg: Dict[str, Any], state: Dict[str, Any] +) -> Sequence[ModelInfo]: + # Create the models + is_continuous = isinstance(env.action_space, gym.spaces.Box) + is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) + actions_dim = tuple( + env.action_space.shape + if is_continuous + else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) + ) + agent = build_agent(fabric, actions_dim, is_continuous, cfg, env.observation_space, state["agent"]) + + # Log the model, create a new run if `cfg.run_id` is None. + model_info = {} + with mlflow.start_run(run_id=cfg.run.id, experiment_id=cfg.experiment.id, run_name=cfg.run.name, nested=True) as _: + model_info["agent"] = mlflow.pytorch.log_model(unwrap_fabric(agent), artifact_path="agent") + mlflow.log_dict(cfg.to_log, "config.json") + return model_info diff --git a/sheeprl/algos/ppo_recurrent/agent.py b/sheeprl/algos/ppo_recurrent/agent.py index b498ec45..ac914e07 100644 --- a/sheeprl/algos/ppo_recurrent/agent.py +++ b/sheeprl/algos/ppo_recurrent/agent.py @@ -4,6 +4,7 @@ import gymnasium import torch import torch.nn as nn +from lightning import Fabric from torch import Tensor from torch.distributions import Independent, Normal @@ -278,7 +279,7 @@ def forward( logprobs (Tensor): the log probabilities of the actions w.r.t. their distributions. entropies (Tensor): the entropies of the actions distributions. values (Tensor): the state values. - hx (Tensor): the new recurrent state. + states (Tuple[Tensor, Tensor]): the new recurrent states (hx, cx). """ embedded_obs = self.feature_extractor(obs) out, states = self.rnn(torch.cat((embedded_obs, prev_actions), dim=-1), prev_states, mask) @@ -286,3 +287,33 @@ def forward( pre_dist = self.get_pre_dist(out) actions, logprobs, entropies = self.get_sampled_actions(pre_dist, actions) return actions, logprobs, entropies, values, states + + +def build_agent( + fabric: Fabric, + actions_dim: Sequence[int], + is_continuous: bool, + cfg: Dict[str, Any], + obs_space: gymnasium.spaces.Dict, + agent_state: Optional[Dict[str, Tensor]] = None, +) -> RecurrentPPOAgent: + agent = RecurrentPPOAgent( + actions_dim=actions_dim, + obs_space=obs_space, + encoder_cfg=cfg.algo.encoder, + rnn_cfg=cfg.algo.rnn, + actor_cfg=cfg.algo.actor, + critic_cfg=cfg.algo.critic, + cnn_keys=cfg.algo.cnn_keys.encoder, + mlp_keys=cfg.algo.mlp_keys.encoder, + is_continuous=is_continuous, + distribution_cfg=cfg.distribution, + num_envs=cfg.env.num_envs, + screen_size=cfg.env.screen_size, + device=fabric.device, + ) + if agent_state: + agent.load_state_dict(agent_state) + agent = fabric.setup_module(agent) + + return agent diff --git a/sheeprl/algos/ppo_recurrent/evaluate.py b/sheeprl/algos/ppo_recurrent/evaluate.py index 43321fd0..919a6c86 100644 --- a/sheeprl/algos/ppo_recurrent/evaluate.py +++ b/sheeprl/algos/ppo_recurrent/evaluate.py @@ -5,16 +5,16 @@ import gymnasium as gym from lightning import Fabric -from sheeprl.algos.ppo_recurrent.agent import RecurrentPPOAgent +from sheeprl.algos.ppo_recurrent.agent import build_agent from sheeprl.algos.ppo_recurrent.utils import test from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.registry import register_evaluation @register_evaluation(algorithms="ppo_recurrent") def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -42,27 +42,11 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): is_continuous = isinstance(env.action_space, gym.spaces.Box) is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( env.action_space.shape if is_continuous else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) ) # Create the actor and critic models - agent = RecurrentPPOAgent( - actions_dim=actions_dim, - obs_space=observation_space, - encoder_cfg=cfg.algo.encoder, - rnn_cfg=cfg.algo.rnn, - actor_cfg=cfg.algo.actor, - critic_cfg=cfg.algo.critic, - cnn_keys=cfg.algo.cnn_keys.encoder, - mlp_keys=cfg.algo.mlp_keys.encoder, - is_continuous=is_continuous, - distribution_cfg=cfg.distribution, - num_envs=cfg.env.num_envs, - screen_size=cfg.env.screen_size, - device=fabric.device, - ) - agent.load_state_dict(state["agent"]) - agent = fabric.setup_module(agent) + agent = build_agent(fabric, actions_dim, is_continuous, cfg, observation_space, state["agent"]) test(agent, fabric, cfg, log_dir) diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index e270c673..6adede10 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -9,9 +9,11 @@ import gymnasium as gym import hydra +import mlflow import numpy as np import torch from lightning.fabric import Fabric +from mlflow.models.model import ModelInfo from tensordict import TensorDict, pad_sequence from tensordict.tensordict import TensorDictBase from torch.distributed.algorithms.join import Join @@ -19,15 +21,16 @@ from torchmetrics import SumMetric from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss -from sheeprl.algos.ppo_recurrent.agent import RecurrentPPOAgent +from sheeprl.algos.ppo.utils import normalize_obs +from sheeprl.algos.ppo_recurrent.agent import RecurrentPPOAgent, build_agent from sheeprl.algos.ppo_recurrent.utils import test from sheeprl.data.buffers import ReplayBuffer from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay +from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay, register_model, unwrap_fabric def train( @@ -136,9 +139,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): state = fabric.load(cfg.checkpoint.resume_from) cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size - # Create TensorBoardLogger. This will create the logger only on the + # Create Logger. This will create the logger only on the # rank-0 process - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -175,39 +178,31 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): is_continuous = isinstance(envs.single_action_space, gym.spaces.Box) is_multidiscrete = isinstance(envs.single_action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( envs.single_action_space.shape if is_continuous else (envs.single_action_space.nvec.tolist() if is_multidiscrete else [envs.single_action_space.n]) ) # Define the agent and the optimizer - agent = RecurrentPPOAgent( - actions_dim=actions_dim, - obs_space=observation_space, - encoder_cfg=cfg.algo.encoder, - rnn_cfg=cfg.algo.rnn, - actor_cfg=cfg.algo.actor, - critic_cfg=cfg.algo.critic, - cnn_keys=cfg.algo.cnn_keys.encoder, - mlp_keys=cfg.algo.mlp_keys.encoder, - is_continuous=is_continuous, - distribution_cfg=cfg.distribution, - num_envs=cfg.env.num_envs, - screen_size=cfg.env.screen_size, - device=device, + agent = build_agent( + fabric, + actions_dim, + is_continuous, + cfg, + observation_space, + state["agent"] if cfg.checkpoint.resume_from else None, ) optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters()) # Load the state from the checkpoint if cfg.checkpoint.resume_from: - agent.load_state_dict(state["agent"]) optimizer.load_state_dict(state["optimizer"]) - # Setup agent and optimizer with Fabric - agent = fabric.setup_module(agent) optimizer = fabric.setup_optimizers(optimizer) + local_vars = locals() + # Create a metric aggregator to log the metrics aggregator = None if not MetricAggregator.disabled: @@ -271,7 +266,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): elif k in cfg.algo.mlp_keys.encoder: torch_obs = torch_obs.float() step_data[k] = torch_obs[None] # [Seq_len, Batch_size, D] --> [1, num_envs, D] - obs[k] = torch_obs + obs[k] = torch_obs[None] # Get the resetted recurrent states from the agent prev_states = agent.initial_states @@ -286,10 +281,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): with torch.no_grad(): # Sample an action given the observation received by the environment - normalized_obs = { - k: obs[k][None] / 255.0 - 0.5 if k in cfg.algo.cnn_keys.encoder else obs[k][None] - for k in obs_keys - } # [Seq_len, Batch_size, D] --> [1, num_envs, D] + # [Seq_len, Batch_size, D] --> [1, num_envs, D] + normalized_obs = normalize_obs(obs, cfg.algo.cnn_keys.encoder, obs_keys) actions, logprobs, _, values, states = agent.module( normalized_obs, prev_actions=prev_actions, prev_states=prev_states ) @@ -358,7 +351,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): elif k in cfg.algo.mlp_keys.encoder: torch_obs = torch.as_tensor(next_obs[k], device=device, dtype=torch.float32) step_data[k] = torch_obs[None] - obs[k] = torch_obs + obs[k] = torch_obs[None] # Reset the states if the episode is done if cfg.algo.reset_recurrent_state_on_done: @@ -378,9 +371,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) with torch.no_grad(): - normalized_obs = { - k: obs[k][None] / 255.0 - 0.5 if k in cfg.algo.cnn_keys.encoder else obs[k][None] for k in obs_keys - } + normalized_obs = normalize_obs(obs, cfg.algo.cnn_keys.encoder, obs_keys) feat = agent.module.feature_extractor(normalized_obs) rnn_out, _ = agent.module.rnn(torch.cat((feat, actions), dim=-1), states) next_values = agent.module.get_values(rnn_out) @@ -500,3 +491,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero: test(agent.module, fabric, cfg, log_dir) + + if not cfg.model_manager.disabled and fabric.is_global_zero: + + def log_models( + run_id: str, experiment_id: str | None = None, run_name: str | None = None + ) -> Dict[str, ModelInfo]: + with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=True) as _: + model_info = {} + unwrapped_models = {} + for k in cfg.model_manager.models.keys(): + unwrapped_models[k] = unwrap_fabric(local_vars[k]) + model_info[k] = mlflow.pytorch.log_model(unwrapped_models[k], artifact_path=k) + mlflow.log_dict(cfg, "config.json") + return model_info + + register_model(fabric, log_models, cfg) diff --git a/sheeprl/algos/ppo_recurrent/utils.py b/sheeprl/algos/ppo_recurrent/utils.py index 18b0cc4e..7ab7a9d4 100644 --- a/sheeprl/algos/ppo_recurrent/utils.py +++ b/sheeprl/algos/ppo_recurrent/utils.py @@ -1,16 +1,21 @@ -from typing import TYPE_CHECKING, Any, Dict +from __future__ import annotations +from typing import Any, Dict, Sequence + +import gymnasium as gym +import mlflow import torch from lightning import Fabric - -from sheeprl.utils.env import make_env - -if TYPE_CHECKING: - from sheeprl.algos.ppo_recurrent.agent import RecurrentPPOAgent +from mlflow.models.model import ModelInfo from sheeprl.algos.ppo.utils import AGGREGATOR_KEYS as ppo_aggregator_keys +from sheeprl.algos.ppo.utils import MODELS_TO_REGISTER as ppo_models_to_register +from sheeprl.algos.ppo_recurrent.agent import RecurrentPPOAgent, build_agent +from sheeprl.utils.env import make_env +from sheeprl.utils.utils import unwrap_fabric AGGREGATOR_KEYS = ppo_aggregator_keys +MODELS_TO_REGISTER = ppo_models_to_register @torch.no_grad() @@ -66,3 +71,24 @@ def test(agent: "RecurrentPPOAgent", fabric: Fabric, cfg: Dict[str, Any], log_di if cfg.metric.log_level > 0: fabric.log_dict({"Test/cumulative_reward": cumulative_rew}, 0) env.close() + + +def log_models_from_checkpoint( + fabric: Fabric, env: gym.Env | gym.Wrapper, cfg: Dict[str, Any], state: Dict[str, Any] +) -> Sequence[ModelInfo]: + # Create the models + is_continuous = isinstance(env.action_space, gym.spaces.Box) + is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) + actions_dim = tuple( + env.action_space.shape + if is_continuous + else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) + ) + agent = build_agent(fabric, actions_dim, is_continuous, cfg, env.observation_space, state["agent"]) + + # Log the model, create a new run if `cfg.run_id` is None. + model_info = {} + with mlflow.start_run(run_id=cfg.run.id, experiment_id=cfg.experiment.id, run_name=cfg.run.name, nested=True) as _: + model_info["agent"] = mlflow.pytorch.log_model(unwrap_fabric(agent), artifact_path="agent") + mlflow.log_dict(cfg.to_log, "config.json") + return model_info diff --git a/sheeprl/algos/sac/agent.py b/sheeprl/algos/sac/agent.py index 4e52a08f..a6687617 100644 --- a/sheeprl/algos/sac/agent.py +++ b/sheeprl/algos/sac/agent.py @@ -1,8 +1,11 @@ import copy -from typing import Any, Dict, Sequence, SupportsFloat, Tuple, Union +from math import prod +from typing import Any, Dict, Optional, Sequence, SupportsFloat, Tuple, Union +import gymnasium import torch import torch.nn as nn +from lightning import Fabric from lightning.fabric.wrappers import _FabricModule from numpy.typing import NDArray from torch import Tensor @@ -273,3 +276,34 @@ def get_next_target_q_values(self, next_obs: Tensor, rewards: Tensor, dones: Ten def qfs_target_ema(self) -> None: for param, target_param in zip(self.qfs_unwrapped.parameters(), self.qfs_target.parameters()): target_param.data.copy_(self._tau * param.data + (1 - self._tau) * target_param.data) + + +def build_agent( + fabric: Fabric, + cfg: Dict[str, Any], + obs_space: gymnasium.spaces.Dict, + action_space: gymnasium.spaces.Box, + agent_state: Optional[Dict[str, Tensor]] = None, +) -> SACAgent: + act_dim = prod(action_space.shape) + obs_dim = sum([prod(obs_space[k].shape) for k in cfg.algo.mlp_keys.encoder]) + actor = SACActor( + observation_dim=obs_dim, + action_dim=act_dim, + distribution_cfg=cfg.distribution, + hidden_size=cfg.algo.actor.hidden_size, + action_low=action_space.low, + action_high=action_space.high, + ) + critics = [ + SACCritic(observation_dim=obs_dim + act_dim, hidden_size=cfg.algo.critic.hidden_size, num_critics=1) + for _ in range(cfg.algo.critic.n) + ] + target_entropy = -act_dim + agent = SACAgent(actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device) + if agent_state: + agent.load_state_dict(agent_state) + agent.actor = fabric.setup_module(agent.actor) + agent.critics = [fabric.setup_module(critic) for critic in agent.critics] + + return agent diff --git a/sheeprl/algos/sac/evaluate.py b/sheeprl/algos/sac/evaluate.py index 3fbbdc35..68f05e46 100644 --- a/sheeprl/algos/sac/evaluate.py +++ b/sheeprl/algos/sac/evaluate.py @@ -1,21 +1,20 @@ from __future__ import annotations -from math import prod from typing import Any, Dict import gymnasium as gym from lightning import Fabric -from sheeprl.algos.sac.agent import SACActor, SACAgent, SACCritic +from sheeprl.algos.sac.agent import build_agent from sheeprl.algos.sac.utils import test from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.registry import register_evaluation @register_evaluation(algorithms=["sac", "sac_decoupled"]) def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -45,22 +44,5 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): ) fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) - act_dim = prod(action_space.shape) - obs_dim = sum([prod(observation_space[k].shape) for k in cfg.algo.mlp_keys.encoder]) - actor = SACActor( - observation_dim=obs_dim, - action_dim=act_dim, - distribution_cfg=cfg.distribution, - hidden_size=cfg.algo.actor.hidden_size, - action_low=action_space.low, - action_high=action_space.high, - ) - critics = [ - SACCritic(observation_dim=obs_dim + act_dim, hidden_size=cfg.algo.critic.hidden_size, num_critics=1) - for _ in range(cfg.algo.critic.n) - ] - target_entropy = -act_dim - agent = SACAgent(actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device) - agent.load_state_dict(state["agent"]) - agent = fabric.setup_module(agent) + agent = build_agent(fabric, cfg, observation_space, action_space, state["agent"]) test(agent.actor, fabric, cfg, log_dir) diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 96894664..d73f70a5 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -3,15 +3,16 @@ import copy import os import warnings -from math import prod from typing import Any, Dict, Optional import gymnasium as gym import hydra +import mlflow import numpy as np import torch from lightning.fabric import Fabric from lightning.fabric.plugins.collectives.collective import CollectibleGroup +from mlflow.models.model import ModelInfo from tensordict import TensorDict, make_tensordict from tensordict.tensordict import TensorDictBase from torch.optim import Optimizer @@ -19,15 +20,16 @@ from torch.utils.data.sampler import BatchSampler from torchmetrics import SumMetric -from sheeprl.algos.sac.agent import SACActor, SACAgent, SACCritic +from sheeprl.algos.sac.agent import SACAgent, build_agent from sheeprl.algos.sac.loss import critic_loss, entropy_loss, policy_loss from sheeprl.algos.sac.utils import test from sheeprl.data.buffers import ReplayBuffer from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer +from sheeprl.utils.utils import register_model, unwrap_fabric def train( @@ -104,9 +106,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): warnings.warn("SAC algorithm cannot allow to use images as observations, the CNN keys will be ignored") cfg.algo.cnn_keys.encoder = [] - # Create TensorBoardLogger. This will create the logger only on the + # Create Logger. This will create the logger only on the # rank-0 process - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -145,26 +147,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) # Define the agent and the optimizer and setup sthem with Fabric - act_dim = prod(action_space.shape) - obs_dim = sum([prod(observation_space[k].shape) for k in cfg.algo.mlp_keys.encoder]) - actor = SACActor( - observation_dim=obs_dim, - action_dim=act_dim, - distribution_cfg=cfg.distribution, - hidden_size=cfg.algo.actor.hidden_size, - action_low=action_space.low, - action_high=action_space.high, + agent = build_agent( + fabric, cfg, observation_space, action_space, state["agent"] if cfg.checkpoint.resume_from else None ) - critics = [ - SACCritic(observation_dim=obs_dim + act_dim, hidden_size=cfg.algo.critic.hidden_size, num_critics=1) - for _ in range(cfg.algo.critic.n) - ] - target_entropy = -act_dim - agent = SACAgent(actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device) - if cfg.checkpoint.resume_from: - agent.load_state_dict(state["agent"]) - agent.actor = fabric.setup_module(agent.actor) - agent.critics = [fabric.setup_module(critic) for critic in agent.critics] # Optimizers qf_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.qfs.parameters()) @@ -178,6 +163,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): qf_optimizer, actor_optimizer, alpha_optimizer ) + local_vars = locals() + # Create a metric aggregator to log the metrics aggregator = None if not MetricAggregator.disabled: @@ -396,3 +383,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero: test(agent.actor.module, fabric, cfg, log_dir) + + if not cfg.model_manager.disabled and fabric.is_global_zero: + + def log_models( + run_id: str, experiment_id: str | None = None, run_name: str | None = None + ) -> Dict[str, ModelInfo]: + with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=True) as _: + model_info = {} + unwrapped_models = {} + for k in cfg.model_manager.models.keys(): + unwrapped_models[k] = unwrap_fabric(local_vars[k]) + model_info[k] = mlflow.pytorch.log_model(unwrapped_models[k], artifact_path=k) + mlflow.log_dict(cfg, "config.json") + return model_info + + register_model(fabric, log_models, cfg) diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index 2752057f..34d2dfb5 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -7,18 +7,20 @@ import gymnasium as gym import hydra +import mlflow import numpy as np import torch from lightning.fabric import Fabric from lightning.fabric.plugins.collectives import TorchCollective from lightning.fabric.plugins.collectives.collective import CollectibleGroup from lightning.fabric.strategies import DDPStrategy +from mlflow.models.model import ModelInfo from tensordict import TensorDict, make_tensordict from tensordict.tensordict import TensorDictBase from torch.utils.data.sampler import BatchSampler from torchmetrics import SumMetric -from sheeprl.algos.sac.agent import SACActor, SACAgent, SACCritic +from sheeprl.algos.sac.agent import SACActor, SACAgent, SACCritic, build_agent from sheeprl.algos.sac.sac import train from sheeprl.algos.sac.utils import test from sheeprl.data.buffers import ReplayBuffer @@ -27,6 +29,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer +from sheeprl.utils.utils import register_model, unwrap_fabric @torch.no_grad() @@ -305,6 +308,37 @@ def player( if fabric.is_global_zero: test(actor, fabric, cfg, log_dir) + if not cfg.model_manager.disabled and fabric.is_global_zero: + critics = [ + SACCritic(observation_dim=obs_dim + act_dim, hidden_size=cfg.algo.critic.hidden_size, num_critics=1) + for _ in range(cfg.algo.critic.n) + ] + target_entropy = -act_dim + agent = SACAgent( + actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device + ) + flattened_parameters = torch.empty_like( + torch.nn.utils.convert_parameters.parameters_to_vector(agent.parameters()), device=device + ) + player_trainer_collective.broadcast(flattened_parameters, src=1) + torch.nn.utils.convert_parameters.vector_to_parameters(flattened_parameters, agent.parameters()) + + local_vars = locals() + + def log_models( + run_id: str, experiment_id: str | None = None, run_name: str | None = None + ) -> Dict[str, ModelInfo]: + with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=True) as _: + model_info = {} + unwrapped_models = {} + for k in cfg.model_manager.models.keys(): + unwrapped_models[k] = unwrap_fabric(local_vars[k]) + model_info[k] = mlflow.pytorch.log_model(unwrapped_models[k], artifact_path=k) + mlflow.log_dict(cfg, "config.json") + return model_info + + register_model(fabric, log_models, cfg) + def trainer( world_collective: TorchCollective, @@ -340,27 +374,13 @@ def trainer( assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" # Define the agent and the optimizer and setup them with Fabric - act_dim = prod(envs.single_action_space.shape) - obs_dim = sum([prod(envs.single_observation_space[k].shape) for k in cfg.algo.mlp_keys.encoder]) - - actor = SACActor( - observation_dim=obs_dim, - action_dim=act_dim, - distribution_cfg=cfg.distribution, - hidden_size=cfg.algo.actor.hidden_size, - action_low=envs.single_action_space.low, - action_high=envs.single_action_space.high, + agent = build_agent( + fabric, + cfg, + envs.single_observation_space, + envs.single_action_space, + state["agent"] if cfg.checkpoint.resume_from else None, ) - critics = [ - SACCritic(observation_dim=obs_dim + act_dim, hidden_size=cfg.algo.critic.hidden_size, num_critics=1) - for _ in range(cfg.algo.critic.n) - ] - target_entropy = -act_dim - agent = SACAgent(actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device) - if cfg.checkpoint.resume_from: - agent.load_state_dict(state["agent"]) - agent.actor = fabric.setup_module(agent.actor) - agent.critics = [fabric.setup_module(critic) for critic in agent.critics] # Optimizers qf_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.qfs.parameters()) @@ -429,6 +449,10 @@ def trainer( ckpt_path=ckpt_path, state=state, ) + if not cfg.model_manager.disabled: + player_trainer_collective.broadcast( + torch.nn.utils.convert_parameters.parameters_to_vector(agent.parameters()), src=1 + ) return data = make_tensordict(data, device=device) sampler = BatchSampler(range(len(data)), batch_size=cfg.algo.per_rank_batch_size, drop_last=False) diff --git a/sheeprl/algos/sac/utils.py b/sheeprl/algos/sac/utils.py index d10122e2..97b090dc 100644 --- a/sheeprl/algos/sac/utils.py +++ b/sheeprl/algos/sac/utils.py @@ -1,10 +1,16 @@ -from typing import Any, Dict +from __future__ import annotations +from typing import Any, Dict, Sequence + +import gymnasium as gym +import mlflow import torch from lightning import Fabric +from mlflow.models.model import ModelInfo -from sheeprl.algos.sac.agent import SACActor +from sheeprl.algos.sac.agent import SACActor, build_agent from sheeprl.utils.env import make_env +from sheeprl.utils.utils import unwrap_fabric AGGREGATOR_KEYS = { "Rewards/rew_avg", @@ -13,6 +19,7 @@ "Loss/policy_loss", "Loss/alpha_loss", } +MODELS_TO_REGISTER = {"agent"} @torch.no_grad() @@ -47,3 +54,17 @@ def test(actor: SACActor, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): if cfg.metric.log_level > 0: fabric.logger.log_metrics({"Test/cumulative_reward": cumulative_rew}, 0) env.close() + + +def log_models_from_checkpoint( + fabric: Fabric, env: gym.Env | gym.Wrapper, cfg: Dict[str, Any], state: Dict[str, Any] +) -> Sequence[ModelInfo]: + # Create the models + agent = build_agent(fabric, cfg, env.observation_space, env.action_space, state["agent"]) + + # Log the model, create a new run if `cfg.run_id` is None. + model_info = {} + with mlflow.start_run(run_id=cfg.run.id, experiment_id=cfg.experiment.id, run_name=cfg.run.name, nested=True) as _: + model_info["agent"] = mlflow.pytorch.log_model(unwrap_fabric(agent), artifact_path="agent") + mlflow.log_dict(cfg.to_log, "config.json") + return model_info diff --git a/sheeprl/algos/sac_ae/agent.py b/sheeprl/algos/sac_ae/agent.py index 0265cbe4..df91720a 100644 --- a/sheeprl/algos/sac_ae/agent.py +++ b/sheeprl/algos/sac_ae/agent.py @@ -1,16 +1,18 @@ import copy from math import prod -from typing import Any, Dict, List, Sequence, SupportsFloat, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, SupportsFloat, Tuple, Union +import gymnasium import numpy as np import torch import torch.nn as nn +from lightning import Fabric from lightning.fabric.wrappers import _FabricModule from numpy.typing import NDArray from torch import Size, Tensor from sheeprl.algos.sac_ae.utils import weight_init -from sheeprl.models.models import CNN, MLP, DeCNN, MultiEncoder +from sheeprl.models.models import CNN, MLP, DeCNN, MultiDecoder, MultiEncoder LOG_STD_MAX = 2 LOG_STD_MIN = -10 @@ -448,3 +450,115 @@ def critic_encoder_target_ema(self) -> None: self.critic_unwrapped.encoder.parameters(), self.critic_target.encoder.parameters() ): target_param.data.copy_(self._encoder_tau * param.data + (1 - self._encoder_tau) * target_param.data) + + +def build_agent( + fabric: Fabric, + cfg: Dict[str, Any], + obs_space: gymnasium.spaces.Dict, + action_space: gymnasium.spaces.Box, + agent_state: Optional[Dict[str, Tensor]] = None, + encoder_state: Optional[Dict[str, Tensor]] = None, + decoder_sate: Optional[Dict[str, Tensor]] = None, +) -> Tuple[SACAEAgent, _FabricModule, _FabricModule]: + act_dim = prod(action_space.shape) + target_entropy = -act_dim + + # Define the encoder and decoder and setup them with fabric. + # Then we will set the critic encoder and actor decoder as the unwrapped encoder module: + # we do not need it wrapped with the strategy inside actor and critic + cnn_channels = [prod(obs_space[k].shape[:-2]) for k in cfg.algo.cnn_keys.encoder] + mlp_dims = [obs_space[k].shape[0] for k in cfg.algo.mlp_keys.encoder] + cnn_encoder = ( + CNNEncoder( + in_channels=sum(cnn_channels), + features_dim=cfg.algo.encoder.features_dim, + keys=cfg.algo.cnn_keys.encoder, + screen_size=cfg.env.screen_size, + cnn_channels_multiplier=cfg.algo.encoder.cnn_channels_multiplier, + ) + if cfg.algo.cnn_keys.encoder is not None and len(cfg.algo.cnn_keys.encoder) > 0 + else None + ) + mlp_encoder = ( + MLPEncoder( + sum(mlp_dims), + cfg.algo.mlp_keys.encoder, + cfg.algo.encoder.dense_units, + cfg.algo.encoder.mlp_layers, + eval(cfg.algo.encoder.dense_act), + cfg.algo.encoder.layer_norm, + ) + if cfg.algo.mlp_keys.encoder is not None and len(cfg.algo.mlp_keys.encoder) > 0 + else None + ) + encoder = MultiEncoder(cnn_encoder, mlp_encoder) + cnn_decoder = ( + CNNDecoder( + cnn_encoder.conv_output_shape, + features_dim=encoder.output_dim, + keys=cfg.algo.cnn_keys.decoder, + channels=cnn_channels, + screen_size=cfg.env.screen_size, + cnn_channels_multiplier=cfg.algo.decoder.cnn_channels_multiplier, + ) + if cfg.algo.cnn_keys.decoder is not None and len(cfg.algo.cnn_keys.decoder) > 0 + else None + ) + mlp_decoder = ( + MLPDecoder( + encoder.output_dim, + mlp_dims, + cfg.algo.mlp_keys.decoder, + cfg.algo.decoder.dense_units, + cfg.algo.decoder.mlp_layers, + eval(cfg.algo.decoder.dense_act), + cfg.algo.decoder.layer_norm, + ) + if cfg.algo.mlp_keys.decoder is not None and len(cfg.algo.mlp_keys.decoder) > 0 + else None + ) + decoder = MultiDecoder(cnn_decoder, mlp_decoder) + if encoder_state: + encoder.load_state_dict(encoder_state) + if decoder_sate: + decoder.load_state_dict(decoder_sate) + + # Setup actor and critic. Those will initialize with orthogonal weights + # both the actor and critic + actor = SACAEContinuousActor( + encoder=copy.deepcopy(encoder), + action_dim=act_dim, + distribution_cfg=cfg.distribution, + hidden_size=cfg.algo.actor.hidden_size, + action_low=action_space.low, + action_high=action_space.high, + ) + qfs = [ + SACAEQFunction( + input_dim=encoder.output_dim, action_dim=act_dim, hidden_size=cfg.algo.critic.hidden_size, output_dim=1 + ) + for _ in range(cfg.algo.critic.n) + ] + critic = SACAECritic(encoder=encoder, qfs=qfs) + + # The agent will tied convolutional and linear weights between the encoder actor and critic + agent = SACAEAgent( + actor, + critic, + target_entropy, + alpha=cfg.algo.alpha.alpha, + tau=cfg.algo.tau, + encoder_tau=cfg.algo.encoder.tau, + device=fabric.device, + ) + + if agent_state: + agent.load_state_dict(agent_state) + + encoder = fabric.setup_module(encoder) + decoder = fabric.setup_module(decoder) + agent.actor = fabric.setup_module(agent.actor) + agent.critic = fabric.setup_module(agent.critic) + + return agent, encoder, decoder diff --git a/sheeprl/algos/sac_ae/evaluate.py b/sheeprl/algos/sac_ae/evaluate.py index fc87e50c..9489f4ab 100644 --- a/sheeprl/algos/sac_ae/evaluate.py +++ b/sheeprl/algos/sac_ae/evaluate.py @@ -1,30 +1,20 @@ from __future__ import annotations -import copy -from math import prod from typing import Any, Dict import gymnasium as gym from lightning import Fabric -from sheeprl.algos.sac_ae.agent import ( - CNNEncoder, - MLPEncoder, - SACAEAgent, - SACAEContinuousActor, - SACAECritic, - SACAEQFunction, -) +from sheeprl.algos.sac_ae.agent import SACAEAgent, build_agent from sheeprl.algos.sac_ae.utils import test_sac_ae -from sheeprl.models.models import MultiEncoder from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.registry import register_evaluation @register_evaluation(algorithms="sac_ae") def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -47,68 +37,8 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) - act_dim = prod(action_space.shape) - target_entropy = -act_dim - - # Define the encoder and decoder and setup them with fabric. - # Then we will set the critic encoder and actor decoder as the unwrapped encoder module: - # we do not need it wrapped with the strategy inside actor and critic - cnn_channels = [prod(observation_space[k].shape[:-2]) for k in cfg.algo.cnn_keys.encoder] - mlp_dims = [observation_space[k].shape[0] for k in cfg.algo.mlp_keys.encoder] - cnn_encoder = ( - CNNEncoder( - in_channels=sum(cnn_channels), - features_dim=cfg.algo.encoder.features_dim, - keys=cfg.algo.cnn_keys.encoder, - screen_size=cfg.env.screen_size, - cnn_channels_multiplier=cfg.algo.encoder.cnn_channels_multiplier, - ) - if cfg.algo.cnn_keys.encoder is not None and len(cfg.algo.cnn_keys.encoder) > 0 - else None - ) - mlp_encoder = ( - MLPEncoder( - sum(mlp_dims), - cfg.algo.mlp_keys.encoder, - cfg.algo.encoder.dense_units, - cfg.algo.encoder.mlp_layers, - eval(cfg.algo.encoder.dense_act), - cfg.algo.encoder.layer_norm, - ) - if cfg.algo.mlp_keys.encoder is not None and len(cfg.algo.mlp_keys.encoder) > 0 - else None - ) - encoder = MultiEncoder(cnn_encoder, mlp_encoder) - encoder.load_state_dict(state["encoder"]) - - # Setup actor and critic. Those will initialize with orthogonal weights - # both the actor and critic - actor = SACAEContinuousActor( - encoder=copy.deepcopy(encoder), - action_dim=act_dim, - distribution_cfg=cfg.distribution, - hidden_size=cfg.algo.actor.hidden_size, - action_low=action_space.low, - action_high=action_space.high, - ) - qfs = [ - SACAEQFunction( - input_dim=encoder.output_dim, action_dim=act_dim, hidden_size=cfg.algo.critic.hidden_size, output_dim=1 - ) - for _ in range(cfg.algo.critic.n) - ] - critic = SACAECritic(encoder=encoder, qfs=qfs) - - # The agent will tied convolutional and linear weights between the encoder actor and critic - agent = SACAEAgent( - actor, - critic, - target_entropy, - alpha=cfg.algo.alpha.alpha, - tau=cfg.algo.tau, - encoder_tau=cfg.algo.encoder.tau, - device=fabric.device, + agent: SACAEAgent + agent, _, _ = build_agent( + fabric, cfg, observation_space, action_space, state["agent"], state["encoder"], state["decoder"] ) - agent.load_state_dict(state["agent"]) - agent.actor = fabric.setup_module(agent.actor) test_sac_ae(agent.actor, fabric, cfg, log_dir) diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index 9e18abbe..d93131b8 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -4,17 +4,18 @@ import os import time import warnings -from math import prod from typing import Any, Dict, Optional, Union import gymnasium as gym import hydra +import mlflow import numpy as np import torch import torch.nn.functional as F from lightning.fabric import Fabric from lightning.fabric.plugins.collectives.collective import CollectibleGroup from lightning.fabric.wrappers import _FabricModule +from mlflow.models.model import ModelInfo from tensordict import TensorDict, make_tensordict from tensordict.tensordict import TensorDictBase from torch.optim import Optimizer @@ -23,24 +24,16 @@ from torchmetrics import SumMetric from sheeprl.algos.sac.loss import critic_loss, entropy_loss, policy_loss -from sheeprl.algos.sac_ae.agent import ( - CNNDecoder, - CNNEncoder, - MLPDecoder, - MLPEncoder, - SACAEAgent, - SACAEContinuousActor, - SACAECritic, - SACAEQFunction, -) +from sheeprl.algos.sac_ae.agent import SACAEAgent, build_agent from sheeprl.algos.sac_ae.utils import preprocess_obs, test_sac_ae from sheeprl.data.buffers import ReplayBuffer from sheeprl.models.models import MultiDecoder, MultiEncoder from sheeprl.utils.env import make_env -from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer +from sheeprl.utils.utils import register_model, unwrap_fabric def train( @@ -156,9 +149,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # These arguments cannot be changed cfg.env.screen_size = 64 - # Create TensorBoardLogger. This will create the logger only on the + # Create Logger. This will create the logger only on the # rank-0 process - logger = create_tensorboard_logger(fabric, cfg) + logger = get_logger(fabric, cfg) if logger and fabric.is_global_zero: fabric._loggers = [logger] fabric.logger.log_hyperparams(cfg) @@ -206,95 +199,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) # Define the agent and the optimizer and setup them with Fabric - act_dim = prod(envs.single_action_space.shape) - target_entropy = -act_dim - - # Define the encoder and decoder and setup them with fabric. - # Then we will set the critic encoder and actor decoder as the unwrapped encoder module: - # we do not need it wrapped with the strategy inside actor and critic - cnn_channels = [prod(envs.single_observation_space[k].shape[:-2]) for k in cfg.algo.cnn_keys.encoder] - mlp_dims = [envs.single_observation_space[k].shape[0] for k in cfg.algo.mlp_keys.encoder] - cnn_encoder = ( - CNNEncoder( - in_channels=sum(cnn_channels), - features_dim=cfg.algo.encoder.features_dim, - keys=cfg.algo.cnn_keys.encoder, - screen_size=cfg.env.screen_size, - cnn_channels_multiplier=cfg.algo.encoder.cnn_channels_multiplier, - ) - if cfg.algo.cnn_keys.encoder is not None and len(cfg.algo.cnn_keys.encoder) > 0 - else None - ) - mlp_encoder = ( - MLPEncoder( - sum(mlp_dims), - cfg.algo.mlp_keys.encoder, - cfg.algo.encoder.dense_units, - cfg.algo.encoder.mlp_layers, - eval(cfg.algo.encoder.dense_act), - cfg.algo.encoder.layer_norm, - ) - if cfg.algo.mlp_keys.encoder is not None and len(cfg.algo.mlp_keys.encoder) > 0 - else None - ) - encoder = MultiEncoder(cnn_encoder, mlp_encoder) - cnn_decoder = ( - CNNDecoder( - cnn_encoder.conv_output_shape, - features_dim=encoder.output_dim, - keys=cfg.algo.cnn_keys.decoder, - channels=cnn_channels, - screen_size=cfg.env.screen_size, - cnn_channels_multiplier=cfg.algo.decoder.cnn_channels_multiplier, - ) - if cfg.algo.cnn_keys.decoder is not None and len(cfg.algo.cnn_keys.decoder) > 0 - else None - ) - mlp_decoder = ( - MLPDecoder( - encoder.output_dim, - mlp_dims, - cfg.algo.mlp_keys.decoder, - cfg.algo.decoder.dense_units, - cfg.algo.decoder.mlp_layers, - eval(cfg.algo.decoder.dense_act), - cfg.algo.decoder.layer_norm, - ) - if cfg.algo.mlp_keys.decoder is not None and len(cfg.algo.mlp_keys.decoder) > 0 - else None - ) - decoder = MultiDecoder(cnn_decoder, mlp_decoder) - if cfg.checkpoint.resume_from: - encoder.load_state_dict(state["encoder"]) - decoder.load_state_dict(state["decoder"]) - - # Setup actor and critic. Those will initialize with orthogonal weights - # both the actor and critic - actor = SACAEContinuousActor( - encoder=copy.deepcopy(encoder), - action_dim=act_dim, - distribution_cfg=cfg.distribution, - hidden_size=cfg.algo.actor.hidden_size, - action_low=envs.single_action_space.low, - action_high=envs.single_action_space.high, - ) - qfs = [ - SACAEQFunction( - input_dim=encoder.output_dim, action_dim=act_dim, hidden_size=cfg.algo.critic.hidden_size, output_dim=1 - ) - for _ in range(cfg.algo.critic.n) - ] - critic = SACAECritic(encoder=encoder, qfs=qfs) - - # The agent will tied convolutional and linear weights between the encoder actor and critic - agent = SACAEAgent( - actor, - critic, - target_entropy, - alpha=cfg.algo.alpha.alpha, - tau=cfg.algo.tau, - encoder_tau=cfg.algo.encoder.tau, - device=fabric.device, + agent, encoder, decoder = build_agent( + fabric, + cfg, + observation_space, + envs.single_action_space, + state["agent"] if cfg.checkpoint.resume_from else None, + state["encoder"] if cfg.checkpoint.resume_from else None, + state["decoder"] if cfg.checkpoint.resume_from else None, ) # Optimizers @@ -305,22 +217,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): decoder_optimizer = hydra.utils.instantiate(cfg.algo.decoder.optimizer, params=decoder.parameters()) if cfg.checkpoint.resume_from: - agent.load_state_dict(state["agent"]) qf_optimizer.load_state_dict(state["qf_optimizer"]) actor_optimizer.load_state_dict(state["actor_optimizer"]) alpha_optimizer.load_state_dict(state["alpha_optimizer"]) encoder_optimizer.load_state_dict(state["encoder_optimizer"]) decoder_optimizer.load_state_dict(state["decoder_optimizer"]) - encoder = fabric.setup_module(encoder) - decoder = fabric.setup_module(decoder) - agent.actor = fabric.setup_module(agent.actor) - agent.critic = fabric.setup_module(agent.critic) - qf_optimizer, actor_optimizer, alpha_optimizer, encoder_optimizer, decoder_optimizer = fabric.setup_optimizers( qf_optimizer, actor_optimizer, alpha_optimizer, encoder_optimizer, decoder_optimizer ) + local_vars = locals() + # Metrics aggregator = None if not MetricAggregator.disabled: @@ -560,3 +468,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero: test_sac_ae(agent.actor.module, fabric, cfg, log_dir) + + if not cfg.model_manager.disabled and fabric.is_global_zero: + + def log_models( + run_id: str, experiment_id: str | None = None, run_name: str | None = None + ) -> Dict[str, ModelInfo]: + with mlflow.start_run(run_id=run_id, experiment_id=experiment_id, run_name=run_name, nested=True) as _: + model_info = {} + unwrapped_models = {} + for k in cfg.model_manager.models.keys(): + unwrapped_models[k] = unwrap_fabric(local_vars[k]) + model_info[k] = mlflow.pytorch.log_model(unwrapped_models[k], artifact_path=k) + mlflow.log_dict(cfg, "config.json") + return model_info + + register_model(fabric, log_models, cfg) diff --git a/sheeprl/algos/sac_ae/utils.py b/sheeprl/algos/sac_ae/utils.py index dc6e5943..108091e7 100644 --- a/sheeprl/algos/sac_ae/utils.py +++ b/sheeprl/algos/sac_ae/utils.py @@ -1,17 +1,24 @@ -from typing import TYPE_CHECKING, Any, Dict +from __future__ import annotations +from typing import TYPE_CHECKING, Any, Dict, Sequence + +import gymnasium as gym +import mlflow import torch import torch.nn as nn from lightning import Fabric +from mlflow.models.model import ModelInfo from torch import Tensor from sheeprl.algos.sac.utils import AGGREGATOR_KEYS from sheeprl.utils.env import make_env +from sheeprl.utils.utils import unwrap_fabric if TYPE_CHECKING: from sheeprl.algos.sac_ae.agent import SACAEContinuousActor AGGREGATOR_KEYS = AGGREGATOR_KEYS.union({"Loss/reconstruction_loss"}) +MODELS_TO_REGISTER = {"agent", "encoder", "decoder"} @torch.no_grad() @@ -84,3 +91,23 @@ def weight_init(m: nn.Module): mid = m.weight.size(2) // 2 gain = nn.init.calculate_gain("relu") nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain) + + +def log_models_from_checkpoint( + fabric: Fabric, env: gym.Env | gym.Wrapper, cfg: Dict[str, Any], state: Dict[str, Any] +) -> Sequence[ModelInfo]: + from sheeprl.algos.sac_ae.agent import build_agent + + # Create the models + agent, encoder, decoder = build_agent( + fabric, cfg, env.observation_space, env.action_space, state["agent"], state["encoder"], state["decoder"] + ) + + # Log the model, create a new run if `cfg.run_id` is None. + model_info = {} + with mlflow.start_run(run_id=cfg.run.id, experiment_id=cfg.experiment.id, run_name=cfg.run.name, nested=True) as _: + model_info["agent"] = mlflow.pytorch.log_model(unwrap_fabric(agent), artifact_path="agent") + model_info["encoder"] = mlflow.pytorch.log_model(unwrap_fabric(encoder), artifact_path="encoder") + model_info["decoder"] = mlflow.pytorch.log_model(unwrap_fabric(decoder), artifact_path="decoder") + mlflow.log_dict(cfg.to_log, "config.json") + return model_info diff --git a/sheeprl/cli.py b/sheeprl/cli.py index 3ca73e5f..1dd2af89 100644 --- a/sheeprl/cli.py +++ b/sheeprl/cli.py @@ -1,22 +1,20 @@ -import datetime import importlib import os import pathlib -import time import warnings from pathlib import Path from typing import Any, Dict import hydra from lightning import Fabric -from lightning.fabric.loggers.tensorboard import TensorBoardLogger from lightning.fabric.strategies import STRATEGY_REGISTRY, DDPStrategy, SingleDeviceStrategy, Strategy -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig, OmegaConf, open_dict +from sheeprl.utils.logger import get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import algorithm_registry, evaluation_registry from sheeprl.utils.timer import timer -from sheeprl.utils.utils import dotdict, print_config +from sheeprl.utils.utils import dotdict, print_config, register_model_from_checkpoint def resume_from_checkpoint(cfg: DictConfig) -> Dict[str, Any]: @@ -79,21 +77,11 @@ def run_algorithm(cfg: Dict[str, Any]): command = task.__dict__[entrypoint] kwargs = {} if decoupled: - root_dir = ( - os.path.join("logs", "runs", cfg.root_dir) - if cfg.root_dir is not None - else os.path.join("logs", "runs", algo_name, datetime.today().strftime("%Y-%m-%d_%H-%M-%S")) - ) - run_name = ( - cfg.run_name if cfg.run_name is not None else f"{cfg.env.id}_{cfg.exp_name}_{cfg.seed}_{int(time.time())}" - ) - logger = None - if cfg.metric.log_level > 0: - logger = TensorBoardLogger(root_dir=root_dir, name=run_name) - logger.log_hyperparams(cfg) fabric: Fabric = hydra.utils.instantiate(cfg.fabric, _convert_="all") - if logger is not None: - fabric._loggers.extend([logger]) + logger = get_logger(fabric, cfg) + if logger and fabric.is_global_zero: + fabric._loggers = [logger] + fabric.logger.log_hyperparams(cfg) else: strategy = cfg.fabric.pop("strategy", "auto") if "sac_ae" in module: @@ -108,8 +96,6 @@ def run_algorithm(cfg: Dict[str, Any]): # Load exploration configurations ckpt_path = pathlib.Path(cfg.checkpoint.exploration_ckpt_path) exploration_cfg = OmegaConf.load(ckpt_path.parent.parent.parent / ".hydra" / "config.yaml") - exploration_cfg.pop("root_dir", None) - exploration_cfg.pop("run_name", None) exploration_cfg = dotdict(OmegaConf.to_container(exploration_cfg, resolve=True, throw_on_missing=True)) if exploration_cfg.env.id != cfg.env.id: raise ValueError( @@ -154,6 +140,22 @@ def run_algorithm(cfg: Dict[str, Any]): for k in keys_to_remove: cfg.metric.aggregator.metrics.pop(k, None) MetricAggregator.disabled = cfg.metric.log_level == 0 or len(cfg.metric.aggregator.metrics) == 0 + + # Model Manager + if hasattr(cfg, "model_manager") and not cfg.model_manager.disabled and cfg.model_manager.models is not None: + predefined_models_keys = set() + if not hasattr(utils, "MODELS_TO_REGISTER"): + warnings.warn( + f"No 'MODELS_TO_REGISTER' set found for the {algo_name} algorithm under the {module} module. " + "No model will be registered.", + UserWarning, + ) + else: + predefined_models_keys = utils.MODELS_TO_REGISTER + keys_to_remove = set(cfg.model_manager.models.keys()) - predefined_models_keys + for k in keys_to_remove: + cfg.model_manager.models.pop(k, None) + cfg.model_manager.disabled == cfg.model_manager.disabled or len(cfg.model_manager.models) == 0 fabric.launch(command, cfg, **kwargs) @@ -281,8 +283,6 @@ def evaluation(cfg: DictConfig): ckpt_cfg = OmegaConf.load(checkpoint_path.parent.parent.parent / ".hydra" / "config.yaml") # Merge the two configs - from omegaconf import open_dict - with open_dict(cfg): capture_video = getattr(cfg.env, "capture_video", True) cfg.env = {"capture_video": capture_video, "num_envs": 1} @@ -311,3 +311,46 @@ def evaluation(cfg: DictConfig): # Check the validity of the configuration and run the evaluation check_configs_evaluation(ckpt_cfg) eval_algorithm(ckpt_cfg) + + +@hydra.main(version_base="1.3", config_path="configs", config_name="model_manager_config") +def registration(cfg: DictConfig): + checkpoint_path = Path(cfg.checkpoint_path) + ckpt_cfg = OmegaConf.load(checkpoint_path.parent.parent.parent / ".hydra" / "config.yaml") + + # Merge the two configs + with open_dict(cfg): + cfg.env = ckpt_cfg.env + cfg.exp_name = ckpt_cfg.exp_name + cfg.algo = ckpt_cfg.algo + cfg.distribution = ckpt_cfg.distribution + cfg.seed = ckpt_cfg.seed + + cfg = dotdict(OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)) + cfg.to_log = dotdict(OmegaConf.to_container(ckpt_cfg, resolve=True, throw_on_missing=True)) + + precision = getattr(ckpt_cfg.fabric, "precision", None) + fabric = Fabric(devices=1, accelerator="cpu", num_nodes=1, precision=precision) + + # Load the checkpoint + state = fabric.load(cfg.checkpoint_path) + # Retrieve the algorithm name, used to import the custom + # log_models_from_checkpoint function. + algo_name = cfg.algo.name + if "decoupled" in cfg.algo.name: + algo_name = algo_name.replace("_decoupled", "") + if algo_name.startswith("p2e_dv"): + algo_name = "_".join(algo_name.split("_")[:2]) + try: + log_models_from_checkpoint = importlib.import_module( + f"sheeprl.algos.{algo_name}.utils" + ).log_models_from_checkpoint + except Exception as e: + print(e) + raise RuntimeError( + f"Make sure that the algorithm is defined in the `./sheeprl/algos/{algo_name}` folder " + "and that the `log_models_from_checkpoint` function is defined " + f"in the `./sheeprl/algos/{algo_name}/utils.py` file." + ) + + fabric.launch(register_model_from_checkpoint, cfg, state, log_models_from_checkpoint) diff --git a/sheeprl/configs/config.yaml b/sheeprl/configs/config.yaml index 437fdfb7..a9e795cf 100644 --- a/sheeprl/configs/config.yaml +++ b/sheeprl/configs/config.yaml @@ -10,6 +10,7 @@ defaults: - env: default.yaml - fabric: default.yaml - metric: default.yaml + - model_manager: default.yaml - hydra: default.yaml - exp: ??? @@ -23,6 +24,6 @@ seed: 42 torch_deterministic: False # Output folders -exp_name: "default" +exp_name: ${algo.name}_${env.id} run_name: ${now:%Y-%m-%d_%H-%M-%S}_${exp_name}_${seed} root_dir: ${algo.name}/${env.id} diff --git a/sheeprl/configs/exp/dreamer_v1.yaml b/sheeprl/configs/exp/dreamer_v1.yaml index 75723073..956957bd 100644 --- a/sheeprl/configs/exp/dreamer_v1.yaml +++ b/sheeprl/configs/exp/dreamer_v1.yaml @@ -3,6 +3,7 @@ defaults: - override /algo: dreamer_v1 - override /env: atari + - override /model_manager: dreamer_v1 - _self_ # Algorithm diff --git a/sheeprl/configs/exp/dreamer_v2.yaml b/sheeprl/configs/exp/dreamer_v2.yaml index c06b01ad..66faf0c9 100644 --- a/sheeprl/configs/exp/dreamer_v2.yaml +++ b/sheeprl/configs/exp/dreamer_v2.yaml @@ -3,6 +3,7 @@ defaults: - override /algo: dreamer_v2 - override /env: atari + - override /model_manager: dreamer_v2 - _self_ # Algorithm diff --git a/sheeprl/configs/exp/dreamer_v3.yaml b/sheeprl/configs/exp/dreamer_v3.yaml index ea108d6c..9f7b915d 100644 --- a/sheeprl/configs/exp/dreamer_v3.yaml +++ b/sheeprl/configs/exp/dreamer_v3.yaml @@ -3,6 +3,7 @@ defaults: - override /algo: dreamer_v3 - override /env: atari + - override /model_manager: dreamer_v3 - _self_ # Algorithm diff --git a/sheeprl/configs/exp/droq.yaml b/sheeprl/configs/exp/droq.yaml index 4f484446..374fd371 100644 --- a/sheeprl/configs/exp/droq.yaml +++ b/sheeprl/configs/exp/droq.yaml @@ -3,4 +3,5 @@ defaults: - sac - override /algo: droq - - _self_ \ No newline at end of file + - override /model_manager: droq + - _self_ diff --git a/sheeprl/configs/exp/p2e_dv1_exploration.yaml b/sheeprl/configs/exp/p2e_dv1_exploration.yaml index 01c687ae..34b95034 100644 --- a/sheeprl/configs/exp/p2e_dv1_exploration.yaml +++ b/sheeprl/configs/exp/p2e_dv1_exploration.yaml @@ -3,6 +3,7 @@ defaults: - dreamer_v1 - override /algo: p2e_dv1 + - override /model_manager: p2e_dv1_exploration - _self_ algo: diff --git a/sheeprl/configs/exp/p2e_dv1_finetuning.yaml b/sheeprl/configs/exp/p2e_dv1_finetuning.yaml index 70ca4e61..c2d4abfb 100644 --- a/sheeprl/configs/exp/p2e_dv1_finetuning.yaml +++ b/sheeprl/configs/exp/p2e_dv1_finetuning.yaml @@ -3,6 +3,7 @@ defaults: - dreamer_v1 - override /algo: p2e_dv1 + - override /model_manager: p2e_dv1_finetuning - _self_ algo: diff --git a/sheeprl/configs/exp/p2e_dv2_exploration.yaml b/sheeprl/configs/exp/p2e_dv2_exploration.yaml index a47d2e71..bae53323 100644 --- a/sheeprl/configs/exp/p2e_dv2_exploration.yaml +++ b/sheeprl/configs/exp/p2e_dv2_exploration.yaml @@ -3,6 +3,7 @@ defaults: - dreamer_v2 - override /algo: p2e_dv2 + - override /model_manager: p2e_dv2_exploration - _self_ algo: diff --git a/sheeprl/configs/exp/p2e_dv2_finetuning.yaml b/sheeprl/configs/exp/p2e_dv2_finetuning.yaml index f3ff1d46..1d315969 100644 --- a/sheeprl/configs/exp/p2e_dv2_finetuning.yaml +++ b/sheeprl/configs/exp/p2e_dv2_finetuning.yaml @@ -3,6 +3,7 @@ defaults: - dreamer_v2 - override /algo: p2e_dv2 + - override /model_manager: p2e_dv2_finetuning - _self_ algo: diff --git a/sheeprl/configs/exp/p2e_dv3_exploration.yaml b/sheeprl/configs/exp/p2e_dv3_exploration.yaml index bc82e96d..009c2d99 100644 --- a/sheeprl/configs/exp/p2e_dv3_exploration.yaml +++ b/sheeprl/configs/exp/p2e_dv3_exploration.yaml @@ -3,6 +3,7 @@ defaults: - dreamer_v3 - override /algo: p2e_dv3 + - override /model_manager: p2e_dv3_exploration - _self_ algo: @@ -75,10 +76,10 @@ metric: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - # There could be more exploration critics, so here the general metrics - # are defined, all the metrics for the exploration critics + # There could be more exploration critics so here the general metrics + # are defined all the metrics for the exploration critics # will be instantiated with the key: _. - # For instance, if 'intrinsic' is the key of an exploration critic, + # For instance if 'intrinsic' is the key of an exploration critic # then its 'Loss/value_loss_exploration' metric will be logged under the # 'Loss/value_loss_exploration_intrinsic' key. # NOTE: Remove from here the metrics you do not want to plot for ALL diff --git a/sheeprl/configs/exp/p2e_dv3_finetuning.yaml b/sheeprl/configs/exp/p2e_dv3_finetuning.yaml index 502f8fcd..3d67448a 100644 --- a/sheeprl/configs/exp/p2e_dv3_finetuning.yaml +++ b/sheeprl/configs/exp/p2e_dv3_finetuning.yaml @@ -3,6 +3,7 @@ defaults: - dreamer_v3 - override /algo: p2e_dv3 + - override /model_manager: p2e_dv3_finetuning - _self_ algo: diff --git a/sheeprl/configs/exp/p2e_dv3_fntn_L_doapp_64px_gray_combo_discrete_5Mstps.yaml b/sheeprl/configs/exp/p2e_dv3_fntn_L_doapp_64px_gray_combo_discrete_5Mstps.yaml index c35a9a98..8dcad491 100644 --- a/sheeprl/configs/exp/p2e_dv3_fntn_L_doapp_64px_gray_combo_discrete_5Mstps.yaml +++ b/sheeprl/configs/exp/p2e_dv3_fntn_L_doapp_64px_gray_combo_discrete_5Mstps.yaml @@ -28,7 +28,7 @@ env: # Checkpoint checkpoint: every: 100000 - exploration_ckpt_path: /home/michele.milesi/repos/sheeprl/logs/runs/p2e_dv3/2023-10-04_11-21-17/doapp_default_0/version_0/checkpoint/ckpt_14900000_0.ckpt + exploration_ckpt_path: ??? # Buffer buffer: diff --git a/sheeprl/configs/exp/ppo.yaml b/sheeprl/configs/exp/ppo.yaml index c671bb31..c5c05719 100644 --- a/sheeprl/configs/exp/ppo.yaml +++ b/sheeprl/configs/exp/ppo.yaml @@ -3,6 +3,7 @@ defaults: - override /algo: ppo - override /env: gym + - override /model_manager: ppo - _self_ # Algorithm diff --git a/sheeprl/configs/exp/ppo_recurrent.yaml b/sheeprl/configs/exp/ppo_recurrent.yaml index 56a3720a..4664cc39 100644 --- a/sheeprl/configs/exp/ppo_recurrent.yaml +++ b/sheeprl/configs/exp/ppo_recurrent.yaml @@ -3,6 +3,7 @@ defaults: - ppo - override /algo: ppo_recurrent + - override /model_manager: ppo_recurrent - _self_ algo: diff --git a/sheeprl/configs/exp/sac.yaml b/sheeprl/configs/exp/sac.yaml index 065612e0..223377ea 100644 --- a/sheeprl/configs/exp/sac.yaml +++ b/sheeprl/configs/exp/sac.yaml @@ -3,6 +3,7 @@ defaults: - override /algo: sac - override /env: gym + - override /model_manager: sac - _self_ # Algorithm diff --git a/sheeprl/configs/exp/sac_ae.yaml b/sheeprl/configs/exp/sac_ae.yaml index de7900ba..4ea605be 100644 --- a/sheeprl/configs/exp/sac_ae.yaml +++ b/sheeprl/configs/exp/sac_ae.yaml @@ -3,6 +3,7 @@ defaults: - sac - override /algo: sac_ae + - override /model_manager: sac_ae - _self_ # Algorithm diff --git a/sheeprl/configs/logger/mlflow.yaml b/sheeprl/configs/logger/mlflow.yaml new file mode 100644 index 00000000..0a3951b0 --- /dev/null +++ b/sheeprl/configs/logger/mlflow.yaml @@ -0,0 +1,10 @@ +_target_: lightning.pytorch.loggers.MLFlowLogger +experiment_name: ${exp_name} +tracking_uri: ${oc.env:MLFLOW_TRACKING_URI} +run_name: ${algo.name}_${env.id}_${now:%Y-%m-%d_%H-%M-%S} +tags: null +save_dir: null +prefix: "" +artifact_location: null +run_id: null +log_model: false \ No newline at end of file diff --git a/sheeprl/configs/logger/tensorboard.yaml b/sheeprl/configs/logger/tensorboard.yaml new file mode 100644 index 00000000..3d64cae6 --- /dev/null +++ b/sheeprl/configs/logger/tensorboard.yaml @@ -0,0 +1,7 @@ +_target_: lightning.fabric.loggers.TensorBoardLogger +name: ${run_name} +root_dir: logs/runs/${root_dir} +version: null +default_hp_metric: True +prefix: "" +sub_dir: null \ No newline at end of file diff --git a/sheeprl/configs/metric/default.yaml b/sheeprl/configs/metric/default.yaml index b92de6d8..f18c9ab3 100644 --- a/sheeprl/configs/metric/default.yaml +++ b/sheeprl/configs/metric/default.yaml @@ -1,3 +1,7 @@ +defaults: + - _self_ + - /logger@logger: tensorboard + log_every: 5000 disable_timer: False diff --git a/sheeprl/configs/model_manager/default.yaml b/sheeprl/configs/model_manager/default.yaml new file mode 100644 index 00000000..e397b00a --- /dev/null +++ b/sheeprl/configs/model_manager/default.yaml @@ -0,0 +1,2 @@ +disabled: True +models: {} diff --git a/sheeprl/configs/model_manager/dreamer_v1.yaml b/sheeprl/configs/model_manager/dreamer_v1.yaml new file mode 100644 index 00000000..f16abd6e --- /dev/null +++ b/sheeprl/configs/model_manager/dreamer_v1.yaml @@ -0,0 +1,17 @@ +defaults: + - default + - _self_ + +models: + world_model: + model_name: "${exp_name}_world_model" + description: "DreamerV1 World Model used in ${env.id} Environment" + tags: {} + actor: + model_name: "${exp_name}_actor" + description: "DreamerV1 Actor used in ${env.id} Environment" + tags: {} + critic: + model_name: "${exp_name}_critic" + description: "DreamerV1 Critic used in ${env.id} Environment" + tags: {} diff --git a/sheeprl/configs/model_manager/dreamer_v2.yaml b/sheeprl/configs/model_manager/dreamer_v2.yaml new file mode 100644 index 00000000..0c084ac9 --- /dev/null +++ b/sheeprl/configs/model_manager/dreamer_v2.yaml @@ -0,0 +1,21 @@ +defaults: + - default + - _self_ + +models: + world_model: + model_name: "${exp_name}_world_model" + description: "DreamerV2 World Model used in ${env.id} Environment" + tags: {} + actor: + model_name: "${exp_name}_actor" + description: "DreamerV2 Actor used in ${env.id} Environment" + tags: {} + critic: + model_name: "${exp_name}_critic" + description: "DreamerV2 Critic used in ${env.id} Environment" + tags: {} + target_critic: + model_name: "${exp_name}_target_critic" + description: "DreamerV2 Target Critic used in ${env.id} Environment" + tags: {} \ No newline at end of file diff --git a/sheeprl/configs/model_manager/dreamer_v3.yaml b/sheeprl/configs/model_manager/dreamer_v3.yaml new file mode 100644 index 00000000..b90637bf --- /dev/null +++ b/sheeprl/configs/model_manager/dreamer_v3.yaml @@ -0,0 +1,25 @@ +defaults: + - default + - _self_ + +models: + world_model: + model_name: "${exp_name}_world_model" + description: "DreamerV3 World Model used in ${env.id} Environment" + tags: {} + actor: + model_name: "${exp_name}_actor" + description: "DreamerV3 Actor used in ${env.id} Environment" + tags: {} + critic: + model_name: "${exp_name}_critic" + description: "DreamerV3 Critic used in ${env.id} Environment" + tags: {} + target_critic: + model_name: "${exp_name}_target_critic" + description: "DreamerV3 Target Critic used in ${env.id} Environment" + tags: {} + moments: + model_name: "${exp_name}_moments" + description: "DreamerV3 Moments used in ${env.id} Environment" + tags: {} \ No newline at end of file diff --git a/sheeprl/configs/model_manager/droq.yaml b/sheeprl/configs/model_manager/droq.yaml new file mode 100644 index 00000000..3234c90c --- /dev/null +++ b/sheeprl/configs/model_manager/droq.yaml @@ -0,0 +1,9 @@ +defaults: + - default + - _self_ + +models: + agent: + model_name: "${exp_name}" + description: "DroQ Agent in ${env.id} Environment" + tags: {} \ No newline at end of file diff --git a/sheeprl/configs/model_manager/p2e_dv1_exploration.yaml b/sheeprl/configs/model_manager/p2e_dv1_exploration.yaml new file mode 100644 index 00000000..9cf7a535 --- /dev/null +++ b/sheeprl/configs/model_manager/p2e_dv1_exploration.yaml @@ -0,0 +1,29 @@ +defaults: + - default + - _self_ + +models: + world_model: + model_name: "${exp_name}_world_model" + description: "P2E_DV1 World Model used in ${env.id} Environment" + tags: {} + ensembles: + model_name: "${exp_name}_ensembles" + description: "P2E_DV1 Ensembles used in ${env.id} Environment" + tags: {} + actor_exploration: + model_name: "${exp_name}_actor_exploration" + description: "P2E_DV1 Actor Exploration used in ${env.id} Environment" + tags: {} + critic_exploration: + model_name: "${exp_name}_critic_exploration" + description: "P2E_DV1 Critic Exploration used in ${env.id} Environment" + tags: {} + actor_task: + model_name: "${exp_name}_actor_task" + description: "P2E_DV1 Actor Exploration used in ${env.id} Environment" + tags: {} + critic_task: + model_name: "${exp_name}_critic_task" + description: "P2E_DV1 Critic Exploration used in ${env.id} Environment" + tags: {} \ No newline at end of file diff --git a/sheeprl/configs/model_manager/p2e_dv1_finetuning.yaml b/sheeprl/configs/model_manager/p2e_dv1_finetuning.yaml new file mode 100644 index 00000000..c9c118b8 --- /dev/null +++ b/sheeprl/configs/model_manager/p2e_dv1_finetuning.yaml @@ -0,0 +1,17 @@ +defaults: + - default + - _self_ + +models: + world_model: + model_name: "${exp_name}_world_model" + description: "P2E_DV1 World Model used in ${env.id} Environment" + tags: {} + actor_task: + model_name: "${exp_name}_actor_task" + description: "P2E_DV1 Actor used in ${env.id} Environment" + tags: {} + critic_task: + model_name: "${exp_name}_critic_task" + description: "P2E_DV1 Critic in ${env.id} Environment" + tags: {} diff --git a/sheeprl/configs/model_manager/p2e_dv2_exploration.yaml b/sheeprl/configs/model_manager/p2e_dv2_exploration.yaml new file mode 100644 index 00000000..dfed6842 --- /dev/null +++ b/sheeprl/configs/model_manager/p2e_dv2_exploration.yaml @@ -0,0 +1,37 @@ +defaults: + - default + - _self_ + +models: + world_model: + model_name: "${exp_name}_world_model" + description: "P2E_DV2 World Model used in ${env.id} Environment" + tags: {} + ensembles: + model_name: "${exp_name}_ensembles" + description: "P2E_DV2 Ensembles used in ${env.id} Environment" + tags: {} + actor_exploration: + model_name: "${exp_name}_actor_exploration" + description: "P2E_DV2 Actor Exploration used in ${env.id} Environment" + tags: {} + critic_exploration: + model_name: "${exp_name}_critic_exploration" + description: "P2E_DV2 Critic Exploration in ${env.id} Environment" + tags: {} + target_critic_exploration: + model_name: "${exp_name}_target_critic_exploration" + description: "P2E_DV2 Target Critic Exploration in ${env.id} Environment" + tags: {} + actor_task: + model_name: "${exp_name}_actor_task" + description: "P2E_DV2 Actor Task used in ${env.id} Environment" + tags: {} + critic_task: + model_name: "${exp_name}_critic_task" + description: "P2E_DV2 Critic Task in ${env.id} Environment" + tags: {} + target_critic_task: + model_name: "${exp_name}_target_critic_task" + description: "P2E_DV2 Target Critic Task in ${env.id} Environment" + tags: {} \ No newline at end of file diff --git a/sheeprl/configs/model_manager/p2e_dv2_finetuning.yaml b/sheeprl/configs/model_manager/p2e_dv2_finetuning.yaml new file mode 100644 index 00000000..23f4f98a --- /dev/null +++ b/sheeprl/configs/model_manager/p2e_dv2_finetuning.yaml @@ -0,0 +1,21 @@ +defaults: + - default + - _self_ + +models: + world_model: + model_name: "${exp_name}_world_model" + description: "P2E_DV2 World Model used in ${env.id} Environment" + tags: {} + actor_task: + model_name: "${exp_name}_actor_task" + description: "P2E_DV2 Actor Task used in ${env.id} Environment" + tags: {} + critic_task: + model_name: "${exp_name}_critic_task" + description: "P2E_DV2 Critic Task used in ${env.id} Environment" + tags: {} + target_critic_task: + model_name: "${exp_name}_target_critic_task" + description: "P2E_DV2 Target Critic Task used in ${env.id} Environment" + tags: {} \ No newline at end of file diff --git a/sheeprl/configs/model_manager/p2e_dv3_exploration.yaml b/sheeprl/configs/model_manager/p2e_dv3_exploration.yaml new file mode 100644 index 00000000..5ca08192 --- /dev/null +++ b/sheeprl/configs/model_manager/p2e_dv3_exploration.yaml @@ -0,0 +1,57 @@ +defaults: + - default + - _self_ + +models: + world_model: + model_name: "${exp_name}_world_model" + description: "P2E_DV3 World Model used in ${env.id} Environment" + tags: {} + ensembles: + model_name: "${exp_name}_ensembles" + description: "P2E_DV1 Ensembles used in ${env.id} Environment" + tags: {} + actor_exploration: + model_name: "${exp_name}_actor_exploration" + description: "P2E_DV3 Actor Exploration used in ${env.id} Environment" + tags: {} + critic_exploration_intrinsic: + model_name: "${exp_name}_critic_exploration_intrinsic" + description: "P2E_DV3 Critic Exploration used in ${env.id} Environment" + tags: {} + target_critic_exploration_intrinsic: + model_name: "${exp_name}_target_critic_exploration_intrinsic" + description: "P2E_DV3 Target Critic Exploration used in ${env.id} Environment" + tags: {} + moments_exploration_intrinsic: + model_name: "${exp_name}_moments_exploration_intrinsic" + description: "P2E_DV3 Moments Exploration used in ${env.id} Environment" + tags: {} + critic_exploration_extrinsic: + model_name: "${exp_name}_critic_exploration_extrinsic" + description: "P2E_DV3 Critic Exploration used in ${env.id} Environment" + tags: {} + target_critic_exploration_extrinsic: + model_name: "${exp_name}_target_critic_exploration_extrinsic" + description: "P2E_DV3 Target Critic Exploration used in ${env.id} Environment" + tags: {} + moments_exploration_extrinsic: + model_name: "${exp_name}_moments_exploration_extrinsic" + description: "P2E_DV3 Moments Exploration used in ${env.id} Environment" + tags: {} + actor_task: + model_name: "${exp_name}_actor_task" + description: "P2E_DV3 Actor Task used in ${env.id} Environment" + tags: {} + critic_task: + model_name: "${exp_name}_critic_task" + description: "P2E_DV3 Critic Task used in ${env.id} Environment" + tags: {} + target_critic_task: + model_name: "${exp_name}_target_critic_task" + description: "P2E_DV3 Target Critic Task used in ${env.id} Environment" + tags: {} + moments_task: + model_name: "${exp_name}_moments_task" + description: "P2E_DV3 Moments Task used in ${env.id} Environment" + tags: {} \ No newline at end of file diff --git a/sheeprl/configs/model_manager/p2e_dv3_finetuning.yaml b/sheeprl/configs/model_manager/p2e_dv3_finetuning.yaml new file mode 100644 index 00000000..a600179c --- /dev/null +++ b/sheeprl/configs/model_manager/p2e_dv3_finetuning.yaml @@ -0,0 +1,25 @@ +defaults: + - default + - _self_ + +models: + world_model: + model_name: "${exp_name}_world_model" + description: "P2E_DV3 World Model used in ${env.id} Environment" + tags: {} + actor_task: + model_name: "${exp_name}_actor_task" + description: "P2E_DV3 Actor Task used in ${env.id} Environment" + tags: {} + critic_task: + model_name: "${exp_name}_critic_task" + description: "P2E_DV3 Critic Task used in ${env.id} Environment" + tags: {} + target_critic_task: + model_name: "${exp_name}_target_critic_task" + description: "P2E_DV3 Target Critic Task used in ${env.id} Environment" + tags: {} + moments_task: + model_name: "${exp_name}_moments_task" + description: "P2E_DV3 Moments Task used in ${env.id} Environment" + tags: {} \ No newline at end of file diff --git a/sheeprl/configs/model_manager/ppo.yaml b/sheeprl/configs/model_manager/ppo.yaml new file mode 100644 index 00000000..b9061972 --- /dev/null +++ b/sheeprl/configs/model_manager/ppo.yaml @@ -0,0 +1,9 @@ +defaults: + - default + - _self_ + +models: + agent: + model_name: "${exp_name}" + description: "PPO Agent in ${env.id} Environment" + tags: {} \ No newline at end of file diff --git a/sheeprl/configs/model_manager/ppo_recurrent.yaml b/sheeprl/configs/model_manager/ppo_recurrent.yaml new file mode 100644 index 00000000..56aca5b0 --- /dev/null +++ b/sheeprl/configs/model_manager/ppo_recurrent.yaml @@ -0,0 +1,9 @@ +defaults: + - default + - _self_ + +models: + agent: + model_name: "${exp_name}" + description: "PPO Recurrent Agent in ${env.id} Environment" + tags: {} \ No newline at end of file diff --git a/sheeprl/configs/model_manager/sac.yaml b/sheeprl/configs/model_manager/sac.yaml new file mode 100644 index 00000000..3a89905e --- /dev/null +++ b/sheeprl/configs/model_manager/sac.yaml @@ -0,0 +1,9 @@ +defaults: + - default + - _self_ + +models: + agent: + model_name: "${exp_name}" + description: "SAC Agent in ${env.id} Environment" + tags: {} \ No newline at end of file diff --git a/sheeprl/configs/model_manager/sac_ae.yaml b/sheeprl/configs/model_manager/sac_ae.yaml new file mode 100644 index 00000000..9c278f2a --- /dev/null +++ b/sheeprl/configs/model_manager/sac_ae.yaml @@ -0,0 +1,17 @@ +defaults: + - default + - _self_ + +models: + encoder: + model_name: "${exp_name}_encoder" + description: "SAC-AE Encoder used in ${env.id} Environment" + tags: {} + decoder: + model_name: "${exp_name}_decoder" + description: "SAC-AE Decoder used in ${env.id} Environment" + tags: {} + agent: + model_name: "${exp_name}_agent" + description: "SAC-AE Agent in ${env.id} Environment" + tags: {} \ No newline at end of file diff --git a/sheeprl/configs/model_manager_config.yaml b/sheeprl/configs/model_manager_config.yaml new file mode 100644 index 00000000..21562069 --- /dev/null +++ b/sheeprl/configs/model_manager_config.yaml @@ -0,0 +1,22 @@ +# @package _global_ + +# specify here default training configuration +defaults: + - _self_ + - model_manager: ??? + - override hydra/hydra_logging: disabled + - override hydra/job_logging: disabled + +hydra: + output_subdir: null + run: + dir: . + +checkpoint_path: ??? +run: + id: null + name: ${now:%Y-%m-%d_%H-%M-%S}_${exp_name} +experiment: + id: null + name: ${exp_name}_${now:%Y-%m-%d_%H-%M-%S} +tracking_uri: ${oc.env:MLFLOW_TRACKING_URI} diff --git a/sheeprl/utils/logger.py b/sheeprl/utils/logger.py index 8f83e3d5..45368023 100644 --- a/sheeprl/utils/logger.py +++ b/sheeprl/utils/logger.py @@ -1,23 +1,38 @@ import os import warnings -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional +import hydra from lightning import Fabric -from lightning.fabric.loggers import TensorBoardLogger +from lightning.fabric.loggers.logger import Logger from lightning.fabric.plugins.collectives import TorchCollective from lightning.fabric.utilities.cloud_io import _is_dir, get_filesystem -def create_tensorboard_logger(fabric: Fabric, cfg: Dict[str, Any]) -> Tuple[Optional[TensorBoardLogger]]: +def get_logger(fabric: Fabric, cfg: Dict[str, Any]) -> Optional[Logger]: # Set logger only on rank-0 but share the logger directory: since we don't know # what is happening during the `fabric.save()` method, at least we assure that all # ranks save under the same named folder. # As a plus, rank-0 sets the time uniquely for everyone logger = None - if fabric.is_global_zero: - root_dir = os.path.join("logs", "runs", cfg.root_dir) - if cfg.metric.log_level > 0: - logger = TensorBoardLogger(root_dir=root_dir, name=cfg.run_name) + if fabric.is_global_zero and cfg.metric.log_level > 0: + if "tensorboard" in cfg.metric.logger._target_.lower(): + root_dir = os.path.join("logs", "runs", cfg.root_dir) + if root_dir != cfg.metric.logger.root_dir: + warnings.warn( + "The specified root directory for the TensorBoardLogger is different from the experiment one, " + "so the logger one will be ignored and replaced with the experiment root directory", + UserWarning, + ) + if cfg.run_name != cfg.metric.logger.name: + warnings.warn( + "The specified name for the TensorBoardLogger is different from the `run_name` of the experiment, " + "so the logger one will be ignored and replaced with the experiment `run_name`", + UserWarning, + ) + logger = hydra.utils.instantiate(cfg.metric.logger, root_dir=root_dir, name=cfg.run_name) + else: + logger = hydra.utils.instantiate(cfg.metric.logger) return logger @@ -40,7 +55,7 @@ def get_log_dir(fabric: Fabric, root_dir: str, run_name: str, share: bool = True world_collective.create_group() if fabric.is_global_zero: # If the logger was instantiated, then take the log_dir from it - if len(fabric.loggers) > 0: + if len(fabric.loggers) > 0 and fabric.logger.log_dir is not None: log_dir = fabric.logger.log_dir else: # Otherwise the rank-zero process creates the log_dir diff --git a/sheeprl/utils/model_manager.py b/sheeprl/utils/model_manager.py new file mode 100644 index 00000000..b5e44b28 --- /dev/null +++ b/sheeprl/utils/model_manager.py @@ -0,0 +1,323 @@ +"""Thank you @lorenzomammana: https://github.com/orobix/quadra/blob/main/quadra/utils/model_manager.py""" + +from __future__ import annotations + +import getpass +import os +import warnings +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any, Dict, Literal, Set + +from git import Sequence +from lightning import Fabric + +try: + import mlflow # noqa + from mlflow.entities import Run # noqa + from mlflow.entities.model_registry import ModelVersion # noqa + from mlflow.exceptions import RestException # noqa + from mlflow.tracking import MlflowClient # noqa + + MLFLOW_AVAILABLE = True +except ImportError: + MLFLOW_AVAILABLE = False + + +VERSION_MD_TEMPLATE = "## **Version {}**\n" +DESCRIPTION_MD_TEMPLATE = "### Description: \n{}\n" + + +class AbstractModelManager(ABC): + """Abstract class for model managers.""" + + @abstractmethod + def __init__(self, fabric: Fabric) -> None: + self.fabric = fabric + + @abstractmethod + def register_model( + self, model_location: str, model_name: str, description: str, tags: Dict[str, Any] | None = None + ) -> Any: + """Register a model in the model registry.""" + + @abstractmethod + def get_latest_version(self, model_name: str) -> Any: + """Get the latest version of a model for all the possible stages or filtered by stage.""" + + @abstractmethod + def transition_model(self, model_name: str, version: int, stage: str, description: str | None = None) -> Any: + """Transition the model with the given version to a new stage.""" + + @abstractmethod + def delete_model(self, model_name: str, version: int, description: str | None = None) -> None: + """Delete a model with the given version.""" + + @abstractmethod + def register_best_models( + self, + experiment_name: str, + models_info: Dict[str, Dict[str, Any]], + metric: str = "Test/cumulative_reward", + mode: Literal["max", "min"] = "max", + ) -> Any: + """Register the best models from an experiment.""" + + @abstractmethod + def download_model(self, model_name: str, version: int, output_path: str) -> None: + """Download the model with the given version to the given output path.""" + + +class MlflowModelManager(AbstractModelManager): + """Model manager for Mlflow.""" + + def __init__(self, fabric: Fabric, tracking_uri: str): + if not MLFLOW_AVAILABLE: + raise ImportError("Mlflow is not available, please install it with pip install mlflow.") + + super().__init__(fabric) + self.tracking_uri = tracking_uri + + mlflow.set_tracking_uri(self.tracking_uri) + self.client = MlflowClient() + + def register_model( + self, model_location: str, model_name: str, description: str | None = None, tags: Dict[str, Any] | None = None + ) -> ModelVersion: + """Register a model in the model registry. + + Args: + model_location (str): The model uri. + model_name (str): The name of the model after it is registered. + description (str, optional): A description of the model, this will be added to the model changelog. + Default to None. + tags (Dict[str, Any], optional): A dictionary of tags to add to the model. + Default to None. + + Returns: + The model version. + """ + model_version = mlflow.register_model(model_uri=model_location, name=model_name, tags=tags) + self.fabric.print(f"Registered model {model_name} with version {model_version.version}") + registered_model_description = self.client.get_registered_model(model_name).description + + if model_version.version == "1": + header = "# MODEL CHANGELOG\n" + else: + header = "" + + new_model_description = VERSION_MD_TEMPLATE.format(model_version.version) + new_model_description += self._get_author_and_date() + new_model_description += self._generate_description(description) + + self.client.update_registered_model(model_name, header + registered_model_description + new_model_description) + + self.client.update_model_version( + model_name, model_version.version, "# MODEL CHANGELOG\n" + new_model_description + ) + + return model_version + + def get_latest_version(self, model_name: str) -> ModelVersion: + """Get the latest version of a model. + + Args: + model_name (str): The name of the model. + + Returns: + The model version. + """ + latest_version = max(int(x.version) for x in self.client.get_latest_versions(model_name)) + model_version = self.client.get_model_version(model_name, latest_version) + + return model_version + + def transition_model( + self, model_name: str, version: int, stage: str, description: str | None = None + ) -> ModelVersion | None: + """Transition a model to a new stage. + + Args: + model_name (str): The name of the model. + version (int): The version of the model + stage (str): The stage of the model. + description (str, optional): A description of the transition, this will be added to the model changelog. + Default to None. + """ + previous_stage = self._safe_get_stage(model_name, version) + + if previous_stage is None: + return None + + if previous_stage.lower() == stage.lower(): + warnings.warn(f"Model {model_name} version {version} is already in stage {stage}") + return self.client.get_model_version(model_name, version) + + self.fabric.print(f"Transitioning model {model_name} version {version} from {previous_stage} to {stage}") + model_version = self.client.transition_model_version_stage(name=model_name, version=version, stage=stage) + new_stage = model_version.current_stage + registered_model_description = self.client.get_registered_model(model_name).description + single_model_description = self.client.get_model_version(model_name, version).description + + new_model_description = "## **Transition:**\n" + new_model_description += f"### Version {model_version.version} from {previous_stage} to {new_stage}\n" + new_model_description += self._get_author_and_date() + new_model_description += self._generate_description(description) + + self.client.update_registered_model(model_name, registered_model_description + new_model_description) + self.client.update_model_version( + model_name, model_version.version, single_model_description + new_model_description + ) + + return model_version + + def delete_model(self, model_name: str, version: int, description: str | None = None) -> None: + """Delete a model. + + Args: + model_name (str): The name of the model, + version (int): The version of the model. + description (str, optional): Why the model was deleted, this will be added to the model changelog. + Default to None. + """ + model_stage = self._safe_get_stage(model_name, version) + + if model_stage is None: + return + + if ( + input( + f"Model named `{model_name}`, version {version} is in stage {model_stage}, " + "type the model name to continue deletion:" + ) + != model_name + ): + warnings.warn("Model name did not match, aborting deletion") + return + + self.fabric.print(f"Deleting model {model_name} version {version}") + self.client.delete_model_version(model_name, version) + + registered_model_description = self.client.get_registered_model(model_name).description + + new_model_description = "## **Deletion:**\n" + new_model_description += f"### Version {version} from stage: {model_stage}\n" + new_model_description += self._get_author_and_date() + new_model_description += self._generate_description(description) + + self.client.update_registered_model(model_name, registered_model_description + new_model_description) + + def register_best_models( + self, + experiment_name: str, + models_info: Dict[str, Dict[str, Any]], + metric: str = "Test/cumulative_reward", + mode: Literal["max", "min"] = "max", + ) -> Dict[str, ModelVersion] | None: + """Register the best model from an experiment. + + Args: + experiment_name (str): The name of the experiment. + models_info (Dict[str, Dict[str, Any]]): A dictionary containing models information + (path, description and tags). + metric (str): The metric to use to determine the best model. + Default to "Test/cumulative_reward". + mode (Literal["max", "min"]): The mode to use to determine the best model, either "max" or "min". + Defaulto to "max". + + Returns: + The registered models version if successful, otherwise None. + """ + if mode not in ["max", "min"]: + raise ValueError(f"Mode must be either 'max' or 'min', got {mode}") + + experiment_id = self.client.get_experiment_by_name(experiment_name).experiment_id + runs = self.client.search_runs(experiment_ids=[experiment_id]) + + if len(runs) == 0: + self.fabric.print(f"No runs found for experiment {experiment_name}") + return None + + best_run: Run | None = None + best_run_artifacts: Sequence[str] | Set[str] | None = None + models_path = [v["path"] for v in models_info.values()] + for run in runs: + run_artifacts = [x.path for x in self.client.list_artifacts(run.info.run_id) if x.path in models_path] + + if len(run_artifacts) == 0 or run.data.metrics.get(metric) is None: + # If we don't find the given model path, skip this run + # If the run has not the target metric, skip this run + continue + + if best_run is None: + best_run = run + best_run_artifacts = set(run_artifacts) + continue + if mode == "max": + if run.data.metrics[metric] > best_run.data.metrics[metric]: + best_run = run + else: + if run.data.metrics[metric] < best_run.data.metrics[metric]: + best_run = run + + if best_run is None: + self.fabric.print(f"No runs found for experiment {experiment_name} with the given metric") + return None + + models_version = {} + for k, v in models_info.items(): + if v["path"] in best_run_artifacts: + best_model_uri = f"runs:/{best_run.info.run_id}/{v['path']}" + models_version[k] = self.register_model( + model_location=best_model_uri, model_name=v["name"], tags=v["tags"], description=v["description"] + ) + + return models_version + + def download_model(self, model_name: str, version: int, output_path: str) -> None: + """Download the model with the given version to the given output path. + + Args: + model_name (str): The name of the model. + version (int): The version of the model. + output_path (str): The path to save the model to. + """ + artifact_uri = self.client.get_model_version_download_uri(model_name, version) + self.fabric.print(f"Downloading model {model_name} version {version} from {artifact_uri} to {output_path}") + if not os.path.exists(output_path): + self.fabric.print(f"Creating output path {output_path}") + os.makedirs(output_path) + mlflow.artifacts.download_artifacts(artifact_uri=artifact_uri, dst_path=output_path) + + @staticmethod + def _generate_description(description: str | None = None) -> str: + """Generate the description markdown template.""" + if description is None: + return "" + + return DESCRIPTION_MD_TEMPLATE.format(description) + + @staticmethod + def _get_author_and_date() -> str: + """Get the author and date markdown template.""" + author_and_date = f"### Author: {getpass.getuser()}\n" + author_and_date += f"### Date: {datetime.now().astimezone().strftime('%d/%m/%Y %H:%M:%S %Z')}\n" + + return author_and_date + + def _safe_get_stage(self, model_name: str, version: int) -> str | None: + """Get the stage of a model version. + + Args: + model_name (str): The name of the model. + version (int): The version of the model + + Returns: + The stage of the model version if it exists, otherwise None. + """ + try: + model_stage = self.client.get_model_version(model_name, version).current_stage + return model_stage + except RestException: + self.fabric.print(f"Model named {model_name} with version {version} does not exist") + return None diff --git a/sheeprl/utils/utils.py b/sheeprl/utils/utils.py index e70a9df6..9fa3428c 100644 --- a/sheeprl/utils/utils.py +++ b/sheeprl/utils/utils.py @@ -1,16 +1,26 @@ from __future__ import annotations +import copy import os -from typing import Optional, Sequence, Tuple, Union +from datetime import datetime +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union +import gymnasium as gym +import mlflow import rich.syntax import rich.tree import torch import torch.nn as nn +from lightning import Fabric +from lightning.fabric.wrappers import _FabricModule +from mlflow.models.model import ModelInfo from omegaconf import DictConfig, OmegaConf from pytorch_lightning.utilities import rank_zero_only from torch import Tensor +from sheeprl.utils.env import make_env +from sheeprl.utils.model_manager import MlflowModelManager + class dotdict(dict): """ @@ -157,3 +167,99 @@ def print_config( if cfg_save_path is not None: with open(os.path.join(os.getcwd(), "config_tree.txt"), "w") as fp: rich.print(tree, file=fp) + + +def unwrap_fabric(model: _FabricModule | nn.Module) -> nn.Module: + model = copy.deepcopy(model) + if isinstance(model, _FabricModule): + model = model.module + for name, child in model.named_children(): + setattr(model, name, unwrap_fabric(child)) + return model + + +def register_model(fabric: Fabric, log_models: Callable[[str], Dict[str, ModelInfo]], cfg: Dict[str, Any]): + tracking_uri = getattr(fabric.logger, "_tracking_uri", None) or os.getenv("MLFLOW_TRACKING_URI", None) + if tracking_uri is None: + raise ValueError( + "The tracking uri is not defined, use an mlflow logger with a tracking uri or define the " + "MLFLOW_TRACKING_URI environment variable." + ) + cfg_model_manager = cfg.model_manager + # Retrieve run_id, if None, create a new run + run_id = None + if len(fabric.loggers) > 0: + run_id = getattr(fabric.logger, "run_id", None) + mlflow.set_tracking_uri(tracking_uri) + experiment_id = None + run_name = None + if run_id is None: + experiment = mlflow.get_experiment_by_name(cfg.exp_name) + experiment_id = mlflow.create_experiment(cfg.exp_name) if experiment is None else experiment.experiment_id + run_name = f"{cfg.algo.name}_{cfg.env.id}_{datetime.today().strftime('%Y-%m-%d %H:%M:%S')}" + models_info = log_models(run_id, experiment_id, run_name) + model_manager = MlflowModelManager(fabric, tracking_uri) + if len(models_info) != len(cfg_model_manager.models): + raise RuntimeError( + f"The number of models of the {cfg.algo.name} agent must be equal to the number " + f"of models you want to register. {len(cfg_model_manager.models)} model registration " + f"configs are given, but the agent has {len(models_info)} models." + ) + for k, cfg_model in cfg_model_manager.models.items(): + model_manager.register_model( + models_info[k]._model_uri, cfg_model["model_name"], cfg_model["description"], cfg_model["tags"] + ) + + +def register_model_from_checkpoint( + fabric: Fabric, + cfg: Dict[str, Any], + state: Dict[str, Any], + log_models_from_checkpoint: Callable[ + [Fabric, gym.Env | gym.Wrapper, Dict[str, Any], Dict[str, Any]], Dict[str, ModelInfo] + ], +): + tracking_uri = getattr(cfg, "tracking_uri", None) or os.getenv("MLFLOW_TRACKING_URI", None) + if tracking_uri is None: + raise ValueError( + "The tracking uri is not defined, use an mlflow logger with a tracking uri or define the " + "MLFLOW_TRACKING_URI environment variable." + ) + # Creating the environment for agent instantiation + env = make_env( + cfg, + cfg.seed, + 0, + None, + "test", + vector_env_idx=0, + )() + observation_space = env.observation_space + if not isinstance(observation_space, gym.spaces.Dict): + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder == []: + raise RuntimeError( + "You should specify at least one CNN keys or MLP keys from the cli: " + "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" + ) + + mlflow.set_tracking_uri(tracking_uri) + # If the user does not specify the experiment, than, create a new experiment + if cfg.run.id is None and cfg.experiment.id is None: + cfg.experiment.id = mlflow.create_experiment(cfg.experiment.name) + # Log the models + models_info = log_models_from_checkpoint(fabric, env, cfg, state) + model_manager = MlflowModelManager(fabric, tracking_uri) + if not set(cfg.model_manager.models.keys()).issubset(models_info.keys()): + raise RuntimeError( + f"The models you want to register must be a subset of the models of the {cfg.algo.name} agent. " + f"{len(cfg.model_manager.models)} model registration " + f"configs are given, but the agent has {len(models_info)} models. " + f"\nModels specified in the configs: {cfg.model_manager.models.keys()}." + f"\nModels of the {cfg.algo.name} agent: {cfg.model_manager.models.keys()}." + ) + # Register the models specified in the configs + for k, cfg_model in cfg.model_manager.models.items(): + model_manager.register_model( + models_info[k]._model_uri, cfg_model["model_name"], cfg_model["description"], cfg_model["tags"] + ) diff --git a/sheeprl_model_manager.py b/sheeprl_model_manager.py new file mode 100644 index 00000000..0a601ba0 --- /dev/null +++ b/sheeprl_model_manager.py @@ -0,0 +1,4 @@ +from sheeprl.cli import registration + +if __name__ == "__main__": + registration()