diff --git a/stanza/models/constituency/parser_training.py b/stanza/models/constituency/parser_training.py index 71cea9aeed..369e156a09 100644 --- a/stanza/models/constituency/parser_training.py +++ b/stanza/models/constituency/parser_training.py @@ -34,15 +34,16 @@ TrainItem = namedtuple("TrainItem", ['tree', 'gold_sequence', 'preterminals']) -class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])): +class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'contrastive_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])): def __add__(self, other): transitions_correct = self.transitions_correct + other.transitions_correct transitions_incorrect = self.transitions_incorrect + other.transitions_incorrect repairs_used = self.repairs_used + other.repairs_used fake_transitions_used = self.fake_transitions_used + other.fake_transitions_used epoch_loss = self.epoch_loss + other.epoch_loss + contrastive_loss = self.contrastive_loss + other.contrastive_loss nans = self.nans + other.nans - return EpochStats(epoch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans) + return EpochStats(epoch_loss, contrastive_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans) def evaluate(args, model_file, retag_pipeline): """ @@ -339,6 +340,8 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d # Various experiments generally show about 0.5 F1 loss on various # datasets when using 'mean' instead of 'sum' for reduction # (Remember to adjust the weight decay when rerunning that experiment) + device = trainer.device + if args['loss'] == 'cross': tlogger.info("Building CrossEntropyLoss(sum)") process_outputs = lambda x: x @@ -357,9 +360,14 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d model_loss_function = LargeMarginInSoftmaxLoss(reduction='sum') else: raise ValueError("Unexpected loss term: %s" % args['loss']) - - device = trainer.device model_loss_function.to(device) + + if args['contrastive_learning_rate'] > 0: + contrastive_loss_function = nn.CosineEmbeddingLoss(margin=args['contrastive_margin']) + contrastive_loss_function.to(device) + else: + contrastive_loss_function = None + transition_tensors = {x: torch.tensor(y, requires_grad=False, device=device).unsqueeze(0) for (y, x) in enumerate(trainer.transitions)} trainer.train() @@ -409,7 +417,7 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d epoch_data = epoch_data + epoch_silver_data epoch_data.sort(key=lambda x: len(x[1])) - epoch_stats = train_model_one_epoch(trainer.epochs_trained, trainer, transition_tensors, process_outputs, model_loss_function, epoch_data, oracle, args) + epoch_stats = train_model_one_epoch(trainer.epochs_trained, trainer, transition_tensors, process_outputs, model_loss_function, contrastive_loss_function, epoch_data, oracle, args) # print statistics # by now we've forgotten about the original tags on the trees, @@ -430,9 +438,15 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d "Transitions correct: %s" % epoch_stats.transitions_correct, "Transitions incorrect: %s" % epoch_stats.transitions_incorrect, "Total loss for epoch: %.5f" % epoch_stats.epoch_loss, + ] + if args['contrastive_learning_rate'] > 0.0: + stats_log_lines.extend([ + "Contrastive loss for epoch: %.5f" % epoch_stats.contrastive_loss + ]) + stats_log_lines.extend([ "Dev score (%5d): %8f" % (trainer.epochs_trained, f1), "Best dev score (%5d): %8f" % (trainer.best_epoch, trainer.best_f1) - ] + ]) tlogger.info("\n ".join(stats_log_lines)) old_lr = trainer.optimizer.param_groups[0]['lr'] @@ -526,17 +540,17 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d return trainer -def train_model_one_epoch(epoch, trainer, transition_tensors, process_outputs, model_loss_function, epoch_data, oracle, args): +def train_model_one_epoch(epoch, trainer, transition_tensors, process_outputs, model_loss_function, contrastive_loss_function, epoch_data, oracle, args): interval_starts = list(range(0, len(epoch_data), args['train_batch_size'])) random.shuffle(interval_starts) optimizer = trainer.optimizer - epoch_stats = EpochStats(0.0, Counter(), Counter(), Counter(), 0, 0) + epoch_stats = EpochStats(0.0, 0.0, Counter(), Counter(), Counter(), 0, 0) for batch_idx, interval_start in enumerate(tqdm(interval_starts, postfix="Epoch %d" % epoch)): batch = epoch_data[interval_start:interval_start+args['train_batch_size']] - batch_stats = train_model_one_batch(epoch, batch_idx, trainer.model, batch, transition_tensors, process_outputs, model_loss_function, oracle, args) + batch_stats = train_model_one_batch(epoch, batch_idx, trainer.model, batch, transition_tensors, process_outputs, model_loss_function, contrastive_loss_function, oracle, args) trainer.batches_trained += 1 # Early in the training, some trees will be degenerate in a @@ -562,7 +576,7 @@ def train_model_one_epoch(epoch, trainer, transition_tensors, process_outputs, m return epoch_stats -def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_tensors, process_outputs, model_loss_function, oracle, args): +def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_tensors, process_outputs, model_loss_function, contrastive_loss_function, oracle, args): """ Train the model for one batch @@ -572,6 +586,29 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te ... although the indentation does get pretty ridiculous if this is merged into train_model_one_epoch and then iterate_training """ + contrastive_loss = 0.0 + if epoch >= args['contrastive_initial_epoch'] and contrastive_loss_function is not None: + reparsed_results = model.parse_sentences(iter([x.tree for x in training_batch]), model.build_batch_from_trees, len(training_batch), model.predict, keep_state=True) + reparsed_states = [x.state for x in reparsed_results] + reparsed_trees = [x.constituents.value.value.value for x in reparsed_states] + reparsed_tree_hx = [x.constituents.value.value.tree_hx for x in reparsed_states] + + gold_results = model.analyze_trees([x.tree for x in training_batch], keep_constituents=False, keep_scores=False) + gold_states = [x.state for x in gold_results] + gold_trees = [x.constituents.value.value.value for x in gold_states] + gold_tree_hx = [x.constituents.value.value.tree_hx for x in gold_states] + + reparsed_negatives = [hx for hx, reparsed_tree, gold_tree in zip(reparsed_tree_hx, reparsed_trees, gold_trees) if reparsed_tree != gold_tree] + gold_negatives = [hx for hx, reparsed_tree, gold_tree in zip(gold_tree_hx, reparsed_trees, gold_trees) if reparsed_tree != gold_tree] + + if len(reparsed_negatives) > 0: + reparsed_negatives = torch.cat(reparsed_negatives, dim=0) + gold_negatives = torch.cat(gold_negatives, dim=0) + + device = next(model.parameters()).device + target = -torch.ones(reparsed_negatives.shape[0]).to(device) + contrastive_loss = args['contrastive_learning_rate'] * contrastive_loss_function(reparsed_negatives, gold_negatives, target) + # now we add the state to the trees in the batch # the state is built as a bulk operation current_batch = model.initial_state_from_preterminals([x.preterminals for x in training_batch], @@ -660,6 +697,7 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te errors = process_outputs(errors) tree_loss = model_loss_function(errors, answers) + tree_loss += contrastive_loss tree_loss.backward() if args['watch_regex']: matched = False @@ -678,12 +716,15 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te tlogger.info(" (none found!)") if torch.any(torch.isnan(tree_loss)): batch_loss = 0.0 + contrastive_loss = 0.0 nans = 1 else: batch_loss = tree_loss.item() + if not isinstance(contrastive_loss, float): + contrastive_loss = contrastive_loss.item() nans = 0 - return EpochStats(batch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans) + return EpochStats(batch_loss, contrastive_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans) def run_dev_set(model, retagged_trees, original_trees, args, evaluator=None): """ diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py index cbfdc61291..47407024c5 100644 --- a/stanza/models/constituency_parser.py +++ b/stanza/models/constituency_parser.py @@ -553,6 +553,10 @@ def build_argparse(): parser.add_argument('--learning_rate_min_lr', default=None, type=float, help='Plateau learning rate minimum') parser.add_argument('--stage1_learning_rate_min_lr', default=None, type=float, help='Plateau learning rate minimum (stage 1)') + parser.add_argument('--contrastive_initial_epoch', default=1, type=int, help='When to start contrastive learning') + parser.add_argument('--contrastive_margin', default=0.0, type=float, help='epsilon for the negative examples of contrastive learning') + parser.add_argument('--contrastive_learning_rate', default=0.0, type=float, help='Multiplicative factor for constrastive learning') + parser.add_argument('--grad_clipping', default=None, type=float, help='Clip abs(grad) to this amount. Use --no_grad_clipping to turn off grad clipping') parser.add_argument('--no_grad_clipping', action='store_const', const=None, dest='grad_clipping', help='Use --no_grad_clipping to turn off grad clipping')