Skip to content

Commit

Permalink
Compare all trees when doing the contrastive loss
Browse files Browse the repository at this point in the history
  • Loading branch information
AngledLuffa committed Jan 16, 2025
1 parent 98544b6 commit 30bc91d
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions stanza/models/constituency/parser_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,22 +620,27 @@ def contrast_trees(reparsed, gold):
if reparsed.is_preterminal() or gold.is_preterminal():
return

if (len(reparsed.children) == len(gold.children) and
all(x.start_index == y.start_index and x.end_index == y.end_index for x, y in zip(reparsed.children, gold.children))):
for x, y in zip(reparsed.children, gold.children):
contrast_trees(x, y)
return

# TODO: instead compare all subtrees?
# preterminals don't have values returned by the tree analysis functions above
if (reparsed.children[0].end_index != gold.children[0].end_index and
not reparsed.children[0].is_preterminal() and not gold.children[0].is_preterminal()):
reparsed_negatives.append(reparsed_hx[str(reparsed.children[0])])
gold_negatives.append(gold_hx[str(gold.children[0])])
if (reparsed.children[-1].start_index != gold.children[-1].start_index and
not reparsed.children[-1].is_preterminal() and not gold.children[-1].is_preterminal()):
reparsed_negatives.append(reparsed_hx[str(reparsed.children[-1])])
gold_negatives.append(gold_hx[str(gold.children[-1])])
reparsed_idx = 0
gold_idx = 0
while reparsed_idx < len(reparsed.children) and gold_idx < len(gold.children):
reparsed_child = reparsed.children[reparsed_idx]
gold_child = gold.children[gold_idx]
if not reparsed_child.is_preterminal() and not gold_child.is_preterminal():
# TODO: check that comparing labels is helpful
if (reparsed_child.label == gold_child.label and
reparsed_child.start_index == gold_child.start_index and
reparsed_child.end_index == gold_child.end_index):
contrast_trees(reparsed_child, gold_child)
else:
reparsed_negatives.append(reparsed_hx[str(reparsed_child)])
gold_negatives.append(gold_hx[str(gold_child)])
if reparsed_child.end_index == gold_child.end_index:
reparsed_idx += 1
gold_idx += 1
elif reparsed_child.end_index < gold_child.end_index:
reparsed_idx += 1
else:
gold_idx += 1

reparsed_tree.mark_spans()
gold_tree.mark_spans()
Expand Down

0 comments on commit 30bc91d

Please sign in to comment.