Skip to content

Commit

Permalink
test: add training test with conditional
Browse files Browse the repository at this point in the history
  • Loading branch information
mj-will committed Dec 11, 2023
1 parent 5d41e7a commit 5c9e6c0
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions tests/test_flowmodel/test_flowmodel_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,3 +624,28 @@ def test_train_without_validation(tmp_path):
history = flow.train(data)

assert np.isnan(history["val_loss"]).all()


@pytest.mark.integration_test
def test_train_conditional(tmp_path):
"""Assert training with conditional data works"""
output = tmp_path / "test_train_conditional"
output.mkdir()

config = dict(
max_epochs=10,
model_config=dict(
n_inputs=2,
n_blocks=2,
kwargs=dict(
linear_transform="lu",
context_features=1,
),
),
)

flow = FlowModel(config=config, output=output)
data = np.random.randn(100, 2)
conditional = np.random.randint(2, size=(100, 1))

_ = flow.train(data, conditional=conditional)

0 comments on commit 5c9e6c0

Please sign in to comment.