Skip to content

Commit

Permalink
Try using normalized vectors with an MSELoss instead of Cosine. The C…
Browse files Browse the repository at this point in the history
…osineEmbeddingLoss was occasionally hitting nan...
  • Loading branch information
AngledLuffa committed Dec 17, 2024
1 parent 6f2b56a commit 5bf876a
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions stanza/models/constituency/parser_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
model_loss_function.to(device)

if args['contrastive_learning_rate'] > 0:
contrastive_loss_function = nn.CosineEmbeddingLoss(margin=args['contrastive_margin'])
contrastive_loss_function = nn.MSELoss(reduction='sum')
contrastive_loss_function.to(device)
else:
contrastive_loss_function = None
Expand Down Expand Up @@ -598,16 +598,14 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te
gold_trees = [x.constituents.value.value.value for x in gold_states]
gold_tree_hx = [x.constituents.value.value.tree_hx for x in gold_states]

reparsed_negatives = [hx for hx, reparsed_tree, gold_tree in zip(reparsed_tree_hx, reparsed_trees, gold_trees) if reparsed_tree != gold_tree]
gold_negatives = [hx for hx, reparsed_tree, gold_tree in zip(gold_tree_hx, reparsed_trees, gold_trees) if reparsed_tree != gold_tree]
reparsed_negatives = [nn.functional.normalize(hx) for hx, reparsed_tree, gold_tree in zip(reparsed_tree_hx, reparsed_trees, gold_trees) if reparsed_tree != gold_tree]
gold_negatives = [nn.functional.normalize(hx) for hx, reparsed_tree, gold_tree in zip(gold_tree_hx, reparsed_trees, gold_trees) if reparsed_tree != gold_tree]

if len(reparsed_negatives) > 0:
reparsed_negatives = torch.cat(reparsed_negatives, dim=0)
gold_negatives = torch.cat(gold_negatives, dim=0)

mse = torch.stack([torch.dot(x.squeeze(0), y.squeeze(0)) for x, y in zip(reparsed_negatives, gold_negatives)])
device = next(model.parameters()).device
target = -torch.ones(reparsed_negatives.shape[0]).to(device)
contrastive_loss = args['contrastive_learning_rate'] * contrastive_loss_function(reparsed_negatives, gold_negatives, target)
target = torch.zeros(mse.shape[0]).to(device)
contrastive_loss = args['contrastive_learning_rate'] * contrastive_loss_function(mse, target)

# now we add the state to the trees in the batch
# the state is built as a bulk operation
Expand Down

0 comments on commit 5bf876a

Please sign in to comment.