Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add device option to Pipeline & Components #42

Merged
merged 2 commits into from
Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion zshot/linker/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from abc import ABC, abstractmethod
from typing import Iterator, List, Optional, Union

import torch
from spacy.tokens import Doc
from spacy.util import ensure_path

Expand All @@ -17,9 +18,18 @@ class Linker(ABC):
extracted mentions or perform end-2-end extraction
"""

def __init__(self):
def __init__(self, device: Optional[Union[str, torch.device]] = None):
self._entities = None
self._is_end2end = False
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device

def set_device(self, device: Union[str, torch.device]):
"""
Set the device to use
:param device:
:return:
"""
self.device = device

def set_kg(self, entities: Iterator[Entity]):
"""
Expand Down
8 changes: 3 additions & 5 deletions zshot/linker/linker_smxm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from typing import Iterator, List, Optional, Union

import torch
from spacy.tokens import Doc
from transformers import BertTokenizerFast

from zshot.linker.linker import Linker
from zshot.utils.models.smxm.model import BertTaggerMultiClass, device
from zshot.utils.data_models import Span
from zshot.utils.models.smxm.model import BertTaggerMultiClass
from zshot.utils.models.smxm.utils import (
get_entities_names_descriptions,
smxm_predict
)
from zshot.utils.data_models import Span

ONTONOTES_MODEL_NAME = "ibm/smxm"

Expand All @@ -27,7 +26,6 @@ def __init__(self, model_name=ONTONOTES_MODEL_NAME):

self.model_name = model_name
self.model = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

@property
def is_end2end(self) -> bool:
Expand All @@ -39,7 +37,7 @@ def load_models(self):
if self.model is None:
self.model = BertTaggerMultiClass.from_pretrained(
self.model_name, output_hidden_states=True
).to(device)
).to(self.device)

def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[Span]]:
"""
Expand Down
15 changes: 13 additions & 2 deletions zshot/mentions_extractor/mentions_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

import zlib
from abc import ABC, abstractmethod

import torch
from spacy.tokens import Doc
from typing import List, Iterator
from typing import List, Iterator, Optional, Union

from spacy.util import ensure_path

Expand All @@ -15,8 +17,17 @@

class MentionsExtractor(ABC):

def __init__(self):
def __init__(self, device: Optional[Union[str, torch.device]] = None):
self._mentions = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device

def set_device(self, device: Union[str, torch.device]):
"""
Set the device to use
:param device:
:return:
"""
self.device = device

def set_kg(self, mentions: Iterator[Entity]):
"""
Expand Down
8 changes: 3 additions & 5 deletions zshot/mentions_extractor/mentions_extractor_smxm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from typing import Iterator, List, Optional, Union

import torch
from spacy.tokens import Doc
from transformers import BertTokenizerFast

from zshot.mentions_extractor.mentions_extractor import MentionsExtractor
from zshot.utils.models.smxm.model import BertTaggerMultiClass, device
from zshot.utils.data_models import Span
from zshot.utils.models.smxm.model import BertTaggerMultiClass
from zshot.utils.models.smxm.utils import (
get_entities_names_descriptions,
smxm_predict,
)
from zshot.utils.data_models import Span

ONTONOTES_MODEL_NAME = "ibm/smxm"

Expand All @@ -27,14 +26,13 @@ def __init__(self, model_name=ONTONOTES_MODEL_NAME):

self.model_name = model_name
self.model = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_models(self):
""" Load SMXM model """
if self.model is None:
self.model = BertTaggerMultiClass.from_pretrained(
self.model_name, output_hidden_states=True
).to(device)
).to(self.device)

