Skip to content

Commit

Permalink
Add blink linker and flair mentions extractor (#5)
Browse files Browse the repository at this point in the history
* Add Blink linker

* Add Flair mention detector
  • Loading branch information
Gabriele Picco authored and GitHub Enterprise committed Feb 11, 2022
1 parent 182b1a4 commit 1e2f0b4
Show file tree
Hide file tree
Showing 13 changed files with 224 additions and 21 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,17 @@ Zero and Few shot named entity recognition plugin for Spacy
# Run tests

python -m pytest -v

# For using Blink NER

pip install git+https://github.com/facebookresearch/BLINK.git#egg=BLINK


## Examples with Flair, Blink and Displacy

pip install flair
pip install git+https://github.com/facebookresearch/BLINK.git#egg=BLINK

Run Wikification on a test sentence

python -m zshot.examples.wikification
7 changes: 5 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
version=version,
description="Zero and Few shot named entity recognition",
long_description="""Zero and Few shot named entity recognition""",
classifiers=[], # Get strings from http://pypi.python.org/pypi?%3Aaction=list_classifiers
classifiers=[],
keywords='NER Zero-Shot Few-Shot',
author='IBM Research',
author_email='',
Expand All @@ -16,7 +16,10 @@
include_package_data=True,
zip_safe=False,
install_requires=[
"spacy~=3.2.1"
"spacy~=3.2.1",
"requests~=2.27.1",
"appdata~=2.1.2",
"tqdm~=4.62.3",
],
entry_points="""
# -*- Entry points: -*-
Expand Down
60 changes: 44 additions & 16 deletions zshot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
from enum import Enum
from typing import Dict, Optional, List, Union

from spacy.language import Language
from spacy.tokens import Doc

from zshot.linker.linker_blink import Blink
from zshot.mentions_extractor.flair_mentions_extractor import FlairMentionsExtractor


class Linker(str, Enum):
BLINK = "BLINK"
NONE = "NONE"


class MentionsExtractor(str, Enum):
FLAIR = "FLAIR"
NONE = "NONE"


class Entity(Dict):
def __init__(self, name: str, description: str, label: str = None, vocabulary: List[str] = None):
Expand All @@ -12,40 +27,53 @@ def __init__(self, name: str, description: str, label: str = None, vocabulary: L
self.vocabulary = vocabulary


@Language.factory("zshot", default_config={"entities": None})
def create_zshot_component(nlp: Language, name: str, entities: Optional[Union[Dict[str, str], List[Entity]]]):
return Zshot(nlp, entities)
@Language.factory("zshot", default_config={"entities": None, "mentions_extractor": None, "linker": None})
def create_zshot_component(nlp: Language, name: str,
entities: Optional[Union[Dict[str, str], List[Entity]]],
mentions_extractor: Optional[MentionsExtractor],
linker: Optional[Linker]):
return Zshot(nlp, entities, mentions_extractor, linker)


class Zshot:

def __init__(self, nlp: Language, entities: Optional[Union[Dict[str, str], List[Entity]]]):
# Register custom extension on the Doc
def __init__(self, nlp: Language,
entities: Optional[Union[Dict[str, str], List[Entity]]],
mentions_extractor: MentionsExtractor,
linker: Linker):
if isinstance(entities, dict):
entities = [Entity(name=name, description=description) for name, description in entities.items()]
self.nlp = nlp
self.entities = entities
if not Doc.has_extension("acronyms"):
Doc.set_extension("acronyms", default=[])
self.mentions_extractor = None
self.linker = None
if mentions_extractor == MentionsExtractor.FLAIR:
self.mentions_extractor = FlairMentionsExtractor()
if linker == Linker.BLINK:
self.linker = Blink()
if not Doc.has_extension("mentions"):
Doc.set_extension("mentions", default=[])

def __call__(self, doc: Doc) -> Doc:
# Add the matched spans when doc is processed
doc._.acronyms.append("test")
self.extracts_mentions([doc])
self.link_entities([doc])
return doc

def pipe(self, docs: List[Doc], batch_size: int, **kwargs):
""".
docs: A sequence of spacy documents.
YIELDS (Doc): A sequence of Doc objects, in order.
"""
self.extracts_mentions(docs)
self.link_entities(docs)
self.extracts_mentions(docs, batch_size=batch_size)
self.link_entities(docs, batch_size=batch_size)
for doc in docs:
yield self(doc)
yield doc

def extracts_mentions(self, docs: List[Doc]):
# Extract and filter mentions
pass
def extracts_mentions(self, docs: List[Doc], batch_size=None):
if self.mentions_extractor:
self.mentions_extractor.extract_mentions(docs, batch_size=batch_size)

def link_entities(self, docs):
pass
def link_entities(self, docs, batch_size=None):
if self.linker:
self.linker.link(docs, batch_size=batch_size)
Empty file added zshot/examples/__init__.py
Empty file.
15 changes: 15 additions & 0 deletions zshot/examples/wikification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import spacy
from spacy import displacy

from zshot import Linker, MentionsExtractor

text = "International Business Machines Corporation (IBM) is an American multinational technology corporation " \
"headquartered in Armonk, New York, with operations in over 171 countries."

nlp = spacy.load("en_core_web_trf")
nlp.disable_pipes('ner')
nlp.add_pipe("zshot", config={"mentions_extractor": MentionsExtractor.FLAIR, "linker": Linker.BLINK}, last=True)
print(nlp.pipe_names)

doc = nlp(text)
displacy.serve(doc, style="ent")
Empty file added zshot/linker/__init__.py
Empty file.
10 changes: 10 additions & 0 deletions zshot/linker/linker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from abc import ABC, abstractmethod
from typing import List

from spacy.tokens import Doc


class Linker(ABC):
@abstractmethod
def link(self, docs: List[Doc], batch_size=None):
pass
81 changes: 81 additions & 0 deletions zshot/linker/linker_blink.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import argparse
from pathlib import Path
from typing import List
import pkgutil
from appdata import AppDataPaths
from spacy.tokens import Doc

from zshot.linker.linker import Linker
from zshot.utils import download_file

MODELS_PATH = AppDataPaths(f".{Path(__file__).stem}").app_data_path + "/"

BLINK_FILES = \
["http://dl.fbaipublicfiles.com/BLINK/entity.jsonl",
"http://dl.fbaipublicfiles.com/BLINK/all_entities_large.t7"]
BLINK_BI_ENCODER_FILES = \
["http://dl.fbaipublicfiles.com/BLINK/biencoder_wiki_large.json",
"http://dl.fbaipublicfiles.com/BLINK/biencoder_wiki_large.bin"]
BLINK_CROSS_ENCODER_FILES = \
["http://dl.fbaipublicfiles.com/BLINK/crossencoder_wiki_large.bin",
"http://dl.fbaipublicfiles.com/BLINK/crossencoder_wiki_large.json"]

_config = {
"test_entities": None,
"test_mentions": None,
"interactive": False,
"top_k": 1,
"biencoder_model": MODELS_PATH + "biencoder_wiki_large.bin",
"biencoder_config": MODELS_PATH + "biencoder_wiki_large.json",
"entity_catalogue": MODELS_PATH + "entity.jsonl",
"entity_encoding": MODELS_PATH + "all_entities_large.t7",
"crossencoder_model": MODELS_PATH + "crossencoder_wiki_large.bin",
"crossencoder_config": MODELS_PATH + "crossencoder_wiki_large.json",
"fast": True,
"output_path": "logs/"
}


class Blink(Linker):

def __init__(self):
if not pkgutil.find_loader("blink"):
raise Exception("Blink module not installed. You need to install blink in order to use the Blink Linker."
"Install it with: pip install git+https://github.com/facebookresearch/BLINK.git#egg=BLINK")
self.config = argparse.Namespace(**_config)
self.models = None

def download_models(self):
for f in BLINK_BI_ENCODER_FILES + BLINK_FILES:
download_file(f, output_dir=MODELS_PATH)
if not self.config.fast:
for f in BLINK_CROSS_ENCODER_FILES:
download_file(f, output_dir=MODELS_PATH)

def load_models(self):
import blink.main_dense as main_dense
self.download_models()
if self.models is None:
self.models = main_dense.load_models(self.config, logger=None)

def link(self, docs: List[Doc], batch_size=None):
import blink.main_dense as main_dense
self.load_models()
data_to_link = []
for doc_id, doc in enumerate(docs):
for mention_id, mention in enumerate(doc._.mentions):
data_to_link.append(
{
"id": doc_id,
"mention_id": mention_id,
"label": "unknown",
"label_id": -1,
"context_left": doc.text[:mention.start_char].lower(),
"mention": mention.text.lower(),
"context_right": doc.text[mention.end_char:].lower(),
})
_, _, _, _, _, predictions, scores, = main_dense.run(self.config, None, *self.models, test_data=data_to_link)
for data, pred in zip(data_to_link, predictions):
doc = docs[data['id']]
mention = doc._.mentions[data['mention_id']]
doc.ents += (doc.char_span(mention.start_char, mention.end_char, label=pred[0]),)
Empty file.
24 changes: 24 additions & 0 deletions zshot/mentions_extractor/flair_mentions_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pkgutil
from pydoc import Doc
from typing import List

from zshot.mentions_extractor.mentions_extractor import MentionsExtractor


class FlairMentionsExtractor(MentionsExtractor):

def __init__(self):
if not pkgutil.find_loader("flair"):
raise Exception("Flair module not installed. You need to install Flair for using this class."
"Install it with: pip install flair")
from flair.models import SequenceTagger
self.model = SequenceTagger.load("ner")

def extract_mentions(self, docs: List[Doc], batch_size=None):
from flair.data import Sentence
for doc in docs:
sent = Sentence(str(doc), use_tokenizer=True)
self.model.predict(sent)
sent_mentions = sent.to_dict(tag_type="ner")["entities"]
for mention in sent_mentions:
doc._.mentions.append(doc.char_span(mention['start_pos'], mention['end_pos']))
9 changes: 9 additions & 0 deletions zshot/mentions_extractor/mentions_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from abc import ABC, abstractmethod
from pydoc import Doc
from typing import List


class MentionsExtractor(ABC):
@abstractmethod
def extract_mentions(self, docs: List[Doc], batch_size=None):
pass
4 changes: 1 addition & 3 deletions zshot/tests/test_zshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ def test_call_pipe_with_dict():
"in a north-bending arc of the river Seine"}}, last=True)
# Process a doc and see the results
nlp(DOCS[0])
for doc in nlp.pipe(DOCS):
print(doc._.acronyms)
assert "zshot" in nlp.pipe_names


Expand All @@ -47,5 +45,5 @@ def test_call_pipe_with_entities():
# Process a doc and see the results
nlp(DOCS[0])
for doc in nlp.pipe(DOCS):
print(doc._.acronyms)
print(doc._.mentions)
assert "zshot" in nlp.pipe_names
21 changes: 21 additions & 0 deletions zshot/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import functools
import os
import pathlib
import shutil
from urllib.request import urlopen
import requests
from tqdm.auto import tqdm


def download_file(url, output_dir="."):
filename = url.rsplit('/', 1)[1]
path = pathlib.Path(os.path.join(output_dir, filename)).resolve()
path.parent.mkdir(parents=True, exist_ok=True)
with requests.get(url, stream=True) as r:
total_length = int(urlopen(url=url).info().get('Content-Length', 0))
if path.exists() and os.path.getsize(path) == total_length:
return
r.raw.read = functools.partial(r.raw.read, decode_content=True)
with tqdm.wrapattr(r.raw, "read", total=total_length, desc=f"Downloading {filename}") as raw:
with path.open("wb") as output:
shutil.copyfileobj(raw, output)

0 comments on commit 1e2f0b4

Please sign in to comment.