Skip to content

Commit

Permalink
fix(tests): flair inference tests
Browse files Browse the repository at this point in the history
  • Loading branch information
djaniak committed Apr 9, 2022
1 parent 5bebc1a commit a218b96
Showing 1 changed file with 5 additions and 10 deletions.
15 changes: 5 additions & 10 deletions tests/test_flair_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,8 @@

from embeddings.data.data_loader import HuggingFaceDataLoader
from embeddings.data.dataset import Dataset
from embeddings.embedding.flair_embedding import FlairDocumentPoolEmbedding
from embeddings.embedding.flair_loader import (
FlairDocumentPoolEmbeddingLoader,
FlairWordEmbeddingLoader,
)
from embeddings.embedding.auto_flair import AutoFlairDocumentEmbedding
from embeddings.embedding.flair_loader import FlairWordEmbeddingLoader
from embeddings.evaluator.sequence_labeling_evaluator import SequenceLabelingEvaluator
from embeddings.evaluator.text_classification_evaluator import TextClassificationEvaluator
from embeddings.model.flair_model import FlairModel
Expand Down Expand Up @@ -59,10 +56,9 @@ def text_classification_pipeline(
)
data_loader = HuggingFaceDataLoader()
transformation = ClassificationCorpusTransformation("text", "target").then(
DownsampleFlairCorpusTransformation(*(0.005, 0.01, 0.01), stratify=False)
DownsampleFlairCorpusTransformation(*(0.01, 0.01, 0.01), stratify=False)
)
embedding_loader = FlairDocumentPoolEmbeddingLoader("clarin-pl/word2vec-kgr10", "")
embedding = embedding_loader.get_embedding(FlairDocumentPoolEmbedding)
embedding = AutoFlairDocumentEmbedding.from_hub("allegro/herbert-base-cased")
task = TextClassification(output_path.name, task_train_kwargs={"max_epochs": 1})
model = FlairModel(embedding, task)
evaluator = TextClassificationEvaluator()
Expand All @@ -89,8 +85,7 @@ def sequence_labeling_pipeline(
hidden_size=256,
task_train_kwargs={"max_epochs": 1, "mini_batch_size": 64},
)
embedding_loader = FlairWordEmbeddingLoader("clarin-pl/word2vec-kgr10", "")
embedding = embedding_loader.get_embedding()
embedding = AutoFlairDocumentEmbedding.from_hub("allegro/herbert-base-cased")
model = FlairModel(embedding, task)
evaluator = SequenceLabelingEvaluator()
pipeline = StandardPipeline(dataset, data_loader, transformation, model, evaluator)
Expand Down

0 comments on commit a218b96

Please sign in to comment.