Skip to content

Commit

Permalink
Add the number of contrastive trees used as a debugging line
Browse files Browse the repository at this point in the history
  • Loading branch information
AngledLuffa committed Jan 16, 2025
1 parent 685e4c3 commit f04c9c1
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions stanza/models/constituency/parser_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

TrainItem = namedtuple("TrainItem", ['tree', 'gold_sequence', 'preterminals'])

class EpochStats(namedtuple("EpochStats", ['transition_loss', 'contrastive_loss', 'total_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])):
class EpochStats(namedtuple("EpochStats", ['transition_loss', 'contrastive_loss', 'total_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'contrastive_trees_used', 'nans'])):
def __add__(self, other):
transitions_correct = self.transitions_correct + other.transitions_correct
transitions_incorrect = self.transitions_incorrect + other.transitions_incorrect
Expand All @@ -43,8 +43,9 @@ def __add__(self, other):
transition_loss = self.transition_loss + other.transition_loss
contrastive_loss = self.contrastive_loss + other.contrastive_loss
total_loss = self.total_loss + other.total_loss
contrastive_trees_used = self.contrastive_trees_used + other.contrastive_trees_used
nans = self.nans + other.nans
return EpochStats(transition_loss, contrastive_loss, total_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)
return EpochStats(transition_loss, contrastive_loss, total_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, contrastive_trees_used, nans)

def evaluate(args, model_file, retag_pipeline):
"""
Expand Down Expand Up @@ -456,6 +457,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
stats_log_lines.extend([
"Contrastive loss for epoch: %.5f" % epoch_stats.contrastive_loss,
"Total loss for epoch: %.5f" % epoch_stats.total_loss,
"Contrastive trees used: %d" % epoch_stats.contrastive_trees_used,
])
stats_log_lines.extend([
"Dev score (%5d): %8f" % (trainer.epochs_trained, f1),
Expand Down Expand Up @@ -560,7 +562,7 @@ def train_model_one_epoch(epoch, trainer, transition_tensors, process_outputs, m

optimizer = trainer.optimizer

epoch_stats = EpochStats(0.0, 0.0, 0.0, Counter(), Counter(), Counter(), 0, 0)
epoch_stats = EpochStats(0.0, 0.0, 0.0, Counter(), Counter(), Counter(), 0, 0, 0)

for batch_idx, interval_start in enumerate(tqdm(interval_starts, postfix="Epoch %d" % epoch)):
batch = epoch_data[interval_start:interval_start+args['train_batch_size']]
Expand Down Expand Up @@ -590,6 +592,7 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te
merged into train_model_one_epoch and then iterate_training
"""
contrastive_loss = 0.0
contrastive_trees_used = 0
if epoch >= args['contrastive_initial_epoch'] and contrastive_loss_function is not None:
reparsed_results = model.parse_sentences(iter([x.tree for x in training_batch]), model.build_batch_from_trees, len(training_batch), model.predict, keep_state=True)
reparsed_states = [x.state for x in reparsed_results]
Expand All @@ -609,6 +612,7 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te
device = next(model.parameters()).device
target = torch.zeros(mse.shape[0]).to(device)
contrastive_loss = args['contrastive_learning_rate'] * contrastive_loss_function(mse, target)
contrastive_trees_used += len(reparsed_negatives)

# now we add the state to the trees in the batch
# the state is built as a bulk operation
Expand Down Expand Up @@ -727,7 +731,7 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te
contrastive_loss = contrastive_loss.item()
nans = 0

return EpochStats(transition_loss, contrastive_loss, total_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)
return EpochStats(transition_loss, contrastive_loss, total_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, contrastive_trees_used, nans)

def run_dev_set(model, retagged_trees, original_trees, args, evaluator=None):
"""
Expand Down

0 comments on commit f04c9c1

Please sign in to comment.