Skip to content

Commit

Permalink
♻️ Remove unused code & host Smxm model on HG hub (#14)
Browse files Browse the repository at this point in the history
* ♻️ Remove unused code

Signed-off-by: Gabriele Picco <gabriele.picco@ibm.comm>
Signed-off-by: Gabriele Picco <piccogabriele@gmail.com>

* 👷 Add CI cache

Signed-off-by: Gabriele Picco <piccogabriele@gmail.com>

* 👷 Update codecov action

Signed-off-by: Gabriele Picco <piccogabriele@gmail.com>

* ♻️ Download smxm from the hugginface hub

Signed-off-by: Gabriele Picco <piccogabriele@gmail.com>

* ♻️ Use models from IBM org on HG hub

Signed-off-by: Gabriele Picco <piccogabriele@gmail.com>

* ⚡ Improve coverage

Signed-off-by: Gabriele Picco <piccogabriele@gmail.com>

* 🎨 Fix code style

Signed-off-by: Gabriele Picco <piccogabriele@gmail.com>

Signed-off-by: Gabriele Picco <gabriele.picco@ibm.comm>
Signed-off-by: Gabriele Picco <piccogabriele@gmail.com>
Co-authored-by: Gabriele Picco <gabriele.picco@ibm.comm>
  • Loading branch information
GabrielePicco and Gabriele Picco authored Oct 12, 2022
1 parent 129d159 commit c44156d
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 235 deletions.
9 changes: 6 additions & 3 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ jobs:
key: ${{ runner.os }}-build-models-cache
path: |
~/.cache/huggingface
~/linker_smxm
~/.cache/zshot
~/.pytest_cache
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand All @@ -43,6 +44,8 @@ jobs:
python -m spacy download en_core_web_sm
- name: Test with pytest
run: |
python -m pytest --cov -v
python -m pytest --cov -v --cov-report xml:/home/runner/coverage.xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
uses: codecov/codecov-action@v3.1.1
with:
files: /home/runner/coverage.xml
43 changes: 0 additions & 43 deletions zshot/evaluation/dataset/med_mentions/entities.py

This file was deleted.

130 changes: 0 additions & 130 deletions zshot/evaluation/dataset/med_mentions/utils.py

This file was deleted.

2 changes: 1 addition & 1 deletion zshot/linker/linker_regen/linker_regen.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from zshot.linker.linker_regen.utils import create_input
from zshot.utils.data_models import Entity, Span

MODEL_NAME = "gabriele-picco/regen-disambiguation"
MODEL_NAME = "ibm/regen-disambiguation"

START_ENT_TOKEN = "[START_ENT]"
END_ENT_TOKEN = "[END_ENT]"
Expand Down
14 changes: 5 additions & 9 deletions zshot/linker/linker_smxm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,21 @@
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast

from zshot.config import MODELS_CACHE_PATH
from zshot.linker.linker import Linker
from zshot.linker.smxm.data import (
ByDescriptionTaggerDataset,
encode_data,
tagger_multiclass_collator
)
from zshot.linker.smxm.model import BertTaggerMultiClass, device
from zshot.linker.smxm.utils import (
SmxmInput,
get_entities_names_descriptions,
load_model,
predictions_to_span_annotations,
)
from zshot.utils.data_models import Span

SMXM_MODEL_FILES_URL = (
"https://ibm.box.com/shared/static/duni7p7i4gbk0prksc6zv5uahiemfy00.zip"
)
SMXM_MODEL_FOLDER_NAME = "BertTaggerMultiClass_config03_mode_tagger_multiclass_filtered_classes__entity_descriptions_mode_annotation_guidelines__per_gpu_train_batch_size_7/checkpoint"
MODEL_NAME = "ibm/smxm"


class LinkerSMXM(Linker):
Expand All @@ -47,9 +43,9 @@ def is_end2end(self) -> bool:
def load_models(self):
""" Load SMXM model """
if self.model is None:
self.model = load_model(
SMXM_MODEL_FILES_URL, MODELS_CACHE_PATH, SMXM_MODEL_FOLDER_NAME
)
self.model = BertTaggerMultiClass.from_pretrained(
MODEL_NAME, output_hidden_states=True
).to(device)

def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[Span]]:
"""
Expand Down
2 changes: 1 addition & 1 deletion zshot/linker/smxm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.utils.data import Dataset
from transformers import BertTokenizerFast

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from zshot.linker.smxm.model import device


class ByDescriptionTaggerDataset(Dataset):
Expand Down
25 changes: 1 addition & 24 deletions zshot/linker/smxm/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
import os
import zipfile
from typing import List, Tuple

import torch
from transformers import BertTokenizerFast

from zshot.linker.smxm.model import device
from zshot.utils.data_models import Entity
from zshot.utils.data_models import Span
from zshot.linker.smxm.model import BertTaggerMultiClass
from zshot.utils import download_file

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class SmxmInput(dict):
Expand All @@ -34,24 +29,6 @@ def __init__(
super().__init__(**config)


def load_model(url: str, output_path: str, folder_name: str) -> BertTaggerMultiClass:
filename = url.rsplit("/", 1)[1]
model_zipfile_path = os.path.join(output_path, filename)
model_folder_path = os.path.join(output_path, folder_name)

if not os.path.isdir(model_folder_path):
download_file(url, output_path)
with zipfile.ZipFile(model_zipfile_path, "r") as model_zip:
model_zip.extractall(output_path)
os.remove(model_zipfile_path)

model = BertTaggerMultiClass.from_pretrained(
model_folder_path, output_hidden_states=True
).to(device)

return model


def predictions_to_span_annotations(
sentences: List[str],
predictions: List[List[int]],
Expand Down
6 changes: 3 additions & 3 deletions zshot/pipeline_config.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from typing import Optional, Union, Dict, List
from typing import Optional, Union, List

import spacy

from zshot.utils.data_models import Entity
from zshot.linker import Linker
from zshot.mentions_extractor import MentionsExtractor
from zshot.utils.data_models import Entity


class PipelineConfig(dict):

def __init__(self,
mentions_extractor: Optional[MentionsExtractor] = None,
linker: Optional[Union[Linker, str]] = None,
entities: Optional[Union[Dict[str, str], List[Entity], List[str], str]] = None,
entities: Optional[Union[List[Entity], List[str], str]] = None,
disable_default_ner: Optional[bool] = True) -> None:
config = {}

Expand Down
15 changes: 13 additions & 2 deletions zshot/tests/test_zshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ def test_disable_ner():
assert "ner" not in nlp.pipe_names


def test_wrong_pipeline():
nlp = spacy.blank("en")
assert "ner" not in nlp.pipe_names
config_zshot = PipelineConfig(mentions_extractor=DummyMentionsExtractorWithNER())
try:
nlp.add_pipe("zshot", config=config_zshot, last=True)
except ValueError:
assert True


def test_disable_mentions_extractor():
nlp = spacy.load("en_core_web_sm")
config_zshot = PipelineConfig(mentions_extractor=DummyMentionsExtractorWithNER(), linker=DummyLinkerEnd2End())
Expand All @@ -35,7 +45,8 @@ def test_disable_mentions_extractor():

def test_serialization_zshot():
nlp = spacy.blank("en")
nlp.add_pipe("zshot", last=True)
config_zshot = PipelineConfig(mentions_extractor=DummyMentionsExtractor(), linker=DummyLinker())
nlp.add_pipe("zshot", config=config_zshot, last=True)
assert "zshot" in nlp.pipe_names
assert "ner" not in nlp.pipe_names
pipes = [p for p in nlp.pipe_names if p != "zshot"]
Expand Down Expand Up @@ -79,7 +90,7 @@ def get_entities() -> List[Entity]:
assert type(zshot_component.entities[0]) == Entity


def test_call_pipe_with_piepeline_configuration():
def test_call_pipe_with_pipeline_configuration():
nlp = spacy.load("en_core_web_sm")
nlp.add_pipe("zshot", config=PipelineConfig(
mentions_extractor=DummyMentionsExtractor(),
Expand Down
Loading

0 comments on commit c44156d

Please sign in to comment.