Skip to content

Commit

Permalink
📝 added comments to bert, multitask_classifier and datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
lkaesberg committed Sep 2, 2023
1 parent 33cc713 commit 6c56d1b
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
5 changes: 4 additions & 1 deletion bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,16 +218,19 @@ def embed(self, input_ids, additional_input=False):
# Convert input_ids to tokens using the BERT tokenizer
tokens = self.tokenizer.convert_ids_to_tokens(sequence_id.tolist())

# Convert tokens to strings
# Convert tokens to strings and remove special tokens
token_strings = [
token for token in tokens if token not in ["[PAD]", "[CLS]", "[SEP]"]
]
input_string = self.tokenizer.convert_tokens_to_string(token_strings)
# Process the input string with spaCy
tokenized = self.nlp(input_string)
pos_tags = [0] * len(tokens)
ner_tags = [0] * len(tokens)
counter = -1
# Loop through the tokens and add the POS and NER tags
for i in range(len(token_strings)):
# Add same POS and NER tag to all subwords of a word
if not token_strings[i].startswith("##"):
counter += 1
pos_tags[i + 1] = self.pos_tag_vocab.get(tokenized[counter].tag_, 0)
Expand Down
2 changes: 2 additions & 0 deletions datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __len__(self):
return self.override_length

def __getitem__(self, idx):
# If we're overriding the length, we want to sample randomly from the dataset
if self.override_length is not None:
return random.choice(self.dataset)

Expand Down Expand Up @@ -139,6 +140,7 @@ def __len__(self):
return self.override_length

def __getitem__(self, idx):
# If we're overriding the length, we want to sample randomly from the dataset
if self.override_length is not None:
return random.choice(self.dataset)

Expand Down
8 changes: 4 additions & 4 deletions multitask_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(self, config):
elif config.option == "finetune":
param.requires_grad = True

# Freeze the layers if unfreeze_interval is set
if config.unfreeze_interval:
for name, param in self.bert.named_parameters():
if not name.startswith("bert_layers"):
Expand Down Expand Up @@ -213,6 +214,7 @@ def train_multitask(args):
if isinstance(args, dict):
args = SimpleNamespace(**args)

# Determine which datasets to train on
train_all_datasets = True
n_datasets = args.sst + args.sts + args.para
if args.sst or args.sts or args.para:
Expand All @@ -229,14 +231,15 @@ def train_multitask(args):
args.sst_dev, args.para_dev, args.sts_dev, split="train"
)

# Generate datasets and dataloaders for training and testing
sst_train_dataloader = None
sst_dev_dataloader = None
para_train_dataloader = None
para_dev_dataloader = None
sts_train_dataloader = None
sts_dev_dataloader = None
total_num_batches = 0
# if train_all_datasets or args.sst:

sst_train_data = SentenceClassificationDataset(
sst_train_data, args, override_length=args.samples_per_epoch
)
Expand All @@ -259,7 +262,6 @@ def train_multitask(args):
num_workers=2,
)

# if train_all_datasets or args.para:
para_train_data = SentencePairDataset(
para_train_data, args, override_length=args.samples_per_epoch
)
Expand All @@ -282,7 +284,6 @@ def train_multitask(args):
num_workers=2,
)

# if train_all_datasets or args.sts:
sts_train_data = SentencePairDataset(
sts_train_data, args, isRegression=True, override_length=args.samples_per_epoch
)
Expand Down Expand Up @@ -374,7 +375,6 @@ def train_multitask(args):
if args.optimizer == "adamw":
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=args.weight_decay)
elif args.optimizer == "sophiah":
# TODO: Tune this further, https://github.com/Liuhong99/Sophia#hyper-parameter-tuning
optimizer = SophiaH(
model.parameters(),
lr=lr,
Expand Down

0 comments on commit 6c56d1b

Please sign in to comment.