-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Frame classification hints (#3)
* adding in lexical unit data for smarter frame classification * adding in stemming for lu handling * allow skipping validation in initial epochs for faster training * use self.current_epoch instead of batch_idx * using bigrams to reduce the amount of frame suggestions * refactoring bigrams stuff and adding more tests * fixing bug with trigger bigrams * updating README * updating model revision
- Loading branch information
Showing
12 changed files
with
171 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
MODEL_MAX_LENGTH = 512 | ||
OFFICIAL_RELEASES = ["base", "small"] # TODO: small, large | ||
MODEL_REVISION = "v0.0.1" | ||
MODEL_REVISION = "v0.1.0" | ||
PADDING_LABEL_ID = -100 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
60 changes: 60 additions & 0 deletions
60
frame_semantic_transformer/data/get_possible_frames_for_trigger_bigrams.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from __future__ import annotations | ||
from collections import defaultdict | ||
from functools import lru_cache | ||
import re | ||
from nltk.stem import PorterStemmer | ||
from .framenet import get_lexical_units | ||
|
||
|
||
stemmer = PorterStemmer() | ||
MONOGRAM_BLACKLIST = {"back", "down", "make", "take", "have", "into", "come"} | ||
|
||
|
||
def get_possible_frames_for_trigger_bigrams(bigrams: list[list[str]]) -> list[str]: | ||
possible_frames = [] | ||
lookup_map = get_lexical_unit_bigram_to_frame_lookup_map() | ||
for bigram in bigrams: | ||
normalized_bigram = normalize_lexical_unit_ngram(bigram) | ||
if normalized_bigram in lookup_map: | ||
bigram_frames = lookup_map[normalized_bigram] | ||
possible_frames += bigram_frames | ||
# remove duplicates, while preserving order | ||
# https://stackoverflow.com/questions/1653970/does-python-have-an-ordered-set/53657523#53657523 | ||
return list(dict.fromkeys(possible_frames)) | ||
|
||
|
||
@lru_cache(1) | ||
def get_lexical_unit_bigram_to_frame_lookup_map() -> dict[str, list[str]]: | ||
uniq_lookup_map: dict[str, set[str]] = defaultdict(set) | ||
for lu in get_lexical_units(): | ||
parts = lu["name"].split() | ||
lu_bigrams: list[str] = [] | ||
prev_part = None | ||
for part in parts: | ||
norm_part = normalize_lexical_unit_text(part) | ||
# also key this as a mongram if there's only 1 element or the word is rare enough | ||
if len(parts) == 1 or ( | ||
len(norm_part) >= 4 and norm_part not in MONOGRAM_BLACKLIST | ||
): | ||
lu_bigrams.append(normalize_lexical_unit_ngram([part])) | ||
if prev_part is not None: | ||
lu_bigrams.append(normalize_lexical_unit_ngram([prev_part, part])) | ||
prev_part = part | ||
|
||
for bigram in lu_bigrams: | ||
uniq_lookup_map[bigram].add(lu["frame"]["name"]) | ||
sorted_lookup_map: dict[str, list[str]] = {} | ||
for lu_bigram, frames in uniq_lookup_map.items(): | ||
sorted_lookup_map[lu_bigram] = sorted(list(frames)) | ||
return sorted_lookup_map | ||
|
||
|
||
def normalize_lexical_unit_ngram(ngram: list[str]) -> str: | ||
return "_".join([normalize_lexical_unit_text(tok) for tok in ngram]) | ||
|
||
|
||
def normalize_lexical_unit_text(lu: str) -> str: | ||
normalized_lu = lu.lower() | ||
normalized_lu = re.sub(r"\.[a-zA-Z]+$", "", normalized_lu) | ||
normalized_lu = re.sub(r"[^a-z0-9 ]", "", normalized_lu) | ||
return stemmer.stem(normalized_lu.strip()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from frame_semantic_transformer.data.tasks.FrameClassificationTask import ( | ||
FrameClassificationTask, | ||
) | ||
|
||
|
||
def test_trigger_bigrams() -> None: | ||
task = FrameClassificationTask( | ||
text="Your contribution to Goodwill will mean more than you may know .", | ||
trigger_loc=5, | ||
) | ||
|
||
assert task.trigger_bigrams == [ | ||
["Your", "contribution"], | ||
["contribution", "to"], | ||
["contribution"], | ||
] |
File renamed without changes.
34 changes: 34 additions & 0 deletions
34
tests/data/test_get_possible_frames_for_trigger_bigrams.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from __future__ import annotations | ||
from frame_semantic_transformer.data.get_possible_frames_for_trigger_bigrams import ( | ||
get_lexical_unit_bigram_to_frame_lookup_map, | ||
normalize_lexical_unit_ngram, | ||
get_possible_frames_for_trigger_bigrams, | ||
) | ||
|
||
|
||
def test_get_lexical_unit_bigram_to_frame_lookup_map() -> None: | ||
lookup_map = get_lexical_unit_bigram_to_frame_lookup_map() | ||
assert len(lookup_map) > 5000 | ||
for frames in lookup_map.values(): | ||
assert len(frames) < 20 | ||
|
||
|
||
def test_normalize_lexical_unit_ngram() -> None: | ||
assert normalize_lexical_unit_ngram(["can't", "stop"]) == "cant_stop" | ||
assert normalize_lexical_unit_ngram(["he", "eats"]) == "he_eat" | ||
assert normalize_lexical_unit_ngram(["eats"]) == "eat" | ||
|
||
|
||
def test_get_possible_frames_for_trigger_bigrams() -> None: | ||
assert get_possible_frames_for_trigger_bigrams( | ||
[["can't", "help"], ["help", "it"], ["help"]] | ||
) == ["Self_control", "Assistance"] | ||
assert get_possible_frames_for_trigger_bigrams([["can't", "help"]]) == [ | ||
"Self_control" | ||
] | ||
|
||
|
||
def test_get_possible_frames_for_trigger_bigrams_stems_bigrams() -> None: | ||
assert get_possible_frames_for_trigger_bigrams([["can't", "helps"]]) == [ | ||
"Self_control" | ||
] |