Skip to content

Commit

Permalink
📝 add comments for task 1
Browse files Browse the repository at this point in the history
  • Loading branch information
lkaesberg committed Jul 30, 2023
1 parent 9893c66 commit de1845a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
9 changes: 7 additions & 2 deletions bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,14 @@ def attention(self, key: Tensor, query: Tensor, value: Tensor, attention_mask: T
# multiply the attention scores to the value and get back V'
# next, we need to concat multi-heads and recover the original shape [bs, seq_len, num_attention_heads * attention_head_size = hidden_size]

# key, query, value: [bs, num_attention_heads, seq_len, attention_head_size]
# attention_mask: [bs, 1, 1, seq_len]
# output: [bs, seq_len, num_attention_heads * attention_head_size = hidden_size]
# Note: the attention_mask is used to mask out the padding tokens
bs, h, seq_len, d_k = key.shape
S = query @ torch.transpose(key, 2, 3) + attention_mask

# normalize the scores
result = torch.softmax((S / math.sqrt(d_k)), 3) @ value
return result.transpose(1, 2).reshape(bs, seq_len, h * d_k)

Expand Down Expand Up @@ -94,7 +99,7 @@ def add_norm(self, input, output, dense_layer, dropout, ln_layer):
dropout: the dropout to be applied
ln_layer: the layer norm to be applied
"""
# Hint: Remember that BERT applies to the output of each sub-layer, before it is added to the sub-layer input and normalized
# apply layer norm to the output and skip connection
return ln_layer(input + dense_layer(dropout(output)))

def forward(self, hidden_states, attention_mask):
Expand All @@ -107,7 +112,7 @@ def forward(self, hidden_states, attention_mask):
3. a feed forward layer
4. a add-norm that takes the input and output of the feed forward layer
"""
### TODO
# apply multi-head attention
multi_head = self.self_attention(hidden_states, attention_mask)

add_norm_1 = self.add_norm(hidden_states, multi_head, self.attention_dense, self.attention_dropout,
Expand Down
7 changes: 5 additions & 2 deletions classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self, config):
param.requires_grad = True

# self.dropout = nn.Dropout(config.hidden_dropout_prob)
# linear layer to get logits
self.linear_layer = nn.Linear(config.hidden_size, self.num_labels)

def forward(self, input_ids, attention_mask):
Expand All @@ -59,6 +60,7 @@ def forward(self, input_ids, attention_mask):
# HINT: you should consider what is the appropriate output to return given that
# the training loop currently uses F.cross_entropy as the loss function.
# Cross entropy already has a softmax therefore this should be okay

# No Dropout because it is the last layer before softmax, else worse performance
result = self.bert(input_ids, attention_mask)
return self.linear_layer(result['pooler_output'])
Expand Down Expand Up @@ -266,6 +268,7 @@ def train(args):
optimizer = AdamW(model.parameters(), lr=lr)
best_dev_acc = 0

# Initialize the tensorboard writer
name = f"{datetime.now().strftime('%Y%m%d-%H%M%S')}-lr={lr}-optimizer={type(optimizer).__name__}"
writer = SummaryWriter(log_dir=args.logdir + "/classifier/" + name)

Expand All @@ -290,6 +293,7 @@ def train(args):
optimizer.step()

train_loss += loss.item()

writer.add_scalar("Loss/Minibatches", loss.item(), loss_idx_value)
loss_idx_value += 1
num_batches += 1
Expand Down Expand Up @@ -362,7 +366,6 @@ def get_args():
parser.add_argument("--logdir", type=str, default="logdir")
parser.add_argument("--dev_out", type=str, default="sst-dev-out.csv")
parser.add_argument("--test_out", type=str, default="sst-test-out.csv")


parser.add_argument("--batch_size", help='sst: 64 can fit a 12GB GPU', type=int, default=64)
parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
Expand All @@ -372,7 +375,7 @@ def get_args():

parser.add_argument("--lr", type=float, help="learning rate, default lr for 'pretrain': 1e-3, 'finetune': 1e-5",
default=1e-5 if args.option == 'finetune' else 1e-3)

args = parser.parse_args()
return args

Expand Down

0 comments on commit de1845a

Please sign in to comment.