Skip to content

Commit

Permalink
augment training samples dynamically during training
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed May 21, 2022
1 parent 57f24dc commit 3b40f07
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 104 deletions.
18 changes: 7 additions & 11 deletions frame_semantic_transformer/FrameSemanticTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,16 @@
from typing import cast
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer

from frame_semantic_transformer.constants import MODEL_MAX_LENGTH, OFFICIAL_RELEASES
from frame_semantic_transformer.data.data_utils import chunk_list, marked_string_to_locs
from frame_semantic_transformer.data.framenet import ensure_framenet_downloaded
from frame_semantic_transformer.data.tasks.ArgumentsExtractionTask import (
from frame_semantic_transformer.predict import batch_predict
from frame_semantic_transformer.data.tasks import (
ArgumentsExtractionTask,
)
from frame_semantic_transformer.data.tasks.FrameClassificationTask import (
FrameClassificationTask,
)

from frame_semantic_transformer.data.tasks.TriggerIdentificationTask import (
TriggerIdentificationTask,
)
from frame_semantic_transformer.predict import batch_predict


OFFICIAL_RELEASES = ["base", "small"] # TODO: small, large


@dataclass
Expand Down Expand Up @@ -71,7 +65,9 @@ def setup(self) -> None:
self._model = T5ForConditionalGeneration.from_pretrained(self.model_path).to(
self.device
)
self._tokenizer = T5Tokenizer.from_pretrained(self.model_path)
self._tokenizer = T5Tokenizer.from_pretrained(
self.model_path, model_max_length=MODEL_MAX_LENGTH
)
ensure_framenet_downloaded()

@property
Expand Down
2 changes: 2 additions & 0 deletions frame_semantic_transformer/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
MODEL_MAX_LENGTH = 512
OFFICIAL_RELEASES = ["base", "small"] # TODO: small, large
123 changes: 57 additions & 66 deletions frame_semantic_transformer/data/TaskSampleDataset.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
from __future__ import annotations
from collections import defaultdict
import random
from typing import Any, Sequence
from typing import Any, Callable, Optional, Sequence
import torch
from torch.utils.data import Dataset
from transformers import T5Tokenizer
from frame_semantic_transformer.constants import MODEL_MAX_LENGTH

from frame_semantic_transformer.data.augmentations import (
LowercaseAugmentation,
RemoveContractionsAugmentation,
RemoveEndPunctuationAugmentation,
chain_augmentations,
)

from frame_semantic_transformer.data.tasks.TaskSample import TaskSample
from frame_semantic_transformer.data.tasks import TaskSample


MAX_SOURCE_LEN = 512
MAX_TARGET_LEN = 512


class TaskSampleDataset(Dataset[Any]):
input_ids: torch.Tensor
attention_mask: torch.Tensor
labels: torch.Tensor
samples: Sequence[TaskSample]
augmentation: Optional[Callable[[str, str], tuple[str, str]]] = None
tokenizer: T5Tokenizer

def __init__(
self,
Expand All @@ -34,30 +33,67 @@ def __init__(
max_task_duplication_factor: int = 2,
augment_data: bool = False,
):
samples_to_parse = samples
self.samples = samples
if balance_tasks:
samples_to_parse = balance_tasks_by_type(
self.samples = balance_tasks_by_type(
samples, seed=seed, max_duplication_factor=max_task_duplication_factor
)
input_ids, attention_mask, labels = parse_samples(
samples_to_parse, tokenizer, augment_data
)
self.input_ids = input_ids
self.attention_mask = attention_mask
self.labels = labels
self.task_names = [sample.get_task_name() for sample in samples_to_parse]
self.tokenizer = tokenizer
if augment_data:
self.augmentation = chain_augmentations(
[
RemoveEndPunctuationAugmentation(0.3),
LowercaseAugmentation(0.2),
RemoveContractionsAugmentation(0.2),
]
)

def __len__(self) -> int:
return len(self.input_ids)
return len(self.samples)

def __getitem__(self, index: int) -> dict[str, Any]:
sample = self.samples[index]

input_ids, attention_mask, labels = self.parse_sample(sample)

return {
"input_ids": self.input_ids[index],
"attention_mask": self.attention_mask[index],
"labels": self.labels[index],
"task": self.task_names[index],
"input_ids": input_ids.squeeze(),
"attention_mask": attention_mask.squeeze(),
"labels": labels,
"task": sample.get_task_name(),
}

def parse_sample(
self, sample: TaskSample
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

input = sample.get_input()
target = sample.get_target()
if self.augmentation:
input, target = self.augmentation(input, target)

input_encoding = self.tokenizer(
input,
padding="max_length",
max_length=MODEL_MAX_LENGTH,
truncation=True,
return_tensors="pt",
)
input_ids, attention_mask = (
input_encoding.input_ids,
input_encoding.attention_mask,
)
output_encoding = self.tokenizer(
target,
padding="max_length",
max_length=MAX_TARGET_LEN,
truncation=True,
)
labels = torch.tensor(output_encoding.input_ids)
labels[labels == self.tokenizer.pad_token_id] = -100

return (input_ids, attention_mask, labels)


def balance_tasks_by_type(
samples: Sequence[TaskSample],
Expand All @@ -81,48 +117,3 @@ def balance_tasks_by_type(
balanced_samples.append(sample)
random.Random(seed).shuffle(balanced_samples)
return balanced_samples


def parse_samples(
samples: Sequence[TaskSample], tokenizer: T5Tokenizer, augment_data: bool
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
input_sequences: list[str] = []
output_sequences: list[str] = []

augmentation = chain_augmentations(
[
RemoveEndPunctuationAugmentation(0.3),
LowercaseAugmentation(0.2),
RemoveContractionsAugmentation(0.2),
]
)

for sample in samples:
input = sample.get_input()
output = sample.get_target()
if augment_data:
input, output = augmentation(input, output)
input_sequences.append(input)
output_sequences.append(output)

input_encoding = tokenizer(
input_sequences,
padding="longest",
max_length=MAX_SOURCE_LEN,
truncation=True,
return_tensors="pt",
)
input_ids, attention_mask = (
input_encoding.input_ids,
input_encoding.attention_mask,
)
output_encoding = tokenizer(
output_sequences,
padding="longest",
max_length=MAX_TARGET_LEN,
truncation=True,
)
labels = torch.tensor(output_encoding.input_ids)
labels[labels == tokenizer.pad_token_id] = -100

return (input_ids, attention_mask, labels)
14 changes: 2 additions & 12 deletions frame_semantic_transformer/data/load_framenet_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,13 @@
SESAME_DEV_FILES,
SESAME_TEST_FILES,
)
from frame_semantic_transformer.data.tasks.ArgumentsExtractionSample import (
from frame_semantic_transformer.data.tasks import (
ArgumentsExtractionSample,
)
from frame_semantic_transformer.data.tasks.ArgumentsExtractionTask import (
ArgumentsExtractionTask,
)
from frame_semantic_transformer.data.tasks.FrameClassificationSample import (
FrameClassificationSample,
)
from frame_semantic_transformer.data.tasks.FrameClassificationTask import (
FrameClassificationTask,
)
from frame_semantic_transformer.data.tasks.TaskSample import TaskSample
from frame_semantic_transformer.data.tasks.TriggerIdentificationSample import (
TaskSample,
TriggerIdentificationSample,
)
from frame_semantic_transformer.data.tasks.TriggerIdentificationTask import (
TriggerIdentificationTask,
)

Expand Down
19 changes: 19 additions & 0 deletions frame_semantic_transformer/data/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from .Task import Task
from .TaskSample import TaskSample
from .ArgumentsExtractionSample import ArgumentsExtractionSample
from .ArgumentsExtractionTask import ArgumentsExtractionTask
from .FrameClassificationSample import FrameClassificationSample
from .FrameClassificationTask import FrameClassificationTask
from .TriggerIdentificationSample import TriggerIdentificationSample
from .TriggerIdentificationTask import TriggerIdentificationTask

__all__ = (
"Task",
"TaskSample",
"ArgumentsExtractionSample",
"ArgumentsExtractionTask",
"FrameClassificationSample",
"FrameClassificationTask",
"TriggerIdentificationSample",
"TriggerIdentificationTask",
)
5 changes: 4 additions & 1 deletion frame_semantic_transformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.base import Callback
from frame_semantic_transformer.constants import MODEL_MAX_LENGTH

from frame_semantic_transformer.data.TaskSampleDataset import TaskSampleDataset
from frame_semantic_transformer.data.load_framenet_samples import (
Expand Down Expand Up @@ -222,7 +223,9 @@ def train(
device = torch.device("cuda" if use_gpu else "cpu")
logging.info("loading base T5 model")
model = T5ForConditionalGeneration.from_pretrained(base_model).to(device)
tokenizer = T5Tokenizer.from_pretrained(base_model)
tokenizer = T5Tokenizer.from_pretrained(
base_model, model_max_length=MODEL_MAX_LENGTH
)

logging.info("loading train/test/val datasets")
train_dataset = TaskSampleDataset(
Expand Down
18 changes: 4 additions & 14 deletions tests/data/test_TaskSampleDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,12 @@
from frame_semantic_transformer.data.load_framenet_samples import (
parse_samples_from_fulltext_doc,
)
from frame_semantic_transformer.data.tasks.ArgumentsExtractionSample import (
from frame_semantic_transformer.data.tasks import (
ArgumentsExtractionSample,
)
from frame_semantic_transformer.data.tasks.ArgumentsExtractionTask import (
ArgumentsExtractionTask,
)
from frame_semantic_transformer.data.tasks.FrameClassificationSample import (
FrameClassificationSample,
)
from frame_semantic_transformer.data.tasks.FrameClassificationTask import (
FrameClassificationTask,
)
from frame_semantic_transformer.data.tasks.TriggerIdentificationSample import (
TriggerIdentificationSample,
)
from frame_semantic_transformer.data.tasks.TriggerIdentificationTask import (
TriggerIdentificationTask,
)

Expand All @@ -39,9 +29,9 @@ def test_TaskSampleDataset() -> None:
dataset = TaskSampleDataset(samples, tokenizer)

assert len(dataset) == 8
assert len(dataset[0]["input_ids"]) == 99
assert len(dataset[0]["attention_mask"]) == 99
assert len(dataset[0]["labels"]) == 30
assert len(dataset[0]["input_ids"]) == 512
assert len(dataset[0]["attention_mask"]) == 512
assert len(dataset[0]["labels"]) == 512


def test_balance_tasks_by_type() -> None:
Expand Down

0 comments on commit 3b40f07

Please sign in to comment.