-
Notifications
You must be signed in to change notification settings - Fork 151
/
Copy pathmlp.py
66 lines (53 loc) · 2.11 KB
/
mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, List, Optional, Union
import torch
from torch import nn
class MLP(nn.Module):
"""A multi-layer perceptron module.
This module is a sequence of linear layers plus activation functions.
The user can optionally add normalization and/or dropout to each of the layers.
Args:
in_dim (int): Input dimension.
out_dim (int): Output dimension.
hidden_dims (Optional[List[int]]): Output dimension for each hidden layer.
dropout (float): Probability for dropout layer.
activation (Callable[..., nn.Module]): Which activation
function to use. Supports module type or partial.
normalization (Optional[Callable[..., nn.Module]]): Which
normalization layer to use (None for no normalization).
Supports module type or partial.
Inputs:
x (Tensor): Tensor containing a batch of input sequences.
"""
def __init__(
self,
in_dim: int,
out_dim: int,
hidden_dims: Optional[Union[int, List[int]]] = None,
dropout: float = 0.5,
activation: Callable[..., nn.Module] = nn.ReLU,
normalization: Optional[Callable[..., nn.Module]] = None,
**kwargs,
) -> None:
super().__init__()
layers = nn.ModuleList()
if hidden_dims is None:
hidden_dims = []
if isinstance(hidden_dims, int):
hidden_dims = [hidden_dims]
for hidden_dim in hidden_dims:
layers.append(nn.Linear(in_dim, hidden_dim))
if normalization:
layers.append(normalization(hidden_dim))
layers.append(activation())
layers.append(nn.Dropout(dropout))
in_dim = hidden_dim
layers.append(nn.Linear(in_dim, out_dim))
self.model = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)