Diffusion-Driven Data Replay: A Novel Approach to Combat Forgetting in Federated Class Continual Learning
Diffusion-Driven Data Replay: A Novel Approach to Combat Forgetting in Federated Class Continual Learning
Jinglin Liang1, Jin Zhong1, Hanlin Gu2, Zhongqi Lu3, Xingxing Tang2, Gang Dai1, Shuangping Huang1,5*, Lixin Fan4, Qiang Yang2,4
1South China University of Technology, 2The Hong Kong University of Science and Technology, 3China University of Petroleum, 4WeBank, 5Pazhou Laboratory
Abstract:
Federated Class Continual Learning (FCCL) merges the challenges of distributed client learning with the need for seamless adaptation to new classes without forgetting old ones. The key challenge in FCCL is catastrophic forgetting, an issue that has been explored to some extent in Continual Learning (CL). However, due to privacy preservation requirements, some conventional methods, such as experience replay, are not directly applicable to FCCL. Existing FCCL methods mitigate forgetting by generating historical data through federated training of GANs or data-free knowledge distillation. However, these approaches often suffer from unstable training of generators or low-quality generated data, limiting their guidance for the model. To address this challenge, we propose a novel method of data replay based on diffusion models. Instead of training a diffusion model, we employ a pre-trained conditional diffusion model to reverse-engineer each category, searching the corresponding input conditions for each category within the model's input space, significantly reducing computational resources and time consumption while ensuring effective generation. Furthermore, we enhance the classifier's domain generalization ability on generated and real data through contrastive learning, indirectly improving the representational capability of generated data for real data. Extensive experiments demonstrate that our method significantly outperforms existing baselines.
Overview of our DDDR
This repository is the official PyTorch implementation of:
Diffusion-Driven Data Replay: A Novel Approach to Combat Forgetting in Federated Class Continual Learning (ECCV 2024 Oral).
pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html
Then install CLIP from the official CLIP repository.
The program will automatically download the CIFAR-100 dataset. You only need to download the Tiny ImageNet dataset using the following commands.
cd data
wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
unzip tiny-imagenet-200.zip
python preprocess.py
cd ..
We use the pretrained diffusion model from LDM repository, you can simply use the following command to obtain the pre-trained model.
mkdir -p models/ldm/text2img-large
wget -O models/ldm/text2img-large/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/txt2img-f8-large/model.ckpt
Please download bert-base-uncased from here, and put it in models/bert.
Just set the GPU_LIST in rep_run.sh and run this command to easily reproduce most of the results from our paper.
bash rep_run.sh
You can also reproduce our method or baseline methods separately using the following commands.
# DDDR
python main.py --dataset cifar100 --method ours --tasks 5 --beta 0.5 --seed 2024
# Finetune
python main.py --dataset cifar100 --method finetune --tasks 10 --beta 0.5 --seed 2024
# EWC
python main.py --dataset cifar100 --method ewc --tasks 10 --beta 0.5 --seed 2024
# Target
python main.py --dataset cifar100 --method target --tasks 5 --beta 0.5 --seed 2024 --w_kd 25
We would like to express our heartfelt gratitude for their contribution to our project.
If you find our work inspiring or use our codebase in your research, please cite our work:
@inproceedings{liang2024dddr,
title={Diffusion-Driven Data Replay: A Novel Approach to Combat Forgetting in Federated Class Continual Learning},
author={Liang, Jinglin and Zhong, Jin and Gu, Hanlin and Lu, Zhongqi and Tang, Xingxing and Dai, Gang and Huang, Shuangping and Fan, Lixin and Yang, Qiang},
booktitle={ECCV},
year={2024}
}