Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/feat/attention_layer'
Browse files Browse the repository at this point in the history
# Conflicts:
#	classifier.py
#	multitask_classifier.py
  • Loading branch information
lkaesberg committed Aug 8, 2023
2 parents 7a9b56e + 4d1547a commit e9da7fa
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 2 deletions.
5 changes: 4 additions & 1 deletion classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sklearn.metrics import classification_report, f1_score, recall_score, accuracy_score
from torch.utils.tensorboard import SummaryWriter

from layers.AttentionLayer import AttentionLayer
# change it with respect to the original model
from tokenizer import BertTokenizer
from bert import BertModel
Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(self, config):

# self.dropout = nn.Dropout(config.hidden_dropout_prob)
# linear layer to get logits
self.attention_layer = AttentionLayer(config.hidden_size)
self.linear_layer = nn.Linear(config.hidden_size, self.num_labels)

def forward(self, input_ids, attention_mask):
Expand All @@ -63,7 +65,8 @@ def forward(self, input_ids, attention_mask):

# No Dropout because it is the last layer before softmax, else worse performance
result = self.bert(input_ids, attention_mask)
return self.linear_layer(result['pooler_output'])
attention_result = self.attention_layer(result['last_hidden_state'])
return self.linear_layer(attention_result)


class SentimentDataset(Dataset):
Expand Down
21 changes: 21 additions & 0 deletions layers/AttentionLayer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
import torch.nn as nn


class AttentionLayer(nn.Module):
def __init__(self, input_size):
super(AttentionLayer, self).__init__()
self.W = nn.Linear(input_size, input_size)
self.v = nn.Linear(input_size, 1, bias=False)

def forward(self, embeddings):
# Apply linear transformation to the embeddings
transformed = torch.tanh(self.W(embeddings))

# Calculate attention weights
attention_weights = torch.softmax(self.v(transformed), dim=1)

# Apply attention weights to the embeddings
attended_embeddings = torch.sum(attention_weights * embeddings, dim=1)

return attended_embeddings
Empty file added layers/__init__.py
Empty file.
6 changes: 5 additions & 1 deletion multitask_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch.utils.tensorboard import SummaryWriter

from bert import BertModel
from layers.AttentionLayer import AttentionLayer
from optimizer import AdamW
from tqdm import tqdm

Expand Down Expand Up @@ -56,6 +57,8 @@ def __init__(self, config):
elif config.option == 'finetune':
param.requires_grad = True

self.attention_layer = AttentionLayer(config.hidden_size)

self.linear_layer = nn.Linear(config.hidden_size, N_SENTIMENT_CLASSES)

self.paraphrase_linear = nn.Linear(config.hidden_size, config.hidden_size)
Expand All @@ -69,7 +72,8 @@ def forward(self, input_ids, attention_mask):
# (e.g., by adding other layers).

result = self.bert(input_ids, attention_mask)
return result['pooler_output']
attention_result = self.attention_layer(result["last_hidden_state"])
return attention_result

def predict_sentiment(self, input_ids, attention_mask):
'''Given a batch of sentences, outputs logits for classifying sentiment.
Expand Down

0 comments on commit e9da7fa

Please sign in to comment.