def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[Span]]:
"""
Expand Down
6 changes: 5 additions & 1 deletion zshot/pipeline_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def __init__(self,
mentions: Optional[Union[List[Entity], List[str], str]] = None,
entities: Optional[Union[List[Entity], List[str], str]] = None,
relations: Optional[Union[List[Relation], str]] = None,
disable_default_ner: Optional[bool] = True) -> None:
disable_default_ner: Optional[bool] = True,
device: Optional[str] = None) -> None:
config = {}

if mentions_extractor:
Expand Down Expand Up @@ -48,6 +49,9 @@ def __init__(self,
if disable_default_ner:
config.update({'disable_default_ner': disable_default_ner})

if device:
config.update({'device': device})

super().__init__(**config)

@staticmethod
Expand Down
45 changes: 38 additions & 7 deletions zshot/relation_extractor/relation_extractor_zsrc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import torch
from torch.utils.data import DataLoader

from zshot.relation_extractor.relations_extractor import RelationsExtractor
from zshot.relation_extractor.zsrc.zero_shot_rel_class import predict, load_model
from zshot.relation_extractor.zsrc import data_helper
from zshot.relation_extractor.zsrc.zero_shot_rel_class import load_model
import numpy as np
from tqdm import tqdm
from typing import Iterator, List
Expand All @@ -10,16 +14,16 @@

class RelationsExtractorZSRC(RelationsExtractor):
def __init__(self, thr=0.5):
super().__init__()
self.model = None
self.load_models()
self.thr = thr
super(RelationsExtractor, self).__init__()

def load_models(
self,
self,
):
if self.model is None:
self.model = load_model()
self.model = load_model(self.device)

def predict(self, docs: Iterator[Doc], batch_size=None) -> List[List[RelationSpan]]:
relations_pred = []
Expand All @@ -29,7 +33,7 @@ def predict(self, docs: Iterator[Doc], batch_size=None) -> List[List[RelationSpa
for i, e1 in enumerate(doc._.spans):
for j, e2 in enumerate(doc._.spans):
if (
i == j or (e1, e2) in items_to_process or (
i == j or (e1, e2) in items_to_process or (
e2, e1) in items_to_process
):
continue
Expand All @@ -39,8 +43,7 @@ def predict(self, docs: Iterator[Doc], batch_size=None) -> List[List[RelationSpa
relations_probs = []
if self.relations is not None:
for rel in self.relations:
_, probs = predict(
self.model,
_, probs = self._predict_internal(
[(e1, e2, doc.text)],
rel.description,
batch_size,
Expand All @@ -55,3 +58,31 @@ def predict(self, docs: Iterator[Doc], batch_size=None) -> List[List[RelationSpa
)
relations_pred.append(relations_doc)
return relations_pred

def _predict_internal(self, items_to_process, relation_description, batch_size=4):
trainset = data_helper.ZSDataset(
'test', items_to_process, relation_description)
trainloader = DataLoader(trainset, batch_size=batch_size,
collate_fn=data_helper.create_mini_batch_fewrel_aio, shuffle=False)
all_preds = []
all_probs = []
for data in trainloader:
tokens_tensors, segments_tensors, marked_e1, marked_e2, masks_tensors, labels = [
t.to(self.device) for t in data]
if tokens_tensors.shape[1] <= 512:
with torch.no_grad():
outputs = self.model(input_ids=tokens_tensors,
token_type_ids=segments_tensors,
e1_mask=marked_e1,
e2_mask=marked_e2,
attention_mask=masks_tensors,
labels=labels)
preds = outputs[1]
probs = preds.detach().cpu().numpy()[:, 1]
all_probs.extend(probs)
all_preds.extend([item >= 0.5 for item in probs])
else:
all_probs.extend([-1] * tokens_tensors.shape[0])
all_preds.extend([False] * tokens_tensors.shape[0])

return all_preds, all_probs
12 changes: 11 additions & 1 deletion zshot/relation_extractor/relations_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from abc import ABC, abstractmethod
from typing import List, Iterator, Optional, Union

import torch
from spacy.tokens import Doc
from spacy.util import ensure_path

Expand All @@ -13,8 +14,17 @@

class RelationsExtractor(ABC):

def __init__(self):
def __init__(self, device: Optional[Union[str, torch.device]] = None):
self._relations = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device

def set_device(self, device: Union[str, torch.device]):
"""
Set the device to use
:param device:
:return:
"""
self.device = device

def set_relations(self, relations: Iterator[Relation]):
"""
Expand Down
Loading