PyTorch Adapt provides tools for domain adaptation, a type of machine learning algorithm that repurposes existing models to work in new domains. This library is:
Build a complete train/val domain adaptation pipeline in a few lines of code.
Use just the parts that suit your needs, whether it's the algorithms, loss functions, or validation methods.
Customize and combine complex algorithms with ease.
Add additional functionality to your code by using one of the framework wrappers. Converting an algorithm into a PyTorch Lightning module is as simple as wrapping it with Lightning
.
See the examples folder for notebooks you can download or run on Google Colab.
from pytorch_adapt.hooks import DANNHook
from pytorch_adapt.utils.common_functions import batch_to_device
# Assuming that models, optimizers, and dataloader are already created.
hook = DANNHook(optimizers)
for data in tqdm(dataloader):
data = batch_to_device(data, device)
# Optimization is done inside the hook.
# The returned loss is for logging.
_, loss = hook({**models, **data})
Let's customize DANNHook
with:
- minimum class confusion
- virtual adversarial training
from pytorch_adapt.hooks import MCCHook, VATHook
# G and C are the Generator and Classifier models
G, C = models["G"], models["C"]
misc = {"combined_model": torch.nn.Sequential(G, C)}
hook = DANNHook(optimizers, post_g=[MCCHook(), VATHook()])
for data in tqdm(dataloader):
data = batch_to_device(data, device)
_, loss = hook({**models, **data, **misc})
First, set up the adapter and dataloaders:
from pytorch_adapt.adapters import DANN
from pytorch_adapt.containers import Models
from pytorch_adapt.datasets import DataloaderCreator
models_cont = Models(models)
adapter = DANN(models=models_cont)
dc = DataloaderCreator(num_workers=2)
dataloaders = dc(**datasets)
Then use a framework wrapper:
import pytorch_lightning as pl
from pytorch_adapt.frameworks.lightning import Lightning
L_adapter = Lightning(adapter)
trainer = pl.Trainer(gpus=1, max_epochs=1)
trainer.fit(L_adapter, dataloaders["train"])
trainer = Ignite(adapter)
trainer.run(datasets, dataloader_creator=dc)
You can do this in vanilla PyTorch:
from pytorch_adapt.validators import SNDValidator
# Assuming predictions have been collected
target_train = {"preds": preds}
validator = SNDValidator()
score = validator(target_train=target_train)
You can also do this during training with a framework wrapper:
from pytorch_adapt.frameworks.utils import filter_datasets
validator = SNDValidator()
dataloaders = dc(**filter_datasets(datasets, validator))
train_loader = dataloaders.pop("train")
L_adapter = Lightning(adapter, validator=validator)
trainer = pl.Trainer(gpus=1, max_epochs=1)
trainer.fit(L_adapter, train_loader, list(dataloaders.values()))
from pytorch_adapt.validators import ScoreHistory
validator = ScoreHistory(SNDValidator())
trainer = Ignite(adapter, validator=validator)
trainer.run(datasets, dataloader_creator=dc)
See this notebook and the examples page for other notebooks.
pip install pytorch-adapt
To get the latest dev version:
pip install pytorch-adapt --pre
To use pytorch_adapt.frameworks.lightning
:
pip install pytorch-adapt[lightning]
To use pytorch_adapt.frameworks.ignite
:
pip install pytorch-adapt[ignite]
Coming soon...
See setup.py
Thanks to the contributors who made pull requests!
Contributor | Highlights |
---|---|
deepseek-eoghan | Improved the TargetDataset class |
Thank you to Ser-Nam Lim, and my research advisor, Professor Serge Belongie.
Thanks to Jeff Musgrave for designing the logo.
If you'd like to cite pytorch-adapt in your paper, you can refer to this paper by copy-pasting this bibtex reference:
@article{Musgrave2022PyTorchA,
title={PyTorch Adapt},
author={Kevin Musgrave and Serge J. Belongie and Ser Nam Lim},
journal={ArXiv},
year={2022},
volume={abs/2211.15673}
}
- https://github.com/wgchang/DSBN
- https://github.com/jihanyang/AFN
- https://github.com/thuml/Versatile-Domain-Adaptation
- https://github.com/tim-learn/ATDOC
- https://github.com/thuml/CDAN
- https://github.com/takerum/vat_chainer
- https://github.com/takerum/vat_tf
- https://github.com/RuiShu/dirt-t
- https://github.com/lyakaap/VAT-pytorch
- https://github.com/9310gaurav/virtual-adversarial-training
- https://github.com/thuml/Deep-Embedded-Validation
- https://github.com/lr94/abas
- https://github.com/thuml/Batch-Spectral-Penalization
- https://github.com/jvanvugt/pytorch-domain-adaptation
- https://github.com/ptrblck/pytorch_misc