diff --git a/stanza/models/lemma/trainer.py b/stanza/models/lemma/trainer.py index 0ee2f88ef3..410de0cad0 100644 --- a/stanza/models/lemma/trainer.py +++ b/stanza/models/lemma/trainer.py @@ -32,10 +32,10 @@ def unpack_batch(batch, device): class Trainer(object): """ A trainer for training models. """ - def __init__(self, args=None, vocab=None, emb_matrix=None, model_file=None, device=None, foundation_cache=None): + def __init__(self, args=None, vocab=None, emb_matrix=None, model_file=None, device=None, foundation_cache=None, lemma_classifier_args=None): if model_file is not None: # load everything from file - self.load(model_file, args, foundation_cache) + self.load(model_file, args, foundation_cache, lemma_classifier_args) else: # build model from scratch self.args = args @@ -292,7 +292,7 @@ def save(self, filename, skip_modules=True): torch.save(params, filename, _use_new_zipfile_serialization=False) logger.info("Model saved to {}".format(filename)) - def load(self, filename, args, foundation_cache): + def load(self, filename, args, foundation_cache, lemma_classifier_args=None): try: checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True) except BaseException: @@ -313,4 +313,4 @@ def load(self, filename, args, foundation_cache): self.vocab = MultiVocab.load_state_dict(checkpoint['vocab']) self.contextual_lemmatizers = [] for contextual in checkpoint.get('contextual', []): - self.contextual_lemmatizers.append(LemmaClassifier.from_checkpoint(contextual)) + self.contextual_lemmatizers.append(LemmaClassifier.from_checkpoint(contextual, args=lemma_classifier_args)) diff --git a/stanza/pipeline/lemma_processor.py b/stanza/pipeline/lemma_processor.py index 023c260bf4..0f22307074 100644 --- a/stanza/pipeline/lemma_processor.py +++ b/stanza/pipeline/lemma_processor.py @@ -48,11 +48,14 @@ def _set_up_model(self, config, pipeline, device): # since a long running program will remember everything # (unless we go back and make it smarter) # we make this an option, not the default + # TODO: need to update the cache to skip the contextual lemmatizer self.store_results = config.get('store_results', False) self._use_identity = False args = {'charlm_forward_file': config.get('forward_charlm_path', None), 'charlm_backward_file': config.get('backward_charlm_path', None)} - self._trainer = Trainer(args=args, model_file=config['model_path'], device=device, foundation_cache=pipeline.foundation_cache) + lemma_classifier_args = dict(args) + lemma_classifier_args['wordvec_pretrain_file'] = config.get('pretrain_path', None) + self._trainer = Trainer(args=args, model_file=config['model_path'], device=device, foundation_cache=pipeline.foundation_cache, lemma_classifier_args=lemma_classifier_args) def _set_up_requires(self): self._pretagged = self._config.get('pretagged', None) diff --git a/stanza/resources/prepare_resources.py b/stanza/resources/prepare_resources.py index fd5ff4d444..f3ec65f249 100644 --- a/stanza/resources/prepare_resources.py +++ b/stanza/resources/prepare_resources.py @@ -22,6 +22,7 @@ from stanza.models.common.constant import lcode2lang, two_to_three_letters, three_to_two_letters from stanza.resources.default_packages import PACKAGES, TRANSFORMERS, TRANSFORMER_NICKNAMES from stanza.resources.default_packages import * +from stanza.utils.datasets.prepare_lemma_classifier import DATASET_MAPPING as LEMMA_CLASSIFIER_DATASETS from stanza.utils.get_tqdm import get_tqdm tqdm = get_tqdm() @@ -179,14 +180,29 @@ def get_pos_dependencies(lang, package): return dependencies +def get_lemma_pretrain_package(lang, package): + package, uses_pretrain, uses_charlm = split_package(package) + if not uses_pretrain: + return None + if not uses_charlm: + # currently the contextual lemma classifier is only active + # for the charlm lemmatizers + return None + if "%s_%s" % (lang, package) not in LEMMA_CLASSIFIER_DATASETS: + return None + return get_pretrain_package(lang, package, {}, default_pretrains) + def get_lemma_charlm_package(lang, package): return get_charlm_package(lang, package, lemma_charlms, default_charlms) def get_lemma_dependencies(lang, package): dependencies = [] - charlm_package = get_lemma_charlm_package(lang, package) + pretrain_package = get_lemma_pretrain_package(lang, package) + if pretrain_package is not None: + dependencies.append({'model': 'pretrain', 'package': pretrain_package}) + charlm_package = get_lemma_charlm_package(lang, package) if charlm_package is not None: dependencies.append({'model': 'forward_charlm', 'package': charlm_package}) dependencies.append({'model': 'backward_charlm', 'package': charlm_package})