Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/dv3 layer norm #257

Merged
merged 2 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
)

from sheeprl.algos.dreamer_v2.utils import compute_stochastic_state, init_weights
from sheeprl.models.models import CNN, MLP, DeCNN, LayerNormGRUCell, MultiDecoder, MultiEncoder
from sheeprl.models.models import CNN, MLP, DeCNN, LayerNormChannelLast, LayerNormGRUCell, MultiDecoder, MultiEncoder
from sheeprl.utils.distribution import TruncatedNormal
from sheeprl.utils.fabric import get_single_device_fabric
from sheeprl.utils.model import LayerNormChannelLast, ModuleType, cnn_forward
from sheeprl.utils.model import ModuleType, cnn_forward


class CNNEncoder(nn.Module):
Expand Down
39 changes: 24 additions & 15 deletions sheeprl/algos/dreamer_v3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,18 @@
from sheeprl.algos.dreamer_v2.agent import WorldModel
from sheeprl.algos.dreamer_v2.utils import compute_stochastic_state
from sheeprl.algos.dreamer_v3.utils import init_weights, uniform_init_weights
from sheeprl.models.models import CNN, MLP, DeCNN, LayerNormGRUCell, MultiDecoder, MultiEncoder
from sheeprl.models.models import (
CNN,
MLP,
DeCNN,
LayerNorm,
LayerNormChannelLast,
LayerNormGRUCell,
MultiDecoder,
MultiEncoder,
)
from sheeprl.utils.fabric import get_single_device_fabric
from sheeprl.utils.model import LayerNormChannelLastFP32, LayerNormFP32, ModuleType, cnn_forward
from sheeprl.utils.model import ModuleType, cnn_forward
from sheeprl.utils.utils import symlog


