Skip to content

Commit

Permalink
Recurse from the top of two trees until finding different trees to co…
Browse files Browse the repository at this point in the history
…ntrast, rather than contrasting at top regardless of where the inaccuracy is
  • Loading branch information
AngledLuffa committed Jan 16, 2025
1 parent fd8cd1e commit 98544b6
Showing 1 changed file with 49 additions and 12 deletions.
61 changes: 49 additions & 12 deletions stanza/models/constituency/parser_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,18 +594,55 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te
contrastive_loss = 0.0
contrastive_trees_used = 0
if epoch >= args['contrastive_initial_epoch'] and contrastive_loss_function is not None:
reparsed_results = model.parse_sentences(iter([x.tree for x in training_batch]), model.build_batch_from_trees, len(training_batch), model.predict, keep_state=True)
reparsed_states = [x.state for x in reparsed_results]
reparsed_trees = [x.constituents.value.value.value for x in reparsed_states]
reparsed_tree_hx = [x.constituents.value.value.tree_hx for x in reparsed_states]

gold_results = model.analyze_trees([x.tree for x in training_batch], keep_constituents=False, keep_scores=False)
gold_states = [x.state for x in gold_results]
gold_trees = [x.constituents.value.value.value for x in gold_states]
gold_tree_hx = [x.constituents.value.value.tree_hx for x in gold_states]

reparsed_negatives = [nn.functional.normalize(hx) for hx, reparsed_tree, gold_tree in zip(reparsed_tree_hx, reparsed_trees, gold_trees) if reparsed_tree != gold_tree]
gold_negatives = [nn.functional.normalize(hx) for hx, reparsed_tree, gold_tree in zip(gold_tree_hx, reparsed_trees, gold_trees) if reparsed_tree != gold_tree]
reparsed_results = model.parse_sentences(iter([x.tree for x in training_batch]), model.build_batch_from_trees, len(training_batch), model.predict, keep_state=True, keep_constituents=True)
gold_results = model.analyze_trees([x.tree for x in training_batch], keep_constituents=True, keep_scores=False)

reparsed_negatives = []
gold_negatives = []

for reparsed_result, gold_result in zip(reparsed_results, gold_results):
reparsed_constituents = reparsed_result.constituents
reparsed_hx = {}
for con in reparsed_constituents:
reparsed_hx[str(con.value)] = con.tree_hx

gold_constituents = gold_result.constituents
gold_hx = {}
for con in gold_constituents:
gold_hx[str(con.value)] = con.tree_hx

reparsed_state = reparsed_result.state
reparsed_tree = reparsed_state.constituents.value.value.value
gold_state = gold_result.state
gold_tree = gold_state.constituents.value.value.value

def contrast_trees(reparsed, gold):
if reparsed.is_preterminal() or gold.is_preterminal():
return

if (len(reparsed.children) == len(gold.children) and
all(x.start_index == y.start_index and x.end_index == y.end_index for x, y in zip(reparsed.children, gold.children))):
for x, y in zip(reparsed.children, gold.children):
contrast_trees(x, y)
return

# TODO: instead compare all subtrees?
# preterminals don't have values returned by the tree analysis functions above
if (reparsed.children[0].end_index != gold.children[0].end_index and
not reparsed.children[0].is_preterminal() and not gold.children[0].is_preterminal()):
reparsed_negatives.append(reparsed_hx[str(reparsed.children[0])])
gold_negatives.append(gold_hx[str(gold.children[0])])
if (reparsed.children[-1].start_index != gold.children[-1].start_index and
not reparsed.children[-1].is_preterminal() and not gold.children[-1].is_preterminal()):
reparsed_negatives.append(reparsed_hx[str(reparsed.children[-1])])
gold_negatives.append(gold_hx[str(gold.children[-1])])

reparsed_tree.mark_spans()
gold_tree.mark_spans()
contrast_trees(reparsed_tree, gold_tree)

reparsed_negatives = [nn.functional.normalize(hx) for hx in reparsed_negatives]
gold_negatives = [nn.functional.normalize(hx) for hx in gold_negatives]

if len(reparsed_negatives) > 0:
mse = torch.stack([torch.dot(x.squeeze(0), y.squeeze(0)) for x, y in zip(reparsed_negatives, gold_negatives)])
Expand Down

0 comments on commit 98544b6

Please sign in to comment.