Skip to content

Commit

Permalink
feat: added classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
lkaesberg committed May 25, 2023
1 parent 109b9dd commit 8912598
Showing 1 changed file with 56 additions and 47 deletions.
103 changes: 56 additions & 47 deletions classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import classification_report, f1_score, recall_score, accuracy_score

Expand All @@ -13,8 +14,9 @@
from optimizer import AdamW
from tqdm import tqdm

TQDM_DISABLE = False


TQDM_DISABLE=False
# fix the random seed
def seed_everything(seed=11711):
random.seed(seed)
Expand All @@ -25,13 +27,15 @@ def seed_everything(seed=11711):
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


class BertSentimentClassifier(torch.nn.Module):
'''
This module performs sentiment classification using BERT embeddings on the SST dataset.
In the SST dataset, there are 5 sentiment categories (from 0 - "negative" to 4 - "positive").
Thus, your forward() should return one logit for each of the 5 classes.
'''

def __init__(self, config):
super(BertSentimentClassifier, self).__init__()
self.num_labels = config.num_labels
Expand All @@ -45,17 +49,18 @@ def __init__(self, config):
param.requires_grad = True

### TODO
raise NotImplementedError

self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.linear_layer = nn.Linear(config.hidden_size, self.num_labels)
# raise NotImplementedError

def forward(self, input_ids, attention_mask):
'''Takes a batch of sentences and returns logits for sentiment classes'''
# The final BERT contextualized embedding is the hidden state of [CLS] token (the first token).
# HINT: you should consider what is the appropriate output to return given that
# the training loop currently uses F.cross_entropy as the loss function.
### TODO
raise NotImplementedError

# Cross entropy already has a softmax therefore this should be okay
result = self.bert(input_ids, attention_mask)
return self.linear_layer(self.dropout(result['pooler_output']))


class SentimentDataset(Dataset):
Expand All @@ -71,7 +76,6 @@ def __getitem__(self, idx):
return self.dataset[idx]

def pad_data(self, data):

sents = [x[0] for x in data]
labels = [x[1] for x in data]
sent_ids = [x[2] for x in data]
Expand All @@ -84,18 +88,19 @@ def pad_data(self, data):
return token_ids, attention_mask, labels, sents, sent_ids

def collate_fn(self, all_data):
token_ids, attention_mask, labels, sents, sent_ids= self.pad_data(all_data)
token_ids, attention_mask, labels, sents, sent_ids = self.pad_data(all_data)

batched_data = {
'token_ids': token_ids,
'attention_mask': attention_mask,
'labels': labels,
'sents': sents,
'sent_ids': sent_ids
}
'token_ids': token_ids,
'attention_mask': attention_mask,
'labels': labels,
'sents': sents,
'sent_ids': sent_ids
}

return batched_data


class SentimentTestDataset(Dataset):
def __init__(self, dataset, args):
self.dataset = dataset
Expand All @@ -109,7 +114,6 @@ def __getitem__(self, idx):
return self.dataset[idx]

def pad_data(self, data):

sents = [x[0] for x in data]
sent_ids = [x[1] for x in data]

Expand All @@ -120,54 +124,55 @@ def pad_data(self, data):
return token_ids, attention_mask, sents, sent_ids

def collate_fn(self, all_data):
token_ids, attention_mask, sents, sent_ids= self.pad_data(all_data)
token_ids, attention_mask, sents, sent_ids = self.pad_data(all_data)

batched_data = {
'token_ids': token_ids,
'attention_mask': attention_mask,
'sents': sents,
'sent_ids': sent_ids
}
'token_ids': token_ids,
'attention_mask': attention_mask,
'sents': sents,
'sent_ids': sent_ids
}

return batched_data


# Load the data: a list of (sentence, label)
def load_data(filename, flag='train'):
num_labels = {}
data = []
if flag == 'test':
with open(filename, 'r') as fp:
for record in csv.DictReader(fp,delimiter = '\t'):
for record in csv.DictReader(fp, delimiter='\t'):
sent = record['sentence'].lower().strip()
sent_id = record['id'].lower().strip()
data.append((sent,sent_id))
data.append((sent, sent_id))
else:
with open(filename, 'r') as fp:
for record in csv.DictReader(fp,delimiter = '\t'):
for record in csv.DictReader(fp, delimiter='\t'):
sent = record['sentence'].lower().strip()
sent_id = record['id'].lower().strip()
label = int(record['sentiment'].strip())
if label not in num_labels:
num_labels[label] = len(num_labels)
data.append((sent, label,sent_id))
data.append((sent, label, sent_id))
print(f"load {len(data)} data from {filename}")

if flag == 'train':
return data, len(num_labels)
else:
return data


# Evaluate the model for accuracy.
def model_eval(dataloader, model, device):
model.eval() # switch to eval model, will turn off randomness like dropout
model.eval() # switch to eval model, will turn off randomness like dropout
y_true = []
y_pred = []
sents = []
sent_ids = []
for step, batch in enumerate(tqdm(dataloader, desc=f'eval', disable=TQDM_DISABLE)):
b_ids, b_mask, b_labels, b_sents, b_sent_ids = batch['token_ids'],batch['attention_mask'], \
batch['labels'], batch['sents'], batch['sent_ids']

