diff --git a/stanza/models/constituency/parse_tree.py b/stanza/models/constituency/parse_tree.py index 26cea9a553..041f33f312 100644 --- a/stanza/models/constituency/parse_tree.py +++ b/stanza/models/constituency/parse_tree.py @@ -589,3 +589,19 @@ def write_treebank(trees, out_file, fmt="{}"): for tree in trees: fout.write(fmt.format(tree)) fout.write("\n") + + def mark_spans(self): + self._mark_spans(0) + + def _mark_spans(self, start_index): + self.start_index = start_index + + if len(self.children) == 0: + self.end_index = start_index + 1 + return + + for child in self.children: + child._mark_spans(start_index) + start_index = child.end_index + + self.end_index = start_index diff --git a/stanza/tests/constituency/test_parse_tree.py b/stanza/tests/constituency/test_parse_tree.py index ea3dd71af5..4de6e9fc5d 100644 --- a/stanza/tests/constituency/test_parse_tree.py +++ b/stanza/tests/constituency/test_parse_tree.py @@ -367,3 +367,17 @@ def test_reverse(): assert len(trees) == 1 reversed_tree = trees[0].reverse() assert str(reversed_tree) == "(ROOT (S (VP (S (VP (VP (NP (NNS antennae) (NP (POS 's) (NNP Jennifer))) (VB lick)) (TO to))) (VBP want)) (NP (PRP I))))" + +def test_mark_spans(): + text = "(ROOT (S (NP (PRP I)) (VP (VBP want) (S (VP (TO to) (VP (VB lick) (NP (NP (NNP Jennifer) (POS 's)) (NNS antennae))))))))" + trees = tree_reader.read_trees(text) + assert len(trees) == 1 + tree = trees[0] + + tree.mark_spans() + + assert tree.start_index == 0 + assert tree.end_index == 7 + for idx, pt in enumerate(tree.yield_preterminals()): + assert pt.start_index == idx + assert pt.end_index == idx + 1