Code for Learning Representations that Support Robust Transfer of Predictors
TL,DR: We introduce a simple robust estimation criterion -- transfer risk -- that is specifically geared towards optimizing transfer to new environments. Effectively, the criterion amounts to finding a representation that minimizes the risk of applying any optimal predictor trained on one environment to another. The transfer risk essentially decomposes into two terms, a direct transfer term and a weighted gradient-matching term arising from the optimality of per-environment predictors.
python scripts/download.py --data_dir {data_dir}
Places dataset can be downloaded at:
http://data.csail.mit.edu/places/places365/train_256_places365standard.tar ;
COCO dataset can be downloaded at:
http://images.cocodataset.org/annotations/annotations_trainval2017.zip
# preprocess COCO
python coco.py
# preprocess Places
python places.py
# generate SceceCOCO dataset
python cocoplaces.py
-
Datasets:
- Synthetic datasets for controlled experiments: ColorMNIST / SceneCOCO
- Real-world datasets: PACS / Office-Home
python -m domainbed.scripts.train --data_dir {root} --algorithm {alg} \
--dataset {dataset} --trial_seed {t_seed} --epochs {epochs} (--shift {shift}) (--resnet50) (--test_eval)
root: root directory for the data
alg: ERM, VREx, IRM, GroupDRO, Fish, MLDG, TRM
t_seed: seed for data splitting
dataset: PACS or OfficeHome or ColoredMNIST or SceneCOCO
epochs: training epochs
resnet50: set ResNet50 as the backbone (default: ResNet18)
shift: for ColoredMNIST and SceneCOCO only, 0:label-correlated; 1: label-uncorrelated; 2: combine shift.
test_eval: test-domain validation set (default: train-domain validation set)
This implementation is based on / inspired by:
-
https://github.com/facebookresearch/DomainBed (code structure).
-
https://github.com/Faruk-Ahmed/predictive_group_invariance (data generation)