diff --git a/tests/test_flowmodel/test_flowmodel_base.py b/tests/test_flowmodel/test_flowmodel_base.py index e9f68671..6cc83c5d 100644 --- a/tests/test_flowmodel/test_flowmodel_base.py +++ b/tests/test_flowmodel/test_flowmodel_base.py @@ -624,3 +624,28 @@ def test_train_without_validation(tmp_path): history = flow.train(data) assert np.isnan(history["val_loss"]).all() + + +@pytest.mark.integration_test +def test_train_conditional(tmp_path): + """Assert training with conditional data works""" + output = tmp_path / "test_train_conditional" + output.mkdir() + + config = dict( + max_epochs=10, + model_config=dict( + n_inputs=2, + n_blocks=2, + kwargs=dict( + linear_transform="lu", + context_features=1, + ), + ), + ) + + flow = FlowModel(config=config, output=output) + data = np.random.randn(100, 2) + conditional = np.random.randint(2, size=(100, 1)) + + _ = flow.train(data, conditional=conditional)