Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/dv3 layer norm #257

Merged
merged 2 commits into from
Apr 5, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Preserve input dtype after LayerNorm (pytorch/pytorch#66707 (comment))
  • Loading branch information
belerico committed Apr 5, 2024
commit d03f756d453683c929b0cdda3cbda60fd2dce73b
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v2/agent.py
Original file line number Diff line number Diff line change
@@ -22,10 +22,10 @@
)

from sheeprl.algos.dreamer_v2.utils import compute_stochastic_state, init_weights
from sheeprl.models.models import CNN, MLP, DeCNN, LayerNormGRUCell, MultiDecoder, MultiEncoder
from sheeprl.models.models import CNN, MLP, DeCNN, LayerNormChannelLast, LayerNormGRUCell, MultiDecoder, MultiEncoder
from sheeprl.utils.distribution import TruncatedNormal
from sheeprl.utils.fabric import get_single_device_fabric
from sheeprl.utils.model import LayerNormChannelLast, ModuleType, cnn_forward
from sheeprl.utils.model import ModuleType, cnn_forward


class CNNEncoder(nn.Module):
39 changes: 24 additions & 15 deletions sheeprl/algos/dreamer_v3/agent.py
Original file line number Diff line number Diff line change
@@ -24,9 +24,18 @@
from sheeprl.algos.dreamer_v2.agent import WorldModel
from sheeprl.algos.dreamer_v2.utils import compute_stochastic_state
from sheeprl.algos.dreamer_v3.utils import init_weights, uniform_init_weights
from sheeprl.models.models import CNN, MLP, DeCNN, LayerNormGRUCell, MultiDecoder, MultiEncoder
from sheeprl.models.models import (
CNN,
MLP,
DeCNN,
LayerNorm,
LayerNormChannelLast,
LayerNormGRUCell,
MultiDecoder,
MultiEncoder,
)
from sheeprl.utils.fabric import get_single_device_fabric
from sheeprl.utils.model import LayerNormChannelLastFP32, LayerNormFP32, ModuleType, cnn_forward
from sheeprl.utils.model import ModuleType, cnn_forward
from sheeprl.utils.utils import symlog