Expand All @@ -44,7 +53,7 @@ class CNNEncoder(nn.Module):
channels_multiplier (int): the multiplier for the output channels. Given the 4 stages, the 4 output channels
will be [1, 2, 4, 8] * `channels_multiplier`.
layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection.
Defaults to LayerNormChannelLastFP32.
Defaults to LayerNormChannelLast.
layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm.
Default to {"eps": 1e-3}.
activation (ModuleType, optional): the activation function.
Expand All @@ -58,7 +67,7 @@ def __init__(
input_channels: Sequence[int],
image_size: Tuple[int, int],
channels_multiplier: int,
layer_norm_cls: Callable[..., nn.Module] = LayerNormChannelLastFP32,
layer_norm_cls: Callable[..., nn.Module] = LayerNormChannelLast,
layer_norm_kw: Dict[str, Any] = {"eps": 1e-3},
activation: ModuleType = nn.SiLU,
stages: int = 4,
Expand Down Expand Up @@ -102,7 +111,7 @@ class MLPEncoder(nn.Module):
dense_units (int, optional): the dimension of every mlp.
Defaults to 512.
layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection.
Defaults to LayerNormFP32.
Defaults to LayerNorm.
layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm.
Default to {"eps": 1e-3}.
activation (ModuleType, optional): the activation function after every layer.
Expand All @@ -117,7 +126,7 @@ def __init__(
input_dims: Sequence[int],
mlp_layers: int = 4,
dense_units: int = 512,
layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32,
layer_norm_cls: Callable[..., nn.Module] = LayerNorm,
layer_norm_kw: Dict[str, Any] = {"eps": 1e-3},
activation: ModuleType = nn.SiLU,
symlog_inputs: bool = True,
Expand Down Expand Up @@ -162,7 +171,7 @@ class CNNDecoder(nn.Module):
activation (nn.Module, optional): the activation function.
Defaults to nn.SiLU.
layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection.
Defaults to LayerNormChannelLastFP32.
Defaults to LayerNormChannelLast.
layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm.
Default to {"eps": 1e-3}.
stages (int): how many stages in the CNN decoder.
Expand All @@ -177,7 +186,7 @@ def __init__(
cnn_encoder_output_dim: int,
image_size: Tuple[int, int],
activation: nn.Module = nn.SiLU,
layer_norm_cls: Callable[..., nn.Module] = LayerNormChannelLastFP32,
layer_norm_cls: Callable[..., nn.Module] = LayerNormChannelLast,
layer_norm_kw: Dict[str, Any] = {"eps": 1e-3},
stages: int = 4,
) -> None:
Expand Down Expand Up @@ -232,7 +241,7 @@ class MLPDecoder(nn.Module):
dense_units (int, optional): the dimension of every mlp.
Defaults to 512.
layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection.
Defaults to LayerNormFP32.
Defaults to LayerNorm.
layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm.
Default to {"eps": 1e-3}.
activation (ModuleType, optional): the activation function after every layer.
Expand All @@ -247,7 +256,7 @@ def __init__(
mlp_layers: int = 4,
dense_units: int = 512,
activation: ModuleType = nn.SiLU,
layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32,
layer_norm_cls: Callable[..., nn.Module] = LayerNorm,
layer_norm_kw: Dict[str, Any] = {"eps": 1e-3},
) -> None:
super().__init__()
Expand Down Expand Up @@ -282,7 +291,7 @@ class RecurrentModel(nn.Module):
activation_fn (nn.Module): the activation function.
Default to SiLU.
layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection.
Defaults to LayerNormFP32.
Defaults to LayerNorm.
layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm.
Default to {"eps": 1e-3}.
"""
Expand All @@ -293,7 +302,7 @@ def __init__(
recurrent_state_size: int,
dense_units: int,
activation_fn: nn.Module = nn.SiLU,
layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32,
layer_norm_cls: Callable[..., nn.Module] = LayerNorm,
layer_norm_kw: Dict[str, Any] = {"eps": 1e-3},
) -> None:
super().__init__()
Expand Down Expand Up @@ -710,7 +719,7 @@ class Actor(nn.Module):
mlp_layers (int): the number of dense layers.
Default to 5.
layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection.
Defaults to LayerNormFP32.
Defaults to LayerNorm.
layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm.
Default to {"eps": 1e-3}.
unimix: (float, optional): the percentage of uniform distribution to inject into the categorical
Expand All @@ -734,7 +743,7 @@ def __init__(
dense_units: int = 1024,
activation: nn.Module = nn.SiLU,
mlp_layers: int = 5,
layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32,
layer_norm_cls: Callable[..., nn.Module] = LayerNorm,
layer_norm_kw: Dict[str, Any] = {"eps": 1e-3},
unimix: float = 0.01,
action_clip: float = 1.0,
Expand Down Expand Up @@ -853,7 +862,7 @@ def __init__(
dense_units: int = 1024,
activation: nn.Module = nn.SiLU,
mlp_layers: int = 5,
layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32,
layer_norm_cls: Callable[..., nn.Module] = LayerNorm,
layer_norm_kw: Dict[str, Any] = {"eps": 1e-3},
unimix: float = 0.01,
action_clip: float = 1.0,
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/configs/algo/dreamer_v3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ mlp_keys:

# Model related parameters
cnn_layer_norm:
cls: sheeprl.utils.model.LayerNormChannelLastFP32
cls: sheeprl.models.models.LayerNormChannelLast
kw:
eps: 1e-3
mlp_layer_norm:
cls: sheeprl.utils.model.LayerNormFP32
cls: sheeprl.models.models.LayerNorm
kw:
eps: 1e-3
dense_units: 1024
Expand Down
21 changes: 21 additions & 0 deletions sheeprl/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,3 +502,24 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:
if self.mlp_decoder is not None:
reconstructed_obs.update(self.mlp_decoder(x))
return reconstructed_obs


class LayerNormChannelLast(nn.LayerNorm):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x: Tensor) -> Tensor:
if x.dim() != 4:
raise ValueError(f"Input tensor must be 4D (NCHW), received {len(x.shape)}D instead: {x.shape}")
input_dtype = x.dtype
x = x.permute(0, 2, 3, 1)
x = super().forward(x)
x = x.permute(0, 3, 1, 2)
return x.to(input_dtype)


class LayerNorm(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
input_dtype = x.dtype
out = super().forward(x)
return out.to(input_dtype)
29 changes: 0 additions & 29 deletions sheeprl/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,32 +221,3 @@ def cnn_forward(
flatten_input = input.reshape(-1, *input_dim)
model_out = model(flatten_input)
return model_out.reshape(*batch_shapes, *output_dim)


class LayerNormChannelLast(nn.LayerNorm):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x: Tensor) -> Tensor:
if x.dim() != 4:
raise ValueError(f"Input tensor must be 4D (NCHW), received {len(x.shape)}D instead: {x.shape}")
x = x.permute(0, 2, 3, 1)
x = super().forward(x)
x = x.permute(0, 3, 1, 2)
return x


class LayerNormChannelLastFP32(LayerNormChannelLast):
def forward(self, x: Tensor) -> Tensor:
input_dtype = x.dtype
x = x.to(torch.float32)
out = super().forward(x)
return out.to(input_dtype)


class LayerNormFP32(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
input_dtype = x.dtype
x = x.to(torch.float32)
out = super().forward(x)
return out.to(input_dtype)
12 changes: 6 additions & 6 deletions tests/test_algos/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,8 @@ def test_dreamer_v3(standard_args, env_id, start_time):
"algo.cnn_keys.decoder=[rgb]",
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
"algo.mlp_layer_norm.cls=torch.nn.LayerNorm",
"algo.cnn_layer_norm.cls=sheeprl.utils.model.LayerNormChannelLast",
"algo.mlp_layer_norm.cls=sheeprl.models.models.LayerNorm",
"algo.cnn_layer_norm.cls=sheeprl.models.models.LayerNormChannelLast",
]

with mock.patch.object(sys, "argv", args):
Expand Down Expand Up @@ -513,8 +513,8 @@ def test_p2e_dv3(standard_args, env_id, start_time):
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
"checkpoint.save_last=True",
"algo.mlp_layer_norm.cls=torch.nn.LayerNorm",
"algo.cnn_layer_norm.cls=sheeprl.utils.model.LayerNormChannelLast",
"algo.mlp_layer_norm.cls=sheeprl.models.models.LayerNorm",
"algo.cnn_layer_norm.cls=sheeprl.models.models.LayerNormChannelLast",
]

with mock.patch.object(sys, "argv", args):
Expand Down Expand Up @@ -557,8 +557,8 @@ def test_p2e_dv3(standard_args, env_id, start_time):
"algo.cnn_keys.decoder=[rgb]",
"algo.mlp_keys.encoder=[state]",
"algo.mlp_keys.decoder=[state]",
"algo.mlp_layer_norm.cls=torch.nn.LayerNorm",
"algo.cnn_layer_norm.cls=sheeprl.utils.model.LayerNormChannelLast",
"algo.mlp_layer_norm.cls=sheeprl.models.models.LayerNorm",
"algo.cnn_layer_norm.cls=sheeprl.models.models.LayerNormChannelLast",
]
with mock.patch.object(sys, "argv", args):
run()
Expand Down
Loading