Skip to content

Commit

Permalink
Add a test which builds a couple models with different batchings of i…
Browse files Browse the repository at this point in the history
…nputs to make sure the batching doesn't fail in some weird way. Then, redo the calls to update() for the batches and check that the losses are the same for a batch of size one or a batch of size two
  • Loading branch information
AngledLuffa committed Nov 7, 2023
1 parent f4db150 commit d0d8efb
Showing 1 changed file with 43 additions and 2 deletions.
45 changes: 43 additions & 2 deletions stanza/tests/pos/test_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,13 @@ def charlm_args(self):
charlm_args = build_charlm_args("en", charlm, model_dir=TEST_MODELS_DIR)
return charlm_args

def run_training(self, tmp_path, wordvec_pretrain_file, train_text, dev_text, augment_nopunct=False, extra_args=None):
def run_training(self, tmp_path, wordvec_pretrain_file, train_text, dev_text, augment_nopunct=False, save_name='test_tagger.pt', extra_args=None):
"""
Run the training for a few iterations, load & return the model
"""
dev_file = str(tmp_path / "dev.conllu")
pred_file = str(tmp_path / "pred.conllu")

save_name = "test_tagger.pt"
save_file = str(tmp_path / save_name)

if isinstance(train_text, str):
Expand Down Expand Up @@ -235,6 +234,48 @@ def test_train_charlm_projection(self, tmp_path, wordvec_pretrain_file, charlm_a
extra_args = charlm_args + ['--charlm_transform_dim', '100']
trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=extra_args)

def test_missing_batched_xpos(self, tmp_path, wordvec_pretrain_file):
base_extra_args = ['--seed', '1000', '--save_each', '--eval_interval', '1', '--optim', 'sgd', '--augment_nopunct', '0.0']
single_args = ['--max_steps', '2', '--batch_size', '1']
single_item_trainer = self.run_training(tmp_path, wordvec_pretrain_file, [TRAIN_DATA_NO_FEATS, TRAIN_DATA_NO_XPOS], DEV_DATA, save_name='single_item_batch.pt', extra_args=base_extra_args + single_args)

double_args = ['--max_steps', '1', '--batch_size', '2']
double_item_trainer = self.run_training(tmp_path, wordvec_pretrain_file, [TRAIN_DATA_NO_FEATS, TRAIN_DATA_NO_XPOS], DEV_DATA, save_name='double_item_batch.pt', extra_args=base_extra_args + double_args)

pt = pretrain.Pretrain(wordvec_pretrain_file)

save_each_name = tagger.save_each_file_name(single_item_trainer.args)
single_model_files = [save_each_name % i for i in range(3)]
assert all(os.path.exists(x) for x in single_model_files)
single_trainer = Trainer(pretrain=pt, model_file=single_model_files[0])

save_each_name = tagger.save_each_file_name(double_item_trainer.args)
double_model_files = [save_each_name % i for i in range(2)]
assert all(os.path.exists(x) for x in double_model_files)
double_trainer = Trainer(pretrain=pt, model_file=double_model_files[0])

# these should be created with the same weights
assert torch.allclose(single_trainer.model.upos_clf.weight, double_trainer.model.upos_clf.weight)
assert torch.allclose(single_trainer.model.xpos_clf.W_bilin.weight, double_trainer.model.xpos_clf.W_bilin.weight)

_, _, train_batches = tagger.load_training_data(single_trainer.args, pt)
total_single_loss = 0.0
for batch in iter(train_batches):
batch = batch._replace(upos=torch.zeros_like(batch.upos),
ufeats=torch.zeros_like(batch.ufeats))
# one should be zero, one should have the expected loss
# incidentally this checks that the loss is not nan for having a blank xpos
total_single_loss += single_trainer.update(batch)

_, _, train_batches = tagger.load_training_data(double_trainer.args, pt)
total_double_loss = 0.0
for batch in iter(train_batches):
batch = batch._replace(upos=torch.zeros_like(batch.upos),
ufeats=torch.zeros_like(batch.ufeats))
total_double_loss += double_trainer.update(batch)

assert pytest.approx(total_single_loss) == total_double_loss

def test_missing_column(self, tmp_path, wordvec_pretrain_file):
"""
Test that using train files with missing columns works
Expand Down

0 comments on commit d0d8efb

Please sign in to comment.