b_ids, b_mask, b_labels, b_sents, b_sent_ids = batch['token_ids'], batch['attention_mask'], \
batch['labels'], batch['sents'], batch['sent_ids']

b_ids = b_ids.to(device)
b_mask = b_mask.to(device)
Expand All @@ -189,14 +194,13 @@ def model_eval(dataloader, model, device):


def model_test_eval(dataloader, model, device):
model.eval() # switch to eval model, will turn off randomness like dropout
model.eval() # switch to eval model, will turn off randomness like dropout
y_pred = []
sents = []
sent_ids = []
for step, batch in enumerate(tqdm(dataloader, desc=f'eval', disable=TQDM_DISABLE)):
b_ids, b_mask, b_sents, b_sent_ids = batch['token_ids'],batch['attention_mask'], \
batch['sents'], batch['sent_ids']

b_ids, b_mask, b_sents, b_sent_ids = batch['token_ids'], batch['attention_mask'], \
batch['sents'], batch['sent_ids']

b_ids = b_ids.to(device)
b_mask = b_mask.to(device)
Expand Down Expand Up @@ -283,14 +287,15 @@ def train(args):

train_loss = train_loss / (num_batches)

train_acc, train_f1, *_ = model_eval(train_dataloader, model, device)
train_acc, train_f1, *_ = model_eval(train_dataloader, model, device)
dev_acc, dev_f1, *_ = model_eval(dev_dataloader, model, device)

if dev_acc > best_dev_acc:
best_dev_acc = dev_acc
save_model(model, optimizer, args, config, args.filepath)

print(f"Epoch {epoch}: train loss :: {train_loss :.3f}, train acc :: {train_acc :.3f}, dev acc :: {dev_acc :.3f}")
print(
f"Epoch {epoch}: train loss :: {train_loss :.3f}, train acc :: {train_acc :.3f}, dev acc :: {dev_acc :.3f}")


def test(args):
Expand All @@ -302,29 +307,33 @@ def test(args):
model.load_state_dict(saved['model'])
model = model.to(device)
print(f"load model from {args.filepath}")

dev_data = load_data(args.dev, 'valid')
dev_dataset = SentimentDataset(dev_data, args)
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=dev_dataset.collate_fn)
dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=args.batch_size,
collate_fn=dev_dataset.collate_fn)

test_data = load_data(args.test, 'test')
test_dataset = SentimentTestDataset(test_data, args)
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=test_dataset.collate_fn)

test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size,
collate_fn=test_dataset.collate_fn)

dev_acc, dev_f1, dev_pred, dev_true, dev_sents, dev_sent_ids = model_eval(dev_dataloader, model, device)
print('DONE DEV')
test_pred, test_sents, test_sent_ids = model_test_eval(test_dataloader, model, device)
print('DONE Test')
with open(args.dev_out, "w+") as f:
print(f"dev acc :: {dev_acc :.3f}")
f.write(f"id \t Predicted_Sentiment \n")
for p, s in zip(dev_sent_ids,dev_pred ):
for p, s in zip(dev_sent_ids, dev_pred):
f.write(f"{p} , {s} \n")

with open(args.test_out, "w+") as f:
f.write(f"id \t Predicted_Sentiment \n")
for p, s in zip(test_sent_ids,test_pred ):
for p, s in zip(test_sent_ids, test_pred):
f.write(f"{p} , {s} \n")


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=11711)
Expand All @@ -335,7 +344,6 @@ def get_args():
parser.add_argument("--use_gpu", action='store_true')
parser.add_argument("--dev_out", type=str, default="cfimdb-dev-output.txt")
parser.add_argument("--test_out", type=str, default="cfimdb-test-output.txt")


parser.add_argument("--batch_size", help='sst: 64, cfimdb: 8 can fit a 12GB GPU', type=int, default=8)
parser.add_argument("--hidden_dropout_prob", type=float, default=0.3)
Expand All @@ -345,10 +353,11 @@ def get_args():
args = parser.parse_args()
return args


if __name__ == "__main__":
args = get_args()
seed_everything(args.seed)
#args.filepath = f'{args.option}-{args.epochs}-{args.lr}.pt'
# args.filepath = f'{args.option}-{args.epochs}-{args.lr}.pt'

print('Training Sentiment Classifier on SST...')
config = SimpleNamespace(
Expand All @@ -362,8 +371,8 @@ def get_args():
dev='data/ids-sst-dev.csv',
test='data/ids-sst-test-student.csv',
option=args.option,
dev_out = 'predictions/'+args.option+'-sst-dev-out.csv',
test_out = 'predictions/'+args.option+'-sst-test-out.csv'
dev_out='predictions/' + args.option + '-sst-dev-out.csv',
test_out='predictions/' + args.option + '-sst-test-out.csv'
)

train(config)
Expand All @@ -383,8 +392,8 @@ def get_args():
dev='data/ids-cfimdb-dev.csv',
test='data/ids-cfimdb-test-student.csv',
option=args.option,
dev_out = 'predictions/'+args.option+'-cfimdb-dev-out.csv',
test_out = 'predictions/'+args.option+'-cfimdb-test-out.csv'
dev_out='predictions/' + args.option + '-cfimdb-dev-out.csv',
test_out='predictions/' + args.option + '-cfimdb-test-out.csv'
)

train(config)
Expand Down

0 comments on commit 8912598

Please sign in to comment.