-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add blink linker and flair mentions extractor (#5)
* Add Blink linker * Add Flair mention detector
- Loading branch information
Gabriele Picco
authored and
GitHub Enterprise
committed
Feb 11, 2022
1 parent
182b1a4
commit 1e2f0b4
Showing
13 changed files
with
224 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import spacy | ||
from spacy import displacy | ||
|
||
from zshot import Linker, MentionsExtractor | ||
|
||
text = "International Business Machines Corporation (IBM) is an American multinational technology corporation " \ | ||
"headquartered in Armonk, New York, with operations in over 171 countries." | ||
|
||
nlp = spacy.load("en_core_web_trf") | ||
nlp.disable_pipes('ner') | ||
nlp.add_pipe("zshot", config={"mentions_extractor": MentionsExtractor.FLAIR, "linker": Linker.BLINK}, last=True) | ||
print(nlp.pipe_names) | ||
|
||
doc = nlp(text) | ||
displacy.serve(doc, style="ent") |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import List | ||
|
||
from spacy.tokens import Doc | ||
|
||
|
||
class Linker(ABC): | ||
@abstractmethod | ||
def link(self, docs: List[Doc], batch_size=None): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import argparse | ||
from pathlib import Path | ||
from typing import List | ||
import pkgutil | ||
from appdata import AppDataPaths | ||
from spacy.tokens import Doc | ||
|
||
from zshot.linker.linker import Linker | ||
from zshot.utils import download_file | ||
|
||
MODELS_PATH = AppDataPaths(f".{Path(__file__).stem}").app_data_path + "/" | ||
|
||
BLINK_FILES = \ | ||
["http://dl.fbaipublicfiles.com/BLINK/entity.jsonl", | ||
"http://dl.fbaipublicfiles.com/BLINK/all_entities_large.t7"] | ||
BLINK_BI_ENCODER_FILES = \ | ||
["http://dl.fbaipublicfiles.com/BLINK/biencoder_wiki_large.json", | ||
"http://dl.fbaipublicfiles.com/BLINK/biencoder_wiki_large.bin"] | ||
BLINK_CROSS_ENCODER_FILES = \ | ||
["http://dl.fbaipublicfiles.com/BLINK/crossencoder_wiki_large.bin", | ||
"http://dl.fbaipublicfiles.com/BLINK/crossencoder_wiki_large.json"] | ||
|
||
_config = { | ||
"test_entities": None, | ||
"test_mentions": None, | ||
"interactive": False, | ||
"top_k": 1, | ||
"biencoder_model": MODELS_PATH + "biencoder_wiki_large.bin", | ||
"biencoder_config": MODELS_PATH + "biencoder_wiki_large.json", | ||
"entity_catalogue": MODELS_PATH + "entity.jsonl", | ||
"entity_encoding": MODELS_PATH + "all_entities_large.t7", | ||
"crossencoder_model": MODELS_PATH + "crossencoder_wiki_large.bin", | ||
"crossencoder_config": MODELS_PATH + "crossencoder_wiki_large.json", | ||
"fast": True, | ||
"output_path": "logs/" | ||
} | ||
|
||
|
||
class Blink(Linker): | ||
|
||
def __init__(self): | ||
if not pkgutil.find_loader("blink"): | ||
raise Exception("Blink module not installed. You need to install blink in order to use the Blink Linker." | ||
"Install it with: pip install git+https://github.com/facebookresearch/BLINK.git#egg=BLINK") | ||
self.config = argparse.Namespace(**_config) | ||
self.models = None | ||
|
||
def download_models(self): | ||
for f in BLINK_BI_ENCODER_FILES + BLINK_FILES: | ||
download_file(f, output_dir=MODELS_PATH) | ||
if not self.config.fast: | ||
for f in BLINK_CROSS_ENCODER_FILES: | ||
download_file(f, output_dir=MODELS_PATH) | ||
|
||
def load_models(self): | ||
import blink.main_dense as main_dense | ||
self.download_models() | ||
if self.models is None: | ||
self.models = main_dense.load_models(self.config, logger=None) | ||
|
||
def link(self, docs: List[Doc], batch_size=None): | ||
import blink.main_dense as main_dense | ||
self.load_models() | ||
data_to_link = [] | ||
for doc_id, doc in enumerate(docs): | ||
for mention_id, mention in enumerate(doc._.mentions): | ||
data_to_link.append( | ||
{ | ||
"id": doc_id, | ||
"mention_id": mention_id, | ||
"label": "unknown", | ||
"label_id": -1, | ||
"context_left": doc.text[:mention.start_char].lower(), | ||
"mention": mention.text.lower(), | ||
"context_right": doc.text[mention.end_char:].lower(), | ||
}) | ||
_, _, _, _, _, predictions, scores, = main_dense.run(self.config, None, *self.models, test_data=data_to_link) | ||
for data, pred in zip(data_to_link, predictions): | ||
doc = docs[data['id']] | ||
mention = doc._.mentions[data['mention_id']] | ||
doc.ents += (doc.char_span(mention.start_char, mention.end_char, label=pred[0]),) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import pkgutil | ||
from pydoc import Doc | ||
from typing import List | ||
|
||
from zshot.mentions_extractor.mentions_extractor import MentionsExtractor | ||
|
||
|
||
class FlairMentionsExtractor(MentionsExtractor): | ||
|
||
def __init__(self): | ||
if not pkgutil.find_loader("flair"): | ||
raise Exception("Flair module not installed. You need to install Flair for using this class." | ||
"Install it with: pip install flair") | ||
from flair.models import SequenceTagger | ||
self.model = SequenceTagger.load("ner") | ||
|
||
def extract_mentions(self, docs: List[Doc], batch_size=None): | ||
from flair.data import Sentence | ||
for doc in docs: | ||
sent = Sentence(str(doc), use_tokenizer=True) | ||
self.model.predict(sent) | ||
sent_mentions = sent.to_dict(tag_type="ner")["entities"] | ||
for mention in sent_mentions: | ||
doc._.mentions.append(doc.char_span(mention['start_pos'], mention['end_pos'])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from abc import ABC, abstractmethod | ||
from pydoc import Doc | ||
from typing import List | ||
|
||
|
||
class MentionsExtractor(ABC): | ||
@abstractmethod | ||
def extract_mentions(self, docs: List[Doc], batch_size=None): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import functools | ||
import os | ||
import pathlib | ||
import shutil | ||
from urllib.request import urlopen | ||
import requests | ||
from tqdm.auto import tqdm | ||
|
||
|
||
def download_file(url, output_dir="."): | ||
filename = url.rsplit('/', 1)[1] | ||
path = pathlib.Path(os.path.join(output_dir, filename)).resolve() | ||
path.parent.mkdir(parents=True, exist_ok=True) | ||
with requests.get(url, stream=True) as r: | ||
total_length = int(urlopen(url=url).info().get('Content-Length', 0)) | ||
if path.exists() and os.path.getsize(path) == total_length: | ||
return | ||
r.raw.read = functools.partial(r.raw.read, decode_content=True) | ||
with tqdm.wrapattr(r.raw, "read", total=total_length, desc=f"Downloading {filename}") as raw: | ||
with path.open("wb") as output: | ||
shutil.copyfileobj(raw, output) |