@@ -44,7 +53,7 @@ class CNNEncoder(nn.Module):
channels_multiplier (int): the multiplier for the output channels. Given the 4 stages, the 4 output channels
will be [1, 2, 4, 8] * `channels_multiplier`.
layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection.
Defaults to LayerNormChannelLastFP32.
Defaults to LayerNormChannelLast.
layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm.
Default to {"eps": 1e-3}.
activation (ModuleType, optional): the activation function.
@@ -58,7 +67,7 @@ def __init__(
input_channels: Sequence[int],
image_size: Tuple[int, int],
channels_multiplier: int,
layer_norm_cls: Callable[..., nn.Module] = LayerNormChannelLastFP32,
layer_norm_cls: Callable[..., nn.Module] = LayerNormChannelLast,
layer_norm_kw: Dict[str, Any] = {"eps": 1e-3},
activation: ModuleType = nn.SiLU,
stages: int = 4,
@@ -102,7 +111,7 @@ class MLPEncoder(nn.Module):
dense_units (int, optional): the dimension of every mlp.
Defaults to 512.
layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection.
Defaults to LayerNormFP32.
Defaults to LayerNorm.
layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm.
Default to {"eps": 1e-3}.
activation (ModuleType, optional): the activation function after every layer.
@@ -117,7 +126,7 @@ def __init__(
input_dims: Sequence[int],
mlp_layers: int = 4,
dense_units: int = 512,
layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32,
layer_norm_cls: Callable[..., nn.Module] = LayerNorm,
layer_norm_kw: Dict[str, Any] = {"eps": 1e-3},
activation: ModuleType = nn.SiLU,
symlog_inputs: bool = True,
@@ -162,7 +171,7 @@ class CNNDecoder(nn.Module):
activation (nn.Module, optional): the activation function.
Defaults to nn.SiLU.
layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection.
Defaults to LayerNormChannelLastFP32.
Defaults to LayerNormChannelLast.
layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm.
Default to {"eps": 1e-3}.
stages (int): how many stages in the CNN decoder.
@@ -177,7 +186,7 @@ def __init__(
cnn_encoder_output_dim: int,
image_size: Tuple[int, int],
activation: nn.Module = nn.SiLU,
layer_norm_cls: Callable[..., nn.Module] = LayerNormChannelLastFP32,
layer_norm_cls: Callable[..., nn.Module] = LayerNormChannelLast,
layer_norm_kw: Dict[str, Any] = {"eps": 1e-3},
stages: int = 4,
) -> None:
@@ -232,7 +241,7 @@ class MLPDecoder(nn.Module):
dense_units (int, optional): the dimension of every mlp.
Defaults to 512.
layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection.
Defaults to LayerNormFP32.
Defaults to LayerNorm.
layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm.
Default to {"eps": 1e-3}.
activation (ModuleType, optional): the activation function after every layer.
@@ -247,7 +256,7 @@ def __init__(
mlp_layers: int = 4,
dense_units: int = 512,
activation: ModuleType = nn.SiLU,
layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32,
layer_norm_cls: Callable[..., nn.Module] = LayerNorm,
layer_norm_kw: Dict[str, Any] = {"eps": 1e-3},
) -> None:
super().__init__()
@@ -282,7 +291,7 @@ class RecurrentModel(nn.Module):
activation_fn (nn.Module): the activation function.
Default to SiLU.
layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection.
Defaults to LayerNormFP32.
Defaults to LayerNorm.
layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm.
Default to {"eps": 1e-3}.
"""
@@ -293,7 +302,7 @@ def __init__(
recurrent_state_size: int,
dense_units: int,
activation_fn: nn.Module = nn.SiLU,
layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32,
layer_norm_cls: Callable[..., nn.Module] = LayerNorm,
layer_norm_kw: Dict[str, Any] = {"eps": 1e-3},
) -> None:
super().__init__()
@@ -710,7 +719,7 @@ class Actor(nn.Module):
mlp_layers (int): the number of dense layers.
Default to 5.
layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection.
Defaults to LayerNormFP32.
Defaults to LayerNorm.
layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm.
Default to {"eps": 1e-3}.
unimix: (float, optional): the percentage of uniform distribution to inject into the categorical
@@ -734,7 +743,7 @@ def __init__(
dense_units: int = 1024,
activation: nn.Module = nn.SiLU,
mlp_layers: int = 5,
layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32,
layer_norm_cls: Callable[..., nn.Module] = LayerNorm,
layer_norm_kw: Dict[str, Any] = {"eps": 1e-3},
unimix: float = 0.01,
action_clip: float = 1.0,
@@ -853,7 +862,7 @@ def __init__(
dense_units: int = 1024,
activation: nn.Module = nn.SiLU,
mlp_layers: int = 5,
layer_norm_cls: Callable[..., nn.Module] = LayerNormFP32,
layer_norm_cls: Callable[..., nn.Module] = LayerNorm,
layer_norm_kw: Dict[str, Any] = {"eps": 1e-3},
unimix: float = 0.01,
action_clip: float = 1.0,
4 changes: 2 additions & 2 deletions sheeprl/configs/algo/dreamer_v3.yaml
Original file line number Diff line number Diff line change
@@ -26,11 +26,11 @@ mlp_keys:

# Model related parameters
cnn_layer_norm:
cls: sheeprl.utils.model.LayerNormChannelLastFP32
cls: sheeprl.utils.model.LayerNormChannelLast
kw:
eps: 1e-3
mlp_layer_norm:
cls: sheeprl.utils.model.LayerNormFP32
cls: sheeprl.utils.model.LayerNorm
kw:
eps: 1e-3
dense_units: 1024
21 changes: 21 additions & 0 deletions sheeprl/models/models.py
Original file line number Diff line number Diff line change
@@ -502,3 +502,24 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:
if self.mlp_decoder is not None:
reconstructed_obs.update(self.mlp_decoder(x))
return reconstructed_obs


class LayerNormChannelLast(nn.LayerNorm):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x: Tensor) -> Tensor:
if x.dim() != 4:
raise ValueError(f"Input tensor must be 4D (NCHW), received {len(x.shape)}D instead: {x.shape}")
input_dtype = x.dtype
x = x.permute(0, 2, 3, 1)
x = super().forward(x)
x = x.permute(0, 3, 1, 2)
return x.to(input_dtype)


class LayerNorm(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
input_dtype = x.dtype
out = super().forward(x)
return out.to(input_dtype)
29 changes: 0 additions & 29 deletions sheeprl/utils/model.py
Original file line number Diff line number Diff line change
@@ -221,32 +221,3 @@ def cnn_forward(
flatten_input = input.reshape(-1, *input_dim)
model_out = model(flatten_input)
return model_out.reshape(*batch_shapes, *output_dim)


class LayerNormChannelLast(nn.LayerNorm):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x: Tensor) -> Tensor:
if x.dim() != 4:
raise ValueError(f"Input tensor must be 4D (NCHW), received {len(x.shape)}D instead: {x.shape}")
x = x.permute(0, 2, 3, 1)
x = super().forward(x)
x = x.permute(0, 3, 1, 2)
return x


class LayerNormChannelLastFP32(LayerNormChannelLast):
def forward(self, x: Tensor) -> Tensor:
input_dtype = x.dtype
x = x.to(torch.float32)
out = super().forward(x)
return out.to(input_dtype)


class LayerNormFP32(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
input_dtype = x.dtype
x = x.to(torch.float32)
out = super().forward(x)
return out.to(input_dtype)