Skip to content

Commit

Permalink
restrict model revisions
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed May 21, 2022
1 parent 7898bba commit 9a36fc4
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
18 changes: 13 additions & 5 deletions frame_semantic_transformer/FrameSemanticTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer

from frame_semantic_transformer.constants import MODEL_MAX_LENGTH, OFFICIAL_RELEASES
from frame_semantic_transformer.constants import (
MODEL_MAX_LENGTH,
MODEL_REVISION,
OFFICIAL_RELEASES,
)
from frame_semantic_transformer.data.data_utils import chunk_list, marked_string_to_locs
from frame_semantic_transformer.data.framenet import ensure_framenet_downloaded
from frame_semantic_transformer.predict import batch_predict
Expand Down Expand Up @@ -39,6 +43,7 @@ class FrameSemanticTransformer:
_model: T5ForConditionalGeneration | None = None
_tokenizer: T5Tokenizer | None = None
model_path: str
model_revision: str | None = None
device: torch.device
max_batch_size: int
predictions_per_sample: int
Expand All @@ -53,6 +58,7 @@ def __init__(
self.model_path = model_name_or_path
if model_name_or_path in OFFICIAL_RELEASES:
self.model_path = f"chanind/frame-semantic-transformer-{model_name_or_path}"
self.model_revision = MODEL_REVISION
self.device = torch.device("cuda" if use_gpu else "cpu")
self.max_batch_size = max_batch_size
self.predictions_per_sample = predictions_per_sample
Expand All @@ -62,11 +68,13 @@ def setup(self) -> None:
Initialize the model and tokenizer, and download models / files as needed
If this is not called explicitly it will be lazily called before inference
"""
self._model = T5ForConditionalGeneration.from_pretrained(self.model_path).to(
self.device
)
self._model = T5ForConditionalGeneration.from_pretrained(
self.model_path, revision=self.model_revision
).to(self.device)
self._tokenizer = T5Tokenizer.from_pretrained(
self.model_path, model_max_length=MODEL_MAX_LENGTH
self.model_path,
revision=self.model_revision,
model_max_length=MODEL_MAX_LENGTH,
)
ensure_framenet_downloaded()

Expand Down
1 change: 1 addition & 0 deletions frame_semantic_transformer/constants.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
MODEL_MAX_LENGTH = 512
OFFICIAL_RELEASES = ["base", "small"] # TODO: small, large
MODEL_REVISION = "v0.0.1"

0 comments on commit 9a36fc4

Please sign in to comment.