From d0d8efb58d2f45dedaf7165f9bcc8a26b649b195 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Tue, 7 Nov 2023 00:09:22 -0800 Subject: [PATCH] Add a test which builds a couple models with different batchings of inputs to make sure the batching doesn't fail in some weird way. Then, redo the calls to update() for the batches and check that the losses are the same for a batch of size one or a batch of size two --- stanza/tests/pos/test_tagger.py | 45 +++++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/stanza/tests/pos/test_tagger.py b/stanza/tests/pos/test_tagger.py index c8e7cbc4c0..491ce78aa0 100644 --- a/stanza/tests/pos/test_tagger.py +++ b/stanza/tests/pos/test_tagger.py @@ -133,14 +133,13 @@ def charlm_args(self): charlm_args = build_charlm_args("en", charlm, model_dir=TEST_MODELS_DIR) return charlm_args - def run_training(self, tmp_path, wordvec_pretrain_file, train_text, dev_text, augment_nopunct=False, extra_args=None): + def run_training(self, tmp_path, wordvec_pretrain_file, train_text, dev_text, augment_nopunct=False, save_name='test_tagger.pt', extra_args=None): """ Run the training for a few iterations, load & return the model """ dev_file = str(tmp_path / "dev.conllu") pred_file = str(tmp_path / "pred.conllu") - save_name = "test_tagger.pt" save_file = str(tmp_path / save_name) if isinstance(train_text, str): @@ -235,6 +234,48 @@ def test_train_charlm_projection(self, tmp_path, wordvec_pretrain_file, charlm_a extra_args = charlm_args + ['--charlm_transform_dim', '100'] trainer = self.run_training(tmp_path, wordvec_pretrain_file, TRAIN_DATA, DEV_DATA, extra_args=extra_args) + def test_missing_batched_xpos(self, tmp_path, wordvec_pretrain_file): + base_extra_args = ['--seed', '1000', '--save_each', '--eval_interval', '1', '--optim', 'sgd', '--augment_nopunct', '0.0'] + single_args = ['--max_steps', '2', '--batch_size', '1'] + single_item_trainer = self.run_training(tmp_path, wordvec_pretrain_file, [TRAIN_DATA_NO_FEATS, TRAIN_DATA_NO_XPOS], DEV_DATA, save_name='single_item_batch.pt', extra_args=base_extra_args + single_args) + + double_args = ['--max_steps', '1', '--batch_size', '2'] + double_item_trainer = self.run_training(tmp_path, wordvec_pretrain_file, [TRAIN_DATA_NO_FEATS, TRAIN_DATA_NO_XPOS], DEV_DATA, save_name='double_item_batch.pt', extra_args=base_extra_args + double_args) + + pt = pretrain.Pretrain(wordvec_pretrain_file) + + save_each_name = tagger.save_each_file_name(single_item_trainer.args) + single_model_files = [save_each_name % i for i in range(3)] + assert all(os.path.exists(x) for x in single_model_files) + single_trainer = Trainer(pretrain=pt, model_file=single_model_files[0]) + + save_each_name = tagger.save_each_file_name(double_item_trainer.args) + double_model_files = [save_each_name % i for i in range(2)] + assert all(os.path.exists(x) for x in double_model_files) + double_trainer = Trainer(pretrain=pt, model_file=double_model_files[0]) + + # these should be created with the same weights + assert torch.allclose(single_trainer.model.upos_clf.weight, double_trainer.model.upos_clf.weight) + assert torch.allclose(single_trainer.model.xpos_clf.W_bilin.weight, double_trainer.model.xpos_clf.W_bilin.weight) + + _, _, train_batches = tagger.load_training_data(single_trainer.args, pt) + total_single_loss = 0.0 + for batch in iter(train_batches): + batch = batch._replace(upos=torch.zeros_like(batch.upos), + ufeats=torch.zeros_like(batch.ufeats)) + # one should be zero, one should have the expected loss + # incidentally this checks that the loss is not nan for having a blank xpos + total_single_loss += single_trainer.update(batch) + + _, _, train_batches = tagger.load_training_data(double_trainer.args, pt) + total_double_loss = 0.0 + for batch in iter(train_batches): + batch = batch._replace(upos=torch.zeros_like(batch.upos), + ufeats=torch.zeros_like(batch.ufeats)) + total_double_loss += double_trainer.update(batch) + + assert pytest.approx(total_single_loss) == total_double_loss + def test_missing_column(self, tmp_path, wordvec_pretrain_file): """ Test that using train files with missing columns works