This is the official PyTorch
implementation for our papers:
- Free-form flows: Make Any Architecture a Normalizing Flow on full-dimensional normalizing flows:
@inproceedings{draxler2024freeform, title = {{Free-form flows: Make Any Architecture a Normalizing Flow}}, author = {Draxler, Felix and Sorrenson, Peter and Zimmermann, Lea and Rousselot, Armand and Köthe, Ullrich}, booktitle = {International Conference on Artificial Intelligence and Statistics}, year = {2024} }
- Lifting Architectural Constraints of Injective Flows on learning a manifold and the distribution on it jointly:
@inproceedings{sorrenson2024lifting, title = {{Lifting Architectural Constraints of Injective Flows}}, booktitle = {International {{Conference}} on {{Learning Representations}}}, author = {Sorrenson, Peter and Draxler, Felix and Rousselot, Armand and Hummerich, Sander and Zimmermann, Lea and Köthe, Ullrich}, year = {2024} }
- Learning Distributions on Manifolds with Free-form Flows on learning distributions on a known manifold:
@article{sorrenson2023learning, title = {Learning Distributions on Manifolds with Free-form Flows}, author = {Sorrenson, Peter and Draxler, Felix and Rousselot, Armand and Hummerich, Sander and Köthe, Ullrich}, journal = {arXiv preprint arXiv:2312.09852}, year = {2023} }
The following will install our package along with all of its dependencies:
git clone https://github.com/vislearn/FFF.git
cd FFF
pip install -r requirements.txt
pip install .
In the last line, use pip install -e .
if you want to edit the code.
Then you can import the package via
import fff
import torch
import fff.loss as loss
class FreeFormFlow(torch.nn.Module):
def __init__(self):
super().__init__()
self.encoder = torch.nn.Sequential(...)
self.decoder = torch.nn.Sequential(...)
model = FreeFormFlow()
optim = ...
data_loader = ...
n_epochs = ...
beta = ...
for epoch in range(n_epochs):
for batch in data_loader:
optim.zero_grad()
loss = loss.fff_loss(batch, model.encoder, model.decoder, beta)
loss.backward()
optim.step()
All training configurations from our papers can be found in the configs/(fff|fif)
directories.
Our training framework is built on lightning-trainable, a configuration wrapper around PyTorch Lightning. There is no main.py
, but you can train all our models via the lightning_trainable.launcher.fit
module.
For example, to train the Boltzmann generator on DW4:
python -m lightning_trainable.launcher.fit configs/fff/dw4.yaml --name '{data_set[name]}'
This will create a new directory lightning_logs/dw4/
. You can monitor the run via tensorboard
:
tensorboard --logdir lightning_logs
When training has finished, you can import the model via
import fff
model = fff.FreeFormFlow.load_from_checkpoint(
'lightning_logs/dw4/version_0/checkpoints/last.ckpt'
)
If you want to overwrite the default parameters, you can add key=value
-pairs after the config file:
python -m lightning_trainable.launcher.fit configs/fff/dw4.yaml batch_size=128 loss_weights.noisy_reconstruction=20 --name '{data_set[name]}'
Training with --continue-from [CHECKPOINT]
flag to the training, such as:
python -m lightning_trainable.launcher.fit configs/fff/dw4.yaml --name '{data_set[name]}' --continue-from lightning_logs/dw4/version_0/checkpoints/last.ckpt
This reloads the entire training state (model state, optim state, epoch, etc.) from the checkpoint and continues training from there.
Start with the config file in configs/(fff|fif)
that fits your needs best and modify it.
For custom data sets, add the data set to fff.data
.