Skip to content

Commit

Permalink
fix(tests): flair inference tests
Browse files Browse the repository at this point in the history
  • Loading branch information
djaniak committed Apr 9, 2022
1 parent 5bebc1a commit f6b9ad5
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions tests/test_flair_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,8 @@

from embeddings.data.data_loader import HuggingFaceDataLoader
from embeddings.data.dataset import Dataset
from embeddings.embedding.flair_embedding import FlairDocumentPoolEmbedding
from embeddings.embedding.flair_loader import (
FlairDocumentPoolEmbeddingLoader,
FlairWordEmbeddingLoader,
)
from embeddings.embedding.auto_flair import AutoFlairDocumentEmbedding
from embeddings.embedding.flair_loader import FlairWordEmbeddingLoader
from embeddings.evaluator.sequence_labeling_evaluator import SequenceLabelingEvaluator
from embeddings.evaluator.text_classification_evaluator import TextClassificationEvaluator
from embeddings.model.flair_model import FlairModel
Expand Down Expand Up @@ -59,10 +56,9 @@ def text_classification_pipeline(
)
data_loader = HuggingFaceDataLoader()
transformation = ClassificationCorpusTransformation("text", "target").then(
DownsampleFlairCorpusTransformation(*(0.005, 0.01, 0.01), stratify=False)
DownsampleFlairCorpusTransformation(*(0.01, 0.01, 0.01), stratify=False)
)
embedding_loader = FlairDocumentPoolEmbeddingLoader("clarin-pl/word2vec-kgr10", "")
embedding = embedding_loader.get_embedding(FlairDocumentPoolEmbedding)
embedding = AutoFlairDocumentEmbedding.from_hub("allegro/herbert-base-cased")
task = TextClassification(output_path.name, task_train_kwargs={"max_epochs": 1})
model = FlairModel(embedding, task)
evaluator = TextClassificationEvaluator()
Expand Down

0 comments on commit f6b9ad5

Please sign in to comment.