Skip to content

Commit

Permalink
🩹 Remove cache after tests passed
Browse files Browse the repository at this point in the history
  • Loading branch information
Gabriele Picco committed Sep 29, 2022
1 parent bd03aea commit 28dd99d
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 16 deletions.
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
install_requires=[
"spacy>=3.4.1",
"requests>=2.28",
"appdata~=2.1.2",
"tqdm>=4.62.3",
"setuptools~=60.0.0", # Needed to install dynamic packages from source (e.g. Blink)
"prettytable>=3.4",
Expand Down
4 changes: 1 addition & 3 deletions zshot/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import os
import pathlib

from appdata import AppDataPaths

MODELS_CACHE_PATH = os.getenv("MODELS_CACHE_PATH") if "MODELS_CACHE_PATH" in os.environ \
else AppDataPaths(f"{pathlib.Path(__file__).stem}").app_data_path + "/"
else f"{pathlib.Path.home()}/.cache/zshot/"
16 changes: 16 additions & 0 deletions zshot/tests/linker/test_regen_linker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
import logging
import shutil
from pathlib import Path

import pytest
import spacy

from zshot import PipelineConfig
from zshot.linker.linker_regen.linker_regen import LinkerRegen
from zshot.mentions_extractor import MentionsExtractorSpacy
from zshot.tests.config import EX_DOCS, EX_ENTITIES

logger = logging.getLogger(__name__)


@pytest.fixture(scope="module", autouse=True)
def teardown():
logger.warning("Starting regen tests")
yield True
logger.warning("Removing cache")
shutil.rmtree(f"{Path.home()}/.cache/huggingface", ignore_errors=True)
shutil.rmtree(f"{Path.home()}/.cache/zshot", ignore_errors=True)


def test_regen_linker():
nlp = spacy.load("en_core_web_sm")
Expand Down
29 changes: 17 additions & 12 deletions zshot/tests/linker/test_smxm_linker.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
import os
import logging
import shutil
from pathlib import Path

import pytest
import spacy

from zshot import PipelineConfig
from zshot.config import MODELS_CACHE_PATH
from zshot import PipelineConfig, Linker
from zshot.linker import LinkerSMXM
from zshot.linker.linker_smxm import SMXM_MODEL_FILES_URL, SMXM_MODEL_FOLDER_NAME
from zshot.linker.smxm.model import BertTaggerMultiClass
from zshot.linker.smxm.utils import load_model
from zshot.tests.config import EX_DOCS, EX_ENTITIES

logger = logging.getLogger(__name__)

def test_smxm_download():
model_folder_path = os.path.join(MODELS_CACHE_PATH, SMXM_MODEL_FOLDER_NAME)
if os.path.exists(model_folder_path):
shutil.rmtree(model_folder_path)

model = load_model(SMXM_MODEL_FILES_URL, MODELS_CACHE_PATH, SMXM_MODEL_FOLDER_NAME)
@pytest.fixture(scope="module", autouse=True)
def teardown():
logger.warning("Starting smxm tests")
yield True
logger.warning("Removing cache")
shutil.rmtree(f"{Path.home()}/.cache/huggingface", ignore_errors=True)
shutil.rmtree(f"{Path.home()}/.cache/zshot", ignore_errors=True)


assert isinstance(model, BertTaggerMultiClass)
def test_smxm_download():
linker = LinkerSMXM()
linker.load_models()
assert isinstance(linker, Linker)


def test_smxm_linker():
Expand Down

0 comments on commit 28dd99d

Please sign in to comment.