From d28961cd5f87d42e87d51219f293984d06ef1efc Mon Sep 17 00:00:00 2001 From: David Chanin Date: Fri, 13 May 2022 17:15:28 +0100 Subject: [PATCH] using multiple predictions when evaluating frame id task --- .../data/task_samples/ArgumentsExtractionSample.py | 7 +++++-- .../data/task_samples/FrameClassificationSample.py | 11 +++++++++-- .../data/task_samples/TaskSample.py | 5 ++++- .../task_samples/TriggerIdentificationSample.py | 7 +++++-- frame_semantic_transformer/evaluate.py | 13 ++++++++++--- .../task_samples/test_ArgumentsExtractionSample.py | 4 ++-- .../task_samples/test_FrameClassificationSample.py | 8 +++++--- .../test_TriggerIdentificationSample.py | 8 ++++---- 8 files changed, 44 insertions(+), 19 deletions(-) diff --git a/frame_semantic_transformer/data/task_samples/ArgumentsExtractionSample.py b/frame_semantic_transformer/data/task_samples/ArgumentsExtractionSample.py index ed0af0b..63f51e7 100644 --- a/frame_semantic_transformer/data/task_samples/ArgumentsExtractionSample.py +++ b/frame_semantic_transformer/data/task_samples/ArgumentsExtractionSample.py @@ -1,5 +1,6 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Sequence from frame_semantic_transformer.data.task_samples.TaskSample import TaskSample @@ -25,9 +26,11 @@ def get_target(self) -> str: ) @staticmethod - def evaluate_prediction(prediction: str, target: str) -> tuple[int, int, int]: + def evaluate_prediction( + prediction_outputs: Sequence[str], target: str + ) -> tuple[int, int, int]: # TODO: improve evaluation - if prediction == target: + if prediction_outputs[0] == target: return (1, 0, 0) else: return (0, 1, 0) diff --git a/frame_semantic_transformer/data/task_samples/FrameClassificationSample.py b/frame_semantic_transformer/data/task_samples/FrameClassificationSample.py index dbad981..cb7a2b8 100644 --- a/frame_semantic_transformer/data/task_samples/FrameClassificationSample.py +++ b/frame_semantic_transformer/data/task_samples/FrameClassificationSample.py @@ -1,6 +1,8 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Sequence from frame_semantic_transformer.data.data_utils import standardize_punct +from frame_semantic_transformer.data.framenet import is_valid_frame from frame_semantic_transformer.data.task_samples.TaskSample import TaskSample @@ -23,8 +25,13 @@ def get_target(self) -> str: return self.frame @staticmethod - def evaluate_prediction(prediction: str, target: str) -> tuple[int, int, int]: - if prediction == target: + def evaluate_prediction( + prediction_outputs: Sequence[str], target: str + ) -> tuple[int, int, int]: + valid_predictions = [ + pred for pred in prediction_outputs if is_valid_frame(pred) + ] + if len(valid_predictions) > 0 and valid_predictions[0] == target: return (1, 0, 0) else: # sesame treats any non-correct frame as both a false pos and false neg diff --git a/frame_semantic_transformer/data/task_samples/TaskSample.py b/frame_semantic_transformer/data/task_samples/TaskSample.py index 3f63b6c..ee3a227 100644 --- a/frame_semantic_transformer/data/task_samples/TaskSample.py +++ b/frame_semantic_transformer/data/task_samples/TaskSample.py @@ -1,5 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod +from typing import Sequence class TaskSample(ABC): @@ -21,6 +22,8 @@ def get_target(self) -> str: @staticmethod @abstractmethod - def evaluate_prediction(prediction: str, target: str) -> tuple[int, int, int]: + def evaluate_prediction( + prediction_outputs: Sequence[str], target: str + ) -> tuple[int, int, int]: "return a tuple indicating the number of true positives, false positives, and false negatives in the prediction" pass diff --git a/frame_semantic_transformer/data/task_samples/TriggerIdentificationSample.py b/frame_semantic_transformer/data/task_samples/TriggerIdentificationSample.py index 274e8dd..886e667 100644 --- a/frame_semantic_transformer/data/task_samples/TriggerIdentificationSample.py +++ b/frame_semantic_transformer/data/task_samples/TriggerIdentificationSample.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import dataclass import re +from typing import Sequence from frame_semantic_transformer.data.data_utils import standardize_punct @@ -31,12 +32,14 @@ def get_target(self) -> str: return standardize_punct(output) @staticmethod - def evaluate_prediction(prediction: str, target: str) -> tuple[int, int, int]: + def evaluate_prediction( + prediction_outputs: Sequence[str], target: str + ) -> tuple[int, int, int]: true_pos = 0 false_pos = 0 false_neg = 0 - prediction_parts = process_text_for_evaluation(prediction).split() + prediction_parts = process_text_for_evaluation(prediction_outputs[0]).split() target_parts = process_text_for_evaluation(target).split() for i, target_part in enumerate(target_parts): diff --git a/frame_semantic_transformer/evaluate.py b/frame_semantic_transformer/evaluate.py index 2050805..3f8fc26 100644 --- a/frame_semantic_transformer/evaluate.py +++ b/frame_semantic_transformer/evaluate.py @@ -98,7 +98,10 @@ def evaluate( def evaluate_batch( - model: T5ForConditionalGeneration, tokenizer: T5Tokenizer, batch: Any + model: T5ForConditionalGeneration, + tokenizer: T5Tokenizer, + batch: Any, + predictions_per_sample: int = 5, ) -> dict[str, list[int]]: predictions = predict_on_ids( model, @@ -107,15 +110,19 @@ def evaluate_batch( batch["attention_mask"], skip_special_tokens=True, clean_up_tokenization_spaces=True, + num_beams=predictions_per_sample, + num_return_sequences=predictions_per_sample, ) + batched_predictions = chunk_list(predictions, predictions_per_sample) results: dict[str, list[int]] = defaultdict(lambda: [0, 0, 0]) - for pred, task, label in zip(predictions, batch["task"], batch["labels"]): + for preds, task, label in zip(batched_predictions, batch["task"], batch["labels"]): + assert len(preds) == predictions_per_sample target_tokens = [tok_id for tok_id in label.tolist() if tok_id != -100] target = tokenizer.decode( target_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True ) sample_class = TASK_SAMPLE_CLASS_MAP[task] - true_pos, false_pos, false_neg = sample_class.evaluate_prediction(pred, target) + true_pos, false_pos, false_neg = sample_class.evaluate_prediction(preds, target) results[task][0] += true_pos results[task][1] += false_pos results[task][2] += false_neg diff --git a/tests/data/task_samples/test_ArgumentsExtractionSample.py b/tests/data/task_samples/test_ArgumentsExtractionSample.py index 9a436e1..a04bda3 100644 --- a/tests/data/task_samples/test_ArgumentsExtractionSample.py +++ b/tests/data/task_samples/test_ArgumentsExtractionSample.py @@ -26,8 +26,8 @@ def test_get_target() -> None: def test_evaluate_prediction_just_does_a_simple_string_match_for_now() -> None: target = "Donor = Your | Recipient = to Goodwill" incorrect_pred = "Donor = Your | Recipient = to Goodwill | meh" - assert ArgumentsExtractionSample.evaluate_prediction(target, target) == (1, 0, 0) - assert ArgumentsExtractionSample.evaluate_prediction(incorrect_pred, target) == ( + assert ArgumentsExtractionSample.evaluate_prediction([target], target) == (1, 0, 0) + assert ArgumentsExtractionSample.evaluate_prediction([incorrect_pred], target) == ( 0, 1, 0, diff --git a/tests/data/task_samples/test_FrameClassificationSample.py b/tests/data/task_samples/test_FrameClassificationSample.py index a67bf55..bc7e338 100644 --- a/tests/data/task_samples/test_FrameClassificationSample.py +++ b/tests/data/task_samples/test_FrameClassificationSample.py @@ -27,19 +27,21 @@ def test_get_target() -> None: def test_evaluate_prediction_correct_prediction() -> None: correct_pred = "Giving" assert FrameClassificationSample.evaluate_prediction( - correct_pred, correct_pred + [correct_pred], correct_pred ) == (1, 0, 0) def test_evaluate_prediction_increments_fp_and_fn_on_incorrect_pred() -> None: incorrect_pred = "Aiming" nonsense_pred = "Nonsense" - assert FrameClassificationSample.evaluate_prediction(incorrect_pred, "Giving") == ( + assert FrameClassificationSample.evaluate_prediction( + [incorrect_pred], "Giving" + ) == ( 0, 1, 1, ) - assert FrameClassificationSample.evaluate_prediction(nonsense_pred, "Giving") == ( + assert FrameClassificationSample.evaluate_prediction([nonsense_pred], "Giving") == ( 0, 1, 1, diff --git a/tests/data/task_samples/test_TriggerIdentificationSample.py b/tests/data/task_samples/test_TriggerIdentificationSample.py index 2c38e6c..01bbc6e 100644 --- a/tests/data/task_samples/test_TriggerIdentificationSample.py +++ b/tests/data/task_samples/test_TriggerIdentificationSample.py @@ -30,22 +30,22 @@ def test_get_target() -> None: def test_evaluate_prediction() -> None: pred = "Your contribution * to Goodwill * will * mean * more than you may * know." - assert TriggerIdentificationSample.evaluate_prediction(pred, target) == (4, 1, 2) + assert TriggerIdentificationSample.evaluate_prediction([pred], target) == (4, 1, 2) def test_evaluate_prediction_fails_for_elements_whose_content_doesnt_match() -> None: pred = "Your AHAHAHAHA * to BADWILL will * PSYCH * more than you may * know." - assert TriggerIdentificationSample.evaluate_prediction(pred, target) == (3, 1, 3) + assert TriggerIdentificationSample.evaluate_prediction([pred], target) == (3, 1, 3) def test_evaluate_prediction_treats_missing_words_as_wrong() -> None: pred = "Your * contribution * to Goodwill will * mean" - assert TriggerIdentificationSample.evaluate_prediction(pred, target) == (3, 2, 3) + assert TriggerIdentificationSample.evaluate_prediction([pred], target) == (3, 2, 3) def test_evaluate_prediction_treats_excess_words_as_false_positives() -> None: pred = "Your * contribution * to Goodwill will * mean * more than you * may * know. ha ha ha ha!" - assert TriggerIdentificationSample.evaluate_prediction(pred, target) == (6, 4, 0) + assert TriggerIdentificationSample.evaluate_prediction([pred], target) == (6, 4, 0) def test_process_text_for_evaluation_handles_contractions() -> None: