diff --git a/stanza/models/constituency/parser_training.py b/stanza/models/constituency/parser_training.py index 6ac39bd92..4790fc859 100644 --- a/stanza/models/constituency/parser_training.py +++ b/stanza/models/constituency/parser_training.py @@ -593,18 +593,55 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te contrastive_loss = 0.0 contrastive_trees_used = 0 if epoch >= args['contrastive_initial_epoch'] and contrastive_loss_function is not None: - reparsed_results = model.parse_sentences(iter([x.tree for x in training_batch]), model.build_batch_from_trees, len(training_batch), model.predict, keep_state=True) - reparsed_states = [x.state for x in reparsed_results] - reparsed_trees = [x.constituents.value.value.value for x in reparsed_states] - reparsed_tree_hx = [x.constituents.value.value.tree_hx for x in reparsed_states] - - gold_results = model.analyze_trees([x.tree for x in training_batch], keep_constituents=False, keep_scores=False) - gold_states = [x.state for x in gold_results] - gold_trees = [x.constituents.value.value.value for x in gold_states] - gold_tree_hx = [x.constituents.value.value.tree_hx for x in gold_states] - - reparsed_negatives = [nn.functional.normalize(hx) for hx, reparsed_tree, gold_tree in zip(reparsed_tree_hx, reparsed_trees, gold_trees) if reparsed_tree != gold_tree] - gold_negatives = [nn.functional.normalize(hx) for hx, reparsed_tree, gold_tree in zip(gold_tree_hx, reparsed_trees, gold_trees) if reparsed_tree != gold_tree] + reparsed_results = model.parse_sentences(iter([x.tree for x in training_batch]), model.build_batch_from_trees, len(training_batch), model.predict, keep_state=True, keep_constituents=True) + gold_results = model.analyze_trees([x.tree for x in training_batch], keep_constituents=True, keep_scores=False) + + reparsed_negatives = [] + gold_negatives = [] + + for reparsed_result, gold_result in zip(reparsed_results, gold_results): + reparsed_constituents = reparsed_result.constituents + reparsed_hx = {} + for con in reparsed_constituents: + reparsed_hx[str(con.value)] = con.tree_hx + + gold_constituents = gold_result.constituents + gold_hx = {} + for con in gold_constituents: + gold_hx[str(con.value)] = con.tree_hx + + reparsed_state = reparsed_result.state + reparsed_tree = reparsed_state.constituents.value.value.value + gold_state = gold_result.state + gold_tree = gold_state.constituents.value.value.value + + def contrast_trees(reparsed, gold): + if reparsed.is_preterminal() or gold.is_preterminal(): + return + + if (len(reparsed.children) == len(gold.children) and + all(x.start_index == y.start_index and x.end_index == y.end_index for x, y in zip(reparsed.children, gold.children))): + for x, y in zip(reparsed.children, gold.children): + contrast_trees(x, y) + return + + # TODO: instead compare all subtrees? + # preterminals don't have values returned by the tree analysis functions above + if (reparsed.children[0].end_index != gold.children[0].end_index and + not reparsed.children[0].is_preterminal() and not gold.children[0].is_preterminal()): + reparsed_negatives.append(reparsed_hx[str(reparsed.children[0])]) + gold_negatives.append(gold_hx[str(gold.children[0])]) + if (reparsed.children[-1].start_index != gold.children[-1].start_index and + not reparsed.children[-1].is_preterminal() and not gold.children[-1].is_preterminal()): + reparsed_negatives.append(reparsed_hx[str(reparsed.children[-1])]) + gold_negatives.append(gold_hx[str(gold.children[-1])]) + + reparsed_tree.mark_spans() + gold_tree.mark_spans() + contrast_trees(reparsed_tree, gold_tree) + + reparsed_negatives = [nn.functional.normalize(hx) for hx in reparsed_negatives] + gold_negatives = [nn.functional.normalize(hx) for hx in gold_negatives] if len(reparsed_negatives) > 0: mse = torch.stack([torch.dot(x.squeeze(0), y.squeeze(0)) for x, y in zip(reparsed_negatives, gold_negatives)])