Skip to content

Commit

Permalink
add single encoder layer
Browse files Browse the repository at this point in the history
  • Loading branch information
andyj29 committed Apr 11, 2023
1 parent f78d36b commit a87eb00
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 19 deletions.
3 changes: 2 additions & 1 deletion transformer/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .attention import MultiHeadAttention
from .ffn import FeedForwardNetwork
from .embedding import Embeddings
from .position import PositionalEncoding
from .position import PositionalEncoding
from .residual import ResidualConnection
18 changes: 9 additions & 9 deletions transformer/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,35 @@


class MultiHeadAttention(nn.Module):
def __init__(self, cfg):
def __init__(self, d_model, n_head, dropout):
super(MultiHeadAttention, self).__init__()
# embedding dimension must be divisible by number of heads
assert cfg.emb_dim % cfg.n_head == 0
assert d_model % n_head == 0

# key, query, value projections for all heads
self.c_attn = nn.Linear(cfg.emb_dim, 3 * cfg.emb_dim)
self.c_attn = nn.Linear(d_model, 3 * d_model)
# output projection
self.c_proj = nn.Linear(cfg.emb_dim, cfg.emb_dim)
self.c_proj = nn.Linear(d_model, d_model)

self.n_head = cfg.n_head
self.emb_dim = cfg.emb_dim
self.n_head = n_head
self.d_model = d_model

# regularization
self.dropout = cfg.dropout
self.dropout = dropout
self.resid_dropout = nn.Dropout(self.dropout)


def forward(self, x, mask=None):
# B: batch size, S: sequence length, E: embedding dimension
B, S, E = x.size()
# pull out the query, key, value from the concatenated projection
q, k, v = self.c_attn(x).split(self.emb_dim, dim=2)
q, k, v = self.c_attn(x).split(self.d_model, dim=2)
# split heads and transpose to (B, n_head, S, E // n_head)
q = q.view(B, S, self.n_head, E // self.n_head).transpose(1, 2)
k = k.view(B, S, self.n_head, E // self.n_head).transpose(1, 2)
v = v.view(B, S, self.n_head, E // self.n_head).transpose(1, 2)
# apply attention
y = F.scaled_dot_product_attention(q, k, v, dropout=self.attn_dropout, is_causal=mask)
y = F.scaled_dot_product_attention(q, k, v, dropout=self.dropout, is_causal=mask)
# concatenate heads and transpose to (B, S, E)
y = y.transpose(1, 2).contiguous().view(B, S, E)
# apply drop out to final linear projection
Expand Down
11 changes: 6 additions & 5 deletions transformer/common/ffn.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import torch.nn as nn

class FeedForwardNetwork(nn.Module):
def __init__(self, cfg):
def __init__(self, d_model, d_ffn_hidden, dropout):
super(FeedForwardNetwork, self).__init__()
self.layers = nn.ModuleList(
[nn.Linear(cfg.emb_dim, cfg.ffn_dim),
nn.ReLU(),
nn.Dropout(cfg.dropout),
nn.Linear(cfg.ffn_dim, cfg.emb_dim)]
[ nn.Linear(d_model, d_ffn_hidden),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ffn_hidden, d_model)
]
)

def forward(self, x):
Expand Down
2 changes: 1 addition & 1 deletion transformer/common/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.autograd import Variable

class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout, max_len=5000):
def __init__(self, d_model, dropout, max_len):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)

Expand Down
43 changes: 40 additions & 3 deletions transformer/encoder/encoder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,47 @@
import torch.nn as nn
import torch.nn.functional as F
from transformer.common import MultiHeadAttention, FeedForwardNetwork

from transformer.common import \
(
MultiHeadAttention,
FeedForwardNetwork,
ResidualConnection,
Embeddings,
PositionalEncoding
)


class EncoderLayer(nn.Module):
def __init__(self, attn, ffn, dropout):
def __init__(self, d_model, n_head, d_ffn_hidden, dropout=0.1):
super(EncoderLayer, self).__init__()
self.attn = MultiHeadAttention(n_head, dropout)
self.ffn = FeedForwardNetwork(d_model, d_ffn_hidden)
self.residual = nn.ModuleList([
ResidualConnection(self.attn),
ResidualConnection(self.ffn),
])

def forward(self, x):
for layer in self.residual:
x = layer(x)

return x



class Encoder(nn.Module):
def __init__(self, d_model, n_stack, n_head, d_ffn_hidden, corpus_len, dropout):
self.layers = nn.ModuleList([
EncoderLayer(d_model, n_head, d_ffn_hidden, dropout)
for _ in range(n_stack)
]
)
self.emb = Embeddings(d_model)
self.pos = PositionalEncoding(d_model, dropout, max_len=corpus_len)
self.dropout = nn.Dropout(dropout)
self.d_model = d_model


def forward(self, x):
pass


0 comments on commit a87eb00

Please sign in to comment.