diff --git a/zshot/evaluation/dataset/med_mentions/med_mentions.py b/zshot/evaluation/dataset/med_mentions/med_mentions.py index 7d293a9..46d2bb5 100644 --- a/zshot/evaluation/dataset/med_mentions/med_mentions.py +++ b/zshot/evaluation/dataset/med_mentions/med_mentions.py @@ -1,6 +1,7 @@ import json +from typing import Dict, Union -from datasets import load_dataset, DatasetDict +from datasets import load_dataset, Split from huggingface_hub import hf_hub_download from zshot.evaluation.dataset.dataset import DatasetWithEntities @@ -10,7 +11,7 @@ ENTITIES_FN = "entities.json" -def load_medmentions() -> DatasetDict[DatasetWithEntities]: +def load_medmentions() -> Dict[Union[str, Split], DatasetWithEntities]: dataset = load_dataset(REPO_ID) entities_file = hf_hub_download(repo_id=REPO_ID, repo_type='dataset', filename=ENTITIES_FN) diff --git a/zshot/evaluation/dataset/ontonotes/onto_notes.py b/zshot/evaluation/dataset/ontonotes/onto_notes.py index c6ca1a6..f967af7 100644 --- a/zshot/evaluation/dataset/ontonotes/onto_notes.py +++ b/zshot/evaluation/dataset/ontonotes/onto_notes.py @@ -1,4 +1,6 @@ -from datasets import ClassLabel, load_dataset, DatasetDict +from typing import Dict, Union + +from datasets import ClassLabel, load_dataset, DatasetDict, Split from zshot.evaluation.dataset.dataset import DatasetWithEntities from zshot.evaluation.dataset.ontonotes.entities import ONTONOTES_ENTITIES @@ -52,7 +54,7 @@ def remove_out_of_split(sentence, split): return sentence -def load_ontonotes() -> DatasetDict[DatasetWithEntities]: +def load_ontonotes() -> Dict[Union[str, Split], DatasetWithEntities]: dataset_zs = load_dataset("conll2012_ontonotesv5", "english_v12") ontonotes_zs = DatasetDict()