Skip to content

Commit

Permalink
Use dict comprehension suggested by @svlandeg
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Aug 4, 2022
1 parent 51f72e4 commit 6e7b958
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 19 deletions.
7 changes: 4 additions & 3 deletions spacy/pipeline/edit_tree_lemmatizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,10 @@ def _scores2guesses(self, docs, scores):
def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT):
batch_tree_ids = activations["guesses"]
for i, doc in enumerate(docs):
doc.activations[self.name] = {}
for activation in self.store_activations:
doc.activations[self.name][activation] = activations[activation][i]
stored_activations = {
key: activations[key][i] for key in self.store_activations
}
doc.activations[self.name] = stored_activations
doc_tree_ids = batch_tree_ids[i]
if hasattr(doc_tree_ids, "get"):
doc_tree_ids = doc_tree_ids.get()
Expand Down
17 changes: 10 additions & 7 deletions spacy/pipeline/entity_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,11 @@ def predict(self, docs: Iterable[Doc]) -> ActivationsT:
# shortcut for efficiency reasons: take the 1 candidate
final_kb_ids.append(candidates[0].entity_)
self._add_activations(
doc_scores, doc_scores_lens, doc_ents, [1.0], [candidates[0].entity_]
doc_scores,
doc_scores_lens,
doc_ents,
[1.0],
[candidates[0].entity_],
)
else:
random.shuffle(candidates)
Expand Down Expand Up @@ -541,12 +545,11 @@ def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT) -> Non
i = 0
overwrite = self.cfg["overwrite"]
for j, doc in enumerate(docs):
doc.activations[self.name] = {}
for activation in self.store_activations:
# We only copy activations that are Ragged.
doc.activations[self.name][activation] = cast(
Ragged, activations[activation][j]
)
# We only copy activations that are Ragged.
stored_activations = {
key: cast(Ragged, activations[key][i]) for key in self.store_activations
}
doc.activations[self.name] = stored_activations
for ent in doc.ents:
kb_id = kb_ids[i]
i += 1
Expand Down
7 changes: 4 additions & 3 deletions spacy/pipeline/morphologizer.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,10 @@ class Morphologizer(Tagger):
# to allocate a compatible container out of the iterable.
labels = tuple(self.labels)
for i, doc in enumerate(docs):
doc.activations[self.name] = {}
for activation in self.store_activations:
doc.activations[self.name][activation] = activations[activation][i]
stored_activations = {
key: activations[key][i] for key in self.store_activations
}
doc.activations[self.name] = stored_activations
doc_tag_ids = batch_tag_ids[i]
if hasattr(doc_tag_ids, "get"):
doc_tag_ids = doc_tag_ids.get()
Expand Down
7 changes: 4 additions & 3 deletions spacy/pipeline/senter.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,10 @@ class SentenceRecognizer(Tagger):
cdef Doc doc
cdef bint overwrite = self.cfg["overwrite"]
for i, doc in enumerate(docs):
doc.activations[self.name] = {}
for activation in self.store_activations:
doc.activations[self.name][activation] = activations[activation][i]
stored_activations = {
key: activations[key][i] for key in self.store_activations
}
doc.activations[self.name] = stored_activations
doc_tag_ids = batch_tag_ids[i]
if hasattr(doc_tag_ids, "get"):
doc_tag_ids = doc_tag_ids.get()
Expand Down
7 changes: 4 additions & 3 deletions spacy/pipeline/tagger.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,10 @@ class Tagger(TrainablePipe):
cdef bint overwrite = self.cfg["overwrite"]
labels = self.labels
for i, doc in enumerate(docs):
doc.activations[self.name] = {}
for activation in self.store_activations:
doc.activations[self.name][activation] = activations[activation][i]
stored_activations = {
key: activations[key][i] for key in self.store_activations
}
doc.activations[self.name] = stored_activations
doc_tag_ids = batch_tag_ids[i]
if hasattr(doc_tag_ids, "get"):
doc_tag_ids = doc_tag_ids.get()
Expand Down

0 comments on commit 6e7b958

Please sign in to comment.