Skip to content

Commit

Permalink
feat: Frame classification hints (#3)
Browse files Browse the repository at this point in the history
* adding in lexical unit data for smarter frame classification

* adding in stemming for lu handling

* allow skipping validation in initial epochs for faster training

* use self.current_epoch instead of batch_idx

* using bigrams to reduce the amount of frame suggestions

* refactoring bigrams stuff and adding more tests

* fixing bug with trigger bigrams

* updating README

* updating model revision
  • Loading branch information
chanind authored May 24, 2022
1 parent 8bf3275 commit 201ed51
Show file tree
Hide file tree
Showing 12 changed files with 171 additions and 13 deletions.
18 changes: 15 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
# Frame Semantic Transformer

[![ci](https://img.shields.io/github/workflow/status/chanind/frame-semantic-transformer/CI/main)](https://github.com/chanind/frame-semantic-transformer)
[![PyPI](https://img.shields.io/pypi/v/frame-semantic-transformer?color=blue)](https://pypi.org/project/frame-semantic-transformer/)


Frame-based semantic parsing library trained on [FrameNet](https://framenet2.icsi.berkeley.edu/) and built on HuggingFace's [T5 Transformer](https://huggingface.co/docs/transformers/model_doc/t5). This library is designed to be easy to use, yet powerful.

Live Demo: [chanind.github.io/frame-semantic-transformer](https://chanind.github.io/frame-semantic-transformer)
**Live Demo: [chanind.github.io/frame-semantic-transformer](https://chanind.github.io/frame-semantic-transformer)**

This library draws heavily on [Open-Sesame](https://github.com/swabhs/open-sesame) ([paper](https://arxiv.org/abs/1706.09528)) for inspiration on training and evaluation on FrameNet 1.7, and uses ideas from the paper [Open-Domain Frame Semantic Parsing Using Transformers](https://arxiv.org/abs/2010.10998) for using T5 as a frame-semantic parser. [SimpleT5](https://github.com/Shivanandroy/simpleT5) was also used as a base for the initial training setup.

## Performance

This library uses the same train/dev/test documents and evaluation methodology as Open-Sesame, so that the results should be comparable between the 2 libraries. There are 2 pretrained models available, `base` and `small`, corresponding to `t5-base` and `t5-small` in Huggingface, respectively.

| Task | Sesame F1 (dev/test) | Small Model F1 (dev/test) | Base Model F1 (dev/test) |
| ---------------------- | -------------------- | ------------------------- | ------------------------ |
| Trigger identification | 0.80 / 0.73 | 0.69 / 0.66 | 0.76 / 0.72 |
| Frame classification | 0.90 / 0.87 | 0.82 / 0.81 | 0.88 / 0.87 |
| Argument extraction | 0.61 / 0.61 | 0.68 / 0.61 | 0.74 / 0.72 |

This library draws heavily on [Open-Sesame](https://github.com/swabhs/open-sesame) ([paper](https://arxiv.org/abs/1706.09528)) for inspiration on training and evaluation on FrameNet, and uses ideas from the paper [Open-Domain Frame Semantic Parsing Using Transformers](https://arxiv.org/abs/2010.10998) for using T5 as a frame-semantic parser. [SimpleT5](https://github.com/Shivanandroy/simpleT5) was also used as a base for the initial training setup.
The base model performs similarly to Open-Sesame on trigger identification and frame classification tasks, but outperforms it by a significant margin on argument extraction. The small pretrained model has lower F1 than base across the board, but is 1/4 the size and is still comparable to Open-Sesame at argument extraction.

## Installation

Expand Down
2 changes: 1 addition & 1 deletion frame_semantic_transformer/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
MODEL_MAX_LENGTH = 512
OFFICIAL_RELEASES = ["base", "small"] # TODO: small, large
MODEL_REVISION = "v0.0.1"
MODEL_REVISION = "v0.1.0"
PADDING_LABEL_ID = -100
4 changes: 4 additions & 0 deletions frame_semantic_transformer/data/framenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,9 @@ def get_all_valid_frame_names() -> set[str]:
return {frame.name for frame in fn.frames()}


def get_lexical_units() -> Sequence[Mapping[str, Any]]:
return fn.lus()


def get_fulltext_docs() -> Sequence[Mapping[str, Any]]:
return fn.docs()
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations
from collections import defaultdict
from functools import lru_cache
import re
from nltk.stem import PorterStemmer
from .framenet import get_lexical_units


stemmer = PorterStemmer()
MONOGRAM_BLACKLIST = {"back", "down", "make", "take", "have", "into", "come"}


def get_possible_frames_for_trigger_bigrams(bigrams: list[list[str]]) -> list[str]:
possible_frames = []
lookup_map = get_lexical_unit_bigram_to_frame_lookup_map()
for bigram in bigrams:
normalized_bigram = normalize_lexical_unit_ngram(bigram)
if normalized_bigram in lookup_map:
bigram_frames = lookup_map[normalized_bigram]
possible_frames += bigram_frames
# remove duplicates, while preserving order
# https://stackoverflow.com/questions/1653970/does-python-have-an-ordered-set/53657523#53657523
return list(dict.fromkeys(possible_frames))


@lru_cache(1)
def get_lexical_unit_bigram_to_frame_lookup_map() -> dict[str, list[str]]:
uniq_lookup_map: dict[str, set[str]] = defaultdict(set)
for lu in get_lexical_units():
parts = lu["name"].split()
lu_bigrams: list[str] = []
prev_part = None
for part in parts:
norm_part = normalize_lexical_unit_text(part)
# also key this as a mongram if there's only 1 element or the word is rare enough
if len(parts) == 1 or (
len(norm_part) >= 4 and norm_part not in MONOGRAM_BLACKLIST
):
lu_bigrams.append(normalize_lexical_unit_ngram([part]))
if prev_part is not None:
lu_bigrams.append(normalize_lexical_unit_ngram([prev_part, part]))
prev_part = part

for bigram in lu_bigrams:
uniq_lookup_map[bigram].add(lu["frame"]["name"])
sorted_lookup_map: dict[str, list[str]] = {}
for lu_bigram, frames in uniq_lookup_map.items():
sorted_lookup_map[lu_bigram] = sorted(list(frames))
return sorted_lookup_map


def normalize_lexical_unit_ngram(ngram: list[str]) -> str:
return "_".join([normalize_lexical_unit_text(tok) for tok in ngram])


def normalize_lexical_unit_text(lu: str) -> str:
normalized_lu = lu.lower()
normalized_lu = re.sub(r"\.[a-zA-Z]+$", "", normalized_lu)
normalized_lu = re.sub(r"[^a-z0-9 ]", "", normalized_lu)
return stemmer.stem(normalized_lu.strip())
28 changes: 26 additions & 2 deletions frame_semantic_transformer/data/tasks/FrameClassificationTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
from dataclasses import dataclass
from typing import Sequence
from frame_semantic_transformer.data.data_utils import standardize_punct
from frame_semantic_transformer.data.framenet import is_valid_frame
from frame_semantic_transformer.data.framenet import (
is_valid_frame,
)
from frame_semantic_transformer.data.get_possible_frames_for_trigger_bigrams import (
get_possible_frames_for_trigger_bigrams,
)

from .Task import Task

Expand All @@ -19,7 +24,8 @@ def get_task_name() -> str:
return "frame_classification"

def get_input(self) -> str:
return f"FRAME: {self.trigger_labeled_text}"
potential_frames = get_possible_frames_for_trigger_bigrams(self.trigger_bigrams)
return f"FRAME {' '.join(potential_frames)} : {self.trigger_labeled_text}"

@staticmethod
def parse_output(prediction_outputs: Sequence[str]) -> str | None:
Expand All @@ -30,6 +36,24 @@ def parse_output(prediction_outputs: Sequence[str]) -> str | None:

# -- helper properties --

@property
def trigger_bigrams(self) -> list[list[str]]:
"""
return bigrams of the trigger, trigger + next work, and prev word + trigger
"""
pre_trigger_tokens = self.text[: self.trigger_loc].split()
trigger_and_after_tokens = self.text[self.trigger_loc :].split()
trigger = trigger_and_after_tokens[0]
post_trigger_tokens = trigger_and_after_tokens[1:]
bigrams: list[list[str]] = []
if len(pre_trigger_tokens) > 0:
bigrams.append([pre_trigger_tokens[-1], trigger])
if len(post_trigger_tokens) > 0:
bigrams.append([trigger, post_trigger_tokens[0]])
# add the monogram last
bigrams.append([trigger])
return bigrams

@property
def trigger_labeled_text(self) -> str:
pre_span = self.text[0 : self.trigger_loc]
Expand Down
16 changes: 12 additions & 4 deletions frame_semantic_transformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class TrainingModelWrapper(pl.LightningModule):
trainer: pl.Trainer
output_dir: str
save_only_last_epoch: bool
skip_initial_epochs_validation: int

def __init__(
self,
Expand All @@ -100,13 +101,15 @@ def __init__(
lr: float = 1e-4,
output_dir: str = "outputs",
save_only_last_epoch: bool = False,
skip_initial_epochs_validation: int = 0,
):
super().__init__()
self.lr = lr
self.model = model
self.tokenizer = tokenizer
self.output_dir = output_dir
self.save_only_last_epoch = save_only_last_epoch
self.skip_initial_epochs_validation = skip_initial_epochs_validation

def forward(
self,
Expand Down Expand Up @@ -143,6 +146,8 @@ def training_step(self, batch: Any, _batch_idx: int) -> Any: # type: ignore
def validation_step(self, batch: Any, _batch_idx: int) -> Any: # type: ignore
output = self._step(batch)
loss = output.loss
if self.current_epoch < self.skip_initial_epochs_validation:
return {"loss": loss}
metrics = evaluate_batch(self.model, self.tokenizer, batch)
self.log(
"val_loss",
Expand Down Expand Up @@ -191,6 +196,9 @@ def validation_epoch_end(self, validation_step_outputs: list[Any]) -> None:
torch.mean(torch.stack(losses)).item(),
4,
)
if self.current_epoch < self.skip_initial_epochs_validation:
# no validation metrics to calculate in this epoch, just return early
return

metrics = merge_metrics([out["metrics"] for out in validation_step_outputs])
for task_name, counts in metrics.items():
Expand Down Expand Up @@ -225,6 +233,7 @@ def train(
save_only_last_epoch: bool = False,
balance_tasks: bool = True,
max_task_duplication_factor: int = 2,
skip_initial_epochs_validation: int = 0,
) -> tuple[T5ForConditionalGeneration, T5Tokenizer]:
device = torch.device("cuda" if use_gpu else "cpu")
logger.info("loading base T5 model")
Expand All @@ -244,14 +253,12 @@ def train(
val_dataset = TaskSampleDataset(
load_sesame_dev_samples(),
tokenizer,
balance_tasks=balance_tasks,
max_task_duplication_factor=max_task_duplication_factor,
balance_tasks=False,
)
test_dataset = TaskSampleDataset(
load_sesame_test_samples(),
tokenizer,
balance_tasks=balance_tasks,
max_task_duplication_factor=max_task_duplication_factor,
balance_tasks=False,
)

data_module = TrainDataModule(
Expand All @@ -268,6 +275,7 @@ def train(
lr=lr,
output_dir=output_dir,
save_only_last_epoch=save_only_last_epoch,
skip_initial_epochs_validation=skip_initial_epochs_validation,
)

# add callbacks
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[flake8]
extend-ignore = E203,E501
exclude = dist

[mypy]
follow_imports = silent
Expand All @@ -11,6 +12,7 @@ check_untyped_defs = True
disallow_untyped_defs = True
namespace_packages = True
mypy_path = $MYPY_CONFIG_FILE_DIR/stubs
exclude = dist

[mypy-tests.*]
ignore_missing_imports = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@


def test_get_input() -> None:
expected = (
"FRAME: Your * contribution to Goodwill will mean more than you may know."
)
expected = "FRAME Condition_symptom_relation Giving : Your * contribution to Goodwill will mean more than you may know."
assert sample.get_input() == expected


Expand Down
16 changes: 16 additions & 0 deletions tests/data/tasks/test_FrameClassificationTask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from frame_semantic_transformer.data.tasks.FrameClassificationTask import (
FrameClassificationTask,
)


def test_trigger_bigrams() -> None:
task = FrameClassificationTask(
text="Your contribution to Goodwill will mean more than you may know .",
trigger_loc=5,
)

assert task.trigger_bigrams == [
["Your", "contribution"],
["contribution", "to"],
["contribution"],
]
34 changes: 34 additions & 0 deletions tests/data/test_get_possible_frames_for_trigger_bigrams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from __future__ import annotations
from frame_semantic_transformer.data.get_possible_frames_for_trigger_bigrams import (
get_lexical_unit_bigram_to_frame_lookup_map,
normalize_lexical_unit_ngram,
get_possible_frames_for_trigger_bigrams,
)


def test_get_lexical_unit_bigram_to_frame_lookup_map() -> None:
lookup_map = get_lexical_unit_bigram_to_frame_lookup_map()
assert len(lookup_map) > 5000
for frames in lookup_map.values():
assert len(frames) < 20


def test_normalize_lexical_unit_ngram() -> None:
assert normalize_lexical_unit_ngram(["can't", "stop"]) == "cant_stop"
assert normalize_lexical_unit_ngram(["he", "eats"]) == "he_eat"
assert normalize_lexical_unit_ngram(["eats"]) == "eat"


def test_get_possible_frames_for_trigger_bigrams() -> None:
assert get_possible_frames_for_trigger_bigrams(
[["can't", "help"], ["help", "it"], ["help"]]
) == ["Self_control", "Assistance"]
assert get_possible_frames_for_trigger_bigrams([["can't", "help"]]) == [
"Self_control"
]


def test_get_possible_frames_for_trigger_bigrams_stems_bigrams() -> None:
assert get_possible_frames_for_trigger_bigrams([["can't", "helps"]]) == [
"Self_control"
]

0 comments on commit 201ed51

Please sign in to comment.