-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: new models trained on Framenet exemplars (#18)
* include exemplars in framenet training * skipping invalid trigger exemplars * skip exemplars by default during training * fixing tests * improving data augmentations * ensure wordnet download for inference * updating snapshots * adding more info when augmentations fail validation * adding more augmentations from nlpaug * fixing linting * fixing keyboard augmentation * more checks on keyboard augmentation * tweaking augmentations * fixing tests * adding safety check to uppercase augmentation * lower augmentation rate * adding more augmentations * tweaking augs * removing debugging output * reduce augmentation * tweaking augmentation probs * tweaking augmentation probs * fixing type import * adding option to delete non-optimal models as training progresses * tweaking augmentations * updating models * updating README with new model stats
- Loading branch information
Showing
55 changed files
with
1,693 additions
and
313 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
26 changes: 19 additions & 7 deletions
26
frame_semantic_transformer/data/augmentations/DataAugmentation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
59 changes: 59 additions & 0 deletions
59
frame_semantic_transformer/data/augmentations/DoubleQuotesAugmentation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from __future__ import annotations | ||
import random | ||
|
||
from frame_semantic_transformer.data.augmentations.modification_helpers import ( | ||
splice_text, | ||
) | ||
from frame_semantic_transformer.data.tasks import TaskSample | ||
from .DataAugmentation import DataAugmentation | ||
from .modification_helpers.get_sample_text import get_sample_text | ||
|
||
|
||
LATEX_QUOTES = ["``", "''"] | ||
STANDARD_QUOTE = '"' | ||
ALL_QUOTES = LATEX_QUOTES + [STANDARD_QUOTE] | ||
|
||
|
||
class DoubleQuotesAugmentation(DataAugmentation): | ||
def apply_augmentation(self, task_sample: TaskSample) -> TaskSample: | ||
sample_text = get_sample_text(task_sample) | ||
|
||
# if standard quotes are used, convert to latex quotes, and vice versa | ||
to_latex = STANDARD_QUOTE in sample_text | ||
from_quotes = [STANDARD_QUOTE] if to_latex else LATEX_QUOTES | ||
to_quotes = LATEX_QUOTES if to_latex else [STANDARD_QUOTE] | ||
|
||
updated_sample = task_sample | ||
while count_instances(sample_text, from_quotes) > 0: | ||
quote, start_loc = find_first_instance(sample_text, from_quotes) | ||
try: | ||
updated_sample = splice_text( | ||
updated_sample, | ||
lambda _text, _critical_indices: ( | ||
start_loc, | ||
len(quote), | ||
random.choice(to_quotes), | ||
), | ||
) | ||
sample_text = get_sample_text(updated_sample) | ||
except ValueError: | ||
# The splice failed, so just return the sample | ||
return updated_sample | ||
|
||
return updated_sample | ||
|
||
|
||
def count_instances(text: str, substrings: list[str]) -> int: | ||
return sum(text.count(substring) for substring in substrings) | ||
|
||
|
||
def find_first_instance(text: str, substrings: list[str]) -> tuple[str, int]: | ||
""" | ||
Find the first instance of any of the substrings in the text. Returns the substring and the | ||
start location of the substring. | ||
""" | ||
for substring in substrings: | ||
start_loc = text.find(substring) | ||
if start_loc >= 0: | ||
return substring, start_loc | ||
raise ValueError(f"Could not find any of {substrings} in {text}") |
48 changes: 48 additions & 0 deletions
48
frame_semantic_transformer/data/augmentations/KeyboardAugmentation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from __future__ import annotations | ||
|
||
from nlpaug.augmenter.char import KeyboardAug | ||
|
||
from frame_semantic_transformer.data.augmentations.modification_helpers import ( | ||
modify_text_without_changing_length, | ||
) | ||
from frame_semantic_transformer.data.tasks import TaskSample | ||
from .DataAugmentation import DataAugmentation, ProbabilityType | ||
|
||
|
||
class KeyboardAugmentation(DataAugmentation): | ||
""" | ||
Wrapper about nlpaug's KeyboardAugmenter | ||
Attempts to make spelling mistakes similar to what a user might make | ||
""" | ||
|
||
augmenter: KeyboardAug | ||
|
||
def __init__(self, probability: ProbabilityType): | ||
super().__init__(probability) | ||
self.augmenter = KeyboardAug( | ||
include_special_char=False, aug_char_p=0.1, aug_word_p=0.1 | ||
) | ||
self.augmenter.include_detail = True | ||
|
||
def apply_augmentation(self, task_sample: TaskSample) -> TaskSample: | ||
def augment_sent(sentence: str) -> str: | ||
# this augmentation removes spaces around punctuation, so just manually do the changes | ||
_, changes = self.augmenter.augment(sentence)[0] | ||
new_sentence = sentence | ||
for change in changes: | ||
# sometimes this augmenter changes token lengths, which we don't want | ||
# just skip the changes if that happens | ||
if len(change["orig_token"]) != len(change["new_token"]): | ||
return new_sentence | ||
if change["orig_start_pos"] != change["new_start_pos"]: | ||
return new_sentence | ||
start_pos = change["orig_start_pos"] | ||
end_pos = change["orig_start_pos"] + len(change["orig_token"]) | ||
new_sentence = ( | ||
new_sentence[:start_pos] | ||
+ change["new_token"] | ||
+ new_sentence[end_pos:] | ||
) | ||
return new_sentence | ||
|
||
return modify_text_without_changing_length(task_sample, augment_sent) |
16 changes: 7 additions & 9 deletions
16
frame_semantic_transformer/data/augmentations/LowercaseAugmentation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,12 @@ | ||
from __future__ import annotations | ||
from frame_semantic_transformer.data.augmentations.modification_helpers import ( | ||
modify_text_without_changing_length, | ||
) | ||
|
||
from frame_semantic_transformer.data.tasks import TaskSample | ||
from .DataAugmentation import DataAugmentation | ||
|
||
|
||
class LowercaseAugmentation(DataAugmentation): | ||
def apply_augmentation(self, input: str, output: str) -> tuple[str, str]: | ||
task_def_index = input.find(":") | ||
task_def = input[:task_def_index] | ||
input_contents = input[task_def_index:] | ||
# only lowercase the content, not the task definition | ||
return ( | ||
task_def + input_contents.lower(), | ||
output.lower(), | ||
) | ||
def apply_augmentation(self, task_sample: TaskSample) -> TaskSample: | ||
return modify_text_without_changing_length(task_sample, str.lower) |
25 changes: 0 additions & 25 deletions
25
frame_semantic_transformer/data/augmentations/RemoveContractionsAugmentation.py
This file was deleted.
Oops, something went wrong.
29 changes: 23 additions & 6 deletions
29
frame_semantic_transformer/data/augmentations/RemoveEndPunctuationAugmentation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,30 @@ | ||
from __future__ import annotations | ||
from .DataAugmentation import DataAugmentation | ||
import re | ||
from frame_semantic_transformer.data.augmentations.modification_helpers import ( | ||
splice_text, | ||
) | ||
from frame_semantic_transformer.data.augmentations.modification_helpers.splice_text import ( | ||
is_valid_splice, | ||
) | ||
|
||
from frame_semantic_transformer.data.tasks import TaskSample | ||
from .DataAugmentation import DataAugmentation | ||
|
||
REMOVE_END_PUNCT_RE = r"\s*[.?!]\s*$" | ||
|
||
|
||
class RemoveEndPunctuationAugmentation(DataAugmentation): | ||
def apply_augmentation(self, input: str, output: str) -> tuple[str, str]: | ||
return ( | ||
re.sub(REMOVE_END_PUNCT_RE, "", input), | ||
re.sub(REMOVE_END_PUNCT_RE, "", output), | ||
) | ||
def apply_augmentation(self, task_sample: TaskSample) -> TaskSample: | ||
def splice_end_punct_cb( | ||
sentence: str, critical_indices: list[int] | ||
) -> tuple[int, int, str] | None: | ||
match = re.search(REMOVE_END_PUNCT_RE, sentence) | ||
if match is None: | ||
return None | ||
start, end = match.span() | ||
del_len = end - start | ||
if not is_valid_splice(start, del_len, critical_indices): | ||
return None | ||
return start, del_len, "" | ||
|
||
return splice_text(task_sample, splice_end_punct_cb) |
Oops, something went wrong.