Skip to content

Commit

Permalink
using multiple predictions when evaluating frame id task
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed May 13, 2022
1 parent 197b93f commit d28961c
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence

from frame_semantic_transformer.data.task_samples.TaskSample import TaskSample

Expand All @@ -25,9 +26,11 @@ def get_target(self) -> str:
)

@staticmethod
def evaluate_prediction(prediction: str, target: str) -> tuple[int, int, int]:
def evaluate_prediction(
prediction_outputs: Sequence[str], target: str
) -> tuple[int, int, int]:
# TODO: improve evaluation
if prediction == target:
if prediction_outputs[0] == target:
return (1, 0, 0)
else:
return (0, 1, 0)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence
from frame_semantic_transformer.data.data_utils import standardize_punct
from frame_semantic_transformer.data.framenet import is_valid_frame

from frame_semantic_transformer.data.task_samples.TaskSample import TaskSample

Expand All @@ -23,8 +25,13 @@ def get_target(self) -> str:
return self.frame

@staticmethod
def evaluate_prediction(prediction: str, target: str) -> tuple[int, int, int]:
if prediction == target:
def evaluate_prediction(
prediction_outputs: Sequence[str], target: str
) -> tuple[int, int, int]:
valid_predictions = [
pred for pred in prediction_outputs if is_valid_frame(pred)
]
if len(valid_predictions) > 0 and valid_predictions[0] == target:
return (1, 0, 0)
else:
# sesame treats any non-correct frame as both a false pos and false neg
Expand Down
5 changes: 4 additions & 1 deletion frame_semantic_transformer/data/task_samples/TaskSample.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Sequence


class TaskSample(ABC):
Expand All @@ -21,6 +22,8 @@ def get_target(self) -> str:

@staticmethod
@abstractmethod
def evaluate_prediction(prediction: str, target: str) -> tuple[int, int, int]:
def evaluate_prediction(
prediction_outputs: Sequence[str], target: str
) -> tuple[int, int, int]:
"return a tuple indicating the number of true positives, false positives, and false negatives in the prediction"
pass
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
import re
from typing import Sequence

from frame_semantic_transformer.data.data_utils import standardize_punct

Expand Down Expand Up @@ -31,12 +32,14 @@ def get_target(self) -> str:
return standardize_punct(output)

@staticmethod
def evaluate_prediction(prediction: str, target: str) -> tuple[int, int, int]:
def evaluate_prediction(
prediction_outputs: Sequence[str], target: str
) -> tuple[int, int, int]:
true_pos = 0
false_pos = 0
false_neg = 0

prediction_parts = process_text_for_evaluation(prediction).split()
prediction_parts = process_text_for_evaluation(prediction_outputs[0]).split()
target_parts = process_text_for_evaluation(target).split()

for i, target_part in enumerate(target_parts):
Expand Down
13 changes: 10 additions & 3 deletions frame_semantic_transformer/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ def evaluate(


def evaluate_batch(
model: T5ForConditionalGeneration, tokenizer: T5Tokenizer, batch: Any
model: T5ForConditionalGeneration,
tokenizer: T5Tokenizer,
batch: Any,
predictions_per_sample: int = 5,
) -> dict[str, list[int]]:
predictions = predict_on_ids(
model,
Expand All @@ -107,15 +110,19 @@ def evaluate_batch(
batch["attention_mask"],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
num_beams=predictions_per_sample,
num_return_sequences=predictions_per_sample,
)
batched_predictions = chunk_list(predictions, predictions_per_sample)
results: dict[str, list[int]] = defaultdict(lambda: [0, 0, 0])
for pred, task, label in zip(predictions, batch["task"], batch["labels"]):
for preds, task, label in zip(batched_predictions, batch["task"], batch["labels"]):
assert len(preds) == predictions_per_sample
target_tokens = [tok_id for tok_id in label.tolist() if tok_id != -100]
target = tokenizer.decode(
target_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
sample_class = TASK_SAMPLE_CLASS_MAP[task]
true_pos, false_pos, false_neg = sample_class.evaluate_prediction(pred, target)
true_pos, false_pos, false_neg = sample_class.evaluate_prediction(preds, target)
results[task][0] += true_pos
results[task][1] += false_pos
results[task][2] += false_neg
Expand Down
4 changes: 2 additions & 2 deletions tests/data/task_samples/test_ArgumentsExtractionSample.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def test_get_target() -> None:
def test_evaluate_prediction_just_does_a_simple_string_match_for_now() -> None:
target = "Donor = Your | Recipient = to Goodwill"
incorrect_pred = "Donor = Your | Recipient = to Goodwill | meh"
assert ArgumentsExtractionSample.evaluate_prediction(target, target) == (1, 0, 0)
assert ArgumentsExtractionSample.evaluate_prediction(incorrect_pred, target) == (
assert ArgumentsExtractionSample.evaluate_prediction([target], target) == (1, 0, 0)
assert ArgumentsExtractionSample.evaluate_prediction([incorrect_pred], target) == (
0,
1,
0,
Expand Down
8 changes: 5 additions & 3 deletions tests/data/task_samples/test_FrameClassificationSample.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,21 @@ def test_get_target() -> None:
def test_evaluate_prediction_correct_prediction() -> None:
correct_pred = "Giving"
assert FrameClassificationSample.evaluate_prediction(
correct_pred, correct_pred
[correct_pred], correct_pred
) == (1, 0, 0)


def test_evaluate_prediction_increments_fp_and_fn_on_incorrect_pred() -> None:
incorrect_pred = "Aiming"
nonsense_pred = "Nonsense"
assert FrameClassificationSample.evaluate_prediction(incorrect_pred, "Giving") == (
assert FrameClassificationSample.evaluate_prediction(
[incorrect_pred], "Giving"
) == (
0,
1,
1,
)
assert FrameClassificationSample.evaluate_prediction(nonsense_pred, "Giving") == (
assert FrameClassificationSample.evaluate_prediction([nonsense_pred], "Giving") == (
0,
1,
1,
Expand Down
8 changes: 4 additions & 4 deletions tests/data/task_samples/test_TriggerIdentificationSample.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,22 @@ def test_get_target() -> None:

def test_evaluate_prediction() -> None:
pred = "Your contribution * to Goodwill * will * mean * more than you may * know."
assert TriggerIdentificationSample.evaluate_prediction(pred, target) == (4, 1, 2)
assert TriggerIdentificationSample.evaluate_prediction([pred], target) == (4, 1, 2)


def test_evaluate_prediction_fails_for_elements_whose_content_doesnt_match() -> None:
pred = "Your AHAHAHAHA * to BADWILL will * PSYCH * more than you may * know."
assert TriggerIdentificationSample.evaluate_prediction(pred, target) == (3, 1, 3)
assert TriggerIdentificationSample.evaluate_prediction([pred], target) == (3, 1, 3)


def test_evaluate_prediction_treats_missing_words_as_wrong() -> None:
pred = "Your * contribution * to Goodwill will * mean"
assert TriggerIdentificationSample.evaluate_prediction(pred, target) == (3, 2, 3)
assert TriggerIdentificationSample.evaluate_prediction([pred], target) == (3, 2, 3)


def test_evaluate_prediction_treats_excess_words_as_false_positives() -> None:
pred = "Your * contribution * to Goodwill will * mean * more than you * may * know. ha ha ha ha!"
assert TriggerIdentificationSample.evaluate_prediction(pred, target) == (6, 4, 0)
assert TriggerIdentificationSample.evaluate_prediction([pred], target) == (6, 4, 0)


def test_process_text_for_evaluation_handles_contractions() -> None:
Expand Down

0 comments on commit d28961c

Please sign in to comment.