diff --git a/README.md b/README.md index b45c6a9..f0a4ade 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ DeComFL is a library designed for training/fine-tuning deep learning models in t ## Performance -From Tables 1 and 2, we observe the DeComFL's effectiveness in communication cost reduction. We evaluate its performance with five and ten perturbations. Its performance matches or even outperforms MeZO and FedZO in all datasets. Surprisingly, DeComFL can just require about **1MB communication cost** to converge, which is a significant saving compared with other algorithms. +From Tables 1 and 2, we observe the DeComFL's effectiveness in communication cost reduction. We evaluate its performance with five and ten perturbations. Its performance matches or even outperforms MeZO and FedZO in all datasets. Surprisingly, DeComFL can just require about **1MB communication cost** to converge, which is a significant saving compared with other algorithms. @@ -110,7 +110,6 @@ From Tables 1 and 2, we observe the DeComFL's effectiveness in communication cos
Table 1: Test accuracy and communication cost on fine-tuning tasks
- @@ -170,8 +169,6 @@ From Tables 1 and 2, we observe the DeComFL's effectiveness in communication cos
Table 2: Test accuracy on fine-tuning tasks (LoRA)
- - ## Environment Setup We use [conda](https://docs.conda.io/projects/conda/en/stable/) as our cross-platform environment management tool. However, due to macOS' lacking support for cuda, we have to make 2 different environment setup files: @@ -197,7 +194,6 @@ For READMD.md, we will use `environment.yml` whenever an environment file is nee - **Run DeComFL:** Follow FL routine, split data into chunks and train on different clients. Usage example: `python decomfl_main.py --large-model=opt-125m --dataset=sst2 --iterations=1000 --train-batch-size=32 --test-batch-size=200 --eval-iterations=25 --num-clients=3 --num-sample-clients=2 --local-update-steps=1 --num-pert=5 --lr=1e-5 --mu=1e-3 --grad-estimate-method=rge-forward` - ## Citation ``` @@ -210,7 +206,8 @@ For READMD.md, we will use `environment.yml` whenever an environment file is nee ``` ## Contributors -DeComFL is currently contributed and maintained by **Zidong Liu** (ComboCurve), **Bicheng Ying** (Google) and **Zhe Li** (RIT), and advised by Prof. **Haibo Yang** (RIT). + +DeComFL is currently contributed and maintained by **Zidong Liu** (ComboCurve), **Bicheng Ying** (Google) and **Zhe Li** (RIT), and advised by Prof. **Haibo Yang** (RIT).
Image 1 diff --git a/cezo_fl/random_gradient_estimator.py b/cezo_fl/random_gradient_estimator.py index 8906594..ce12b22 100644 --- a/cezo_fl/random_gradient_estimator.py +++ b/cezo_fl/random_gradient_estimator.py @@ -9,6 +9,8 @@ GradEstimateMethod: TypeAlias = Literal["forward", "central"] +BatchInput: TypeAlias = torch.Tensor | LLMBatchInput + # TODO: split this class into abstract class and several subcalsses. class RandomGradientEstimator: @@ -57,7 +59,7 @@ def __init__( # TODO(zidong) move this func out of this class def model_forward( self, - batch_inputs: torch.Tensor | LLMBatchInput, + batch_inputs: BatchInput, ): if self.generation_mode: if not isinstance(self.model, (OPTForCausalLM, PeftModel)): @@ -136,7 +138,14 @@ def generate_then_put_grad(self, seed: int, dir_grads: torch.Tensor) -> None: assert update_grad is not None self.put_grad(update_grad) - def compute_grad(self, batch_inputs, labels, criterion, seed: int) -> torch.Tensor: + def compute_grad( + self, + batch_inputs: BatchInput, + labels: torch.Tensor, + criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + seed: int | None, + ) -> torch.Tensor: + """When seed is None, it means we are just using random gradient estimator to estimate grad, we do not need to reconstruct each perturbation""" if not self.paramwise_perturb: # We generate the perturbation vector all together. It should be faster but consume # more memory @@ -145,6 +154,7 @@ def compute_grad(self, batch_inputs, labels, criterion, seed: int) -> torch.Tens ) self.put_grad(grad) else: + assert isinstance(seed, int) perturbation_dir_grads = self._zo_grad_estimate_paramwise( batch_inputs, labels, criterion, seed ) @@ -166,10 +176,10 @@ def sgd_no_optim_update_model( def _zo_grad_estimate( self, - batch_inputs: torch.Tensor, + batch_inputs: BatchInput, labels: torch.Tensor, criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], - seed: int, + seed: int | None, ) -> tuple[torch.Tensor, torch.Tensor]: """Calculate the zeroth-order gradient estimate. @@ -188,7 +198,7 @@ def _zo_grad_estimate( pert_minus_loss = criterion(self.model_forward(batch_inputs), labels) for i in range(self.num_pert): - rng = self.get_rng(seed, i) + rng: torch.Generator | None = self.get_rng(seed, i) if isinstance(seed, int) else None pb_norm = self.generate_perturbation_norm(rng) self.perturb_model(pb_norm, alpha=self.mu) @@ -236,7 +246,7 @@ def perturb_model_paramwise(self, rng: torch.Generator, alpha: float | int) -> N def _zo_grad_estimate_paramwise( self, - batch_inputs: torch.Tensor, + batch_inputs: BatchInput, labels: torch.Tensor, criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], seed: int, diff --git a/cezo_fl/util/model_helpers.py b/cezo_fl/util/model_helpers.py index f3357e2..d15a58f 100644 --- a/cezo_fl/util/model_helpers.py +++ b/cezo_fl/util/model_helpers.py @@ -10,6 +10,10 @@ import torch import torch.optim as optim +from peft import PeftModel +from transformers.models.opt.modeling_opt import OPTForCausalLM + +from cezo_fl.util.language_utils import LLMBatchInput def get_current_datetime_str(): @@ -73,3 +77,14 @@ def get_trainable_model_parameters( for param in model.parameters(): if param.requires_grad: yield param + + +def model_forward( + model: OPTForCausalLM | PeftModel | torch.nn.Module, batch_inputs: torch.Tensor | LLMBatchInput +): + if isinstance(model, (OPTForCausalLM, PeftModel)): + return model(input_ids=batch_inputs.input_ids, attention_mask=batch_inputs.attention_mask) + elif isinstance(model, torch.nn.Module): + return model(batch_inputs) + else: + raise Exception("This model type is not supported") diff --git a/cezo_fl/util/prepare_settings.py b/cezo_fl/util/prepare_settings.py new file mode 100644 index 0000000..a88f280 --- /dev/null +++ b/cezo_fl/util/prepare_settings.py @@ -0,0 +1,191 @@ +import torch +import torch.nn as nn + +from peft import LoraConfig, get_peft_model +from transformers import AutoModelForCausalLM, AutoTokenizer + + +from cezo_fl.util import model_helpers +from cezo_fl.models.cnn_fashion import CNN_FMNIST +from cezo_fl.models.cnn_mnist import CNN_MNIST +from cezo_fl.models.lenet import LeNet +from cezo_fl.models.lstm import CharLSTM +from cezo_fl.random_gradient_estimator import RandomGradientEstimator as RGE +from cezo_fl.util.language_utils import LM_TEMPLATE_MAP, SUPPORTED_LLM, get_lm_loss +from cezo_fl.util.metrics import accuracy + + +def prepare_settings_underseed(args, device, server_or_client: str = "server"): + torch_dtype = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[args.model_dtype] + torch.manual_seed(args.seed) + if args.dataset == "mnist": + model = CNN_MNIST().to(torch_dtype).to(device) + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.SGD( + model_helpers.get_trainable_model_parameters(model), + lr=args.lr, + weight_decay=1e-5, + momentum=args.momentum, + ) + accuracy_func = accuracy + # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8) + elif args.dataset == "cifar10": + model = LeNet().to(torch_dtype).to(device) + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.SGD( + model_helpers.get_trainable_model_parameters(model), + lr=args.lr, + weight_decay=5e-4, + momentum=args.momentum, + ) + accuracy_func = accuracy + # scheduler = torch.optim.lr_scheduler.MultiStepLR( + # optimizer, milestones=[200], gamma=0.1 + # ) + elif args.dataset == "fashion": + model = CNN_FMNIST().to(torch_dtype).to(device) + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.SGD( + model_helpers.get_trainable_model_parameters(model), + lr=args.lr, + weight_decay=1e-5, + momentum=args.momentum, + ) + accuracy_func = accuracy + # scheduler = torch.optim.lr_scheduler.MultiStepLR( + # optimizer, milestones=[200], gamma=0.1 + # ) + elif args.dataset == "shakespeare": + model = CharLSTM().to(torch_dtype).to(device) + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.SGD( + model_helpers.get_trainable_model_parameters(model), + lr=args.lr, + momentum=0.9, + weight_decay=5e-4, + ) + accuracy_func = accuracy + # scheduler = torch.optim.lr_scheduler.MultiStepLR( + # optimizer, milestones=[200], gamma=0.1 + # ) + elif args.dataset in LM_TEMPLATE_MAP.keys(): + large_model = args.large_model + model_name = SUPPORTED_LLM[large_model] + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch_dtype).to(device) + model.model_name = large_model + tokenizer = AutoTokenizer.from_pretrained( + model_name, padding_side="left", truncate_side="left" + ) + template = LM_TEMPLATE_MAP[args.dataset]() + if args.dataset in ["sst2", "cb", "wsc", "wic", "multirc", "rte", "boolq"]: + if args.lora: + # this step initialize lora parameters, which should be under control of seed + lora_config = LoraConfig( + r=args.lora_r, + lora_alpha=args.lora_alpha, + target_modules=["q_proj", "v_proj"], + ) + model = get_peft_model(model, lora_config).to(torch_dtype) + verbalizer_id_map = template.get_verbalizer_id(tokenizer) + criterion = get_lm_loss("last_token", verbalizer_id_map=verbalizer_id_map) + optimizer = torch.optim.SGD( + model_helpers.get_trainable_model_parameters(model), + lr=args.lr, + momentum=0, + weight_decay=5e-4, + ) + accuracy_func = get_lm_loss("accuracy", verbalizer_id_map=verbalizer_id_map) + elif args.dataset in ["squad", "drop", "xsum"]: + if server_or_client == "server": + criterion = get_lm_loss("f1", tokenizer=tokenizer) + optimizer = torch.optim.SGD( + model_helpers.get_trainable_model_parameters(model), + lr=args.lr, + momentum=0, + weight_decay=0, + ) + accuracy_func = get_lm_loss("f1", tokenizer=tokenizer) + elif server_or_client == "client": + criterion = get_lm_loss("full_sentence", verbalizer_id_map={}) + optimizer = torch.optim.SGD( + model_helpers.get_trainable_model_parameters(model), + lr=args.lr, + momentum=0, + weight_decay=0, + ) + accuracy_func = get_lm_loss("full_sentence", verbalizer_id_map={}) + else: + raise ValueError( + "server_or_client must be either 'server' or 'client'. " + f"But get {server_or_client}" + ) + else: + raise ValueError(f"Dataset {args.dataset} is not supported") + else: + raise Exception(f"Dataset {args.dataset} is not supported") + + if args.grad_estimate_method in ["rge-central", "rge-forward"]: + method = args.grad_estimate_method[4:] + print(f"Using RGE {method}") + if args.dataset in ["squad", "drop"] and server_or_client == "server": + generation_mode = True + # TODO move this setting partially to the args + generation_mode_kwargs = { + "do_sample": True, + "temperature": 1.0, + "num_beams": 2, + "top_p": 0.3, + "top_k": None, + "num_return_sequences": 1, + "max_new_tokens": 5, # will be adjusted dynamically later + "max_length": 2048, + "length_penalty": 2, + "early_stopping": True, + "eos_token_id": [ + tokenizer.encode("\n", add_special_tokens=False)[-1], + tokenizer.eos_token_id, + ], + } + elif args.dataset in ["xsum"] and server_or_client == "server": + generation_mode = True + # TODO move this setting partially to the args + generation_mode_kwargs = { + "do_sample": True, + "temperature": 1.0, + "num_beams": 2, + "top_p": 0.95, + "top_k": None, + "num_return_sequences": 1, + "max_new_tokens": 500, # will be adjusted dynamically later + "max_length": 2048, + "early_stopping": True, + "eos_token_id": [ + tokenizer.encode("\n", add_special_tokens=False)[-1], + tokenizer.eos_token_id, + ], + } + else: + generation_mode = False + generation_mode_kwargs = None + grad_estimator = RGE( + model, + parameters=model_helpers.get_trainable_model_parameters(model), + mu=args.mu, + num_pert=args.num_pert, + grad_estimate_method=method, + device=device, + torch_dtype=torch_dtype, + # To save memory consumption, we have to use parameter-wise perturb + no_optim together. + sgd_only_no_optim=args.no_optim, + paramwise_perturb=args.no_optim, + # For generation mode, the forward style is different + generation_mode=generation_mode, + generation_mode_kwargs=generation_mode_kwargs, + ) + else: + raise Exception(f"Grad estimate method {args.grad_estimate_method} not supported") + return model, criterion, optimizer, grad_estimator, accuracy_func diff --git a/decomfl_main.py b/decomfl_main.py index 0ec0e71..9e4890a 100644 --- a/decomfl_main.py +++ b/decomfl_main.py @@ -1,215 +1,22 @@ import functools from os import path - import torch -import torch.nn as nn -from peft import LoraConfig, get_peft_model + + from tensorboardX import SummaryWriter from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer from byzantine import aggregation as byz_agg from byzantine import attack as byz_attack from cezo_fl.client import ResetClient from cezo_fl.fl_helpers import get_client_name -from cezo_fl.models.cnn_fashion import CNN_FMNIST -from cezo_fl.models.cnn_mnist import CNN_MNIST -from cezo_fl.models.lenet import LeNet -from cezo_fl.models.lstm import CharLSTM -from cezo_fl.random_gradient_estimator import RandomGradientEstimator as RGE from cezo_fl.server import CeZO_Server -from cezo_fl.util import model_helpers -from cezo_fl.util.language_utils import ( - LM_TEMPLATE_MAP, - SUPPORTED_LLM, - get_lm_loss, - ClassificationTemplate, -) -from cezo_fl.util.metrics import accuracy +from cezo_fl.util import model_helpers, prepare_settings + from config import get_args_str, get_params from preprocess import preprocess -def prepare_settings_underseed(args, device, server_or_client: str = "server"): - torch_dtype = { - "float32": torch.float32, - "float16": torch.float16, - "bfloat16": torch.bfloat16, - }[args.model_dtype] - torch.manual_seed(args.seed) - model: CNN_MNIST | LeNet | CNN_FMNIST | CharLSTM | AutoModelForCausalLM - if args.dataset == "mnist": - model = CNN_MNIST().to(torch_dtype).to(device) - criterion = nn.CrossEntropyLoss() - optimizer = torch.optim.SGD( - model_helpers.get_trainable_model_parameters(model), - lr=args.lr, - weight_decay=1e-5, - momentum=args.momentum, - ) - accuracy_func = accuracy - # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8) - elif args.dataset == "cifar10": - model = LeNet().to(torch_dtype).to(device) - criterion = nn.CrossEntropyLoss() - optimizer = torch.optim.SGD( - model_helpers.get_trainable_model_parameters(model), - lr=args.lr, - weight_decay=5e-4, - momentum=args.momentum, - ) - accuracy_func = accuracy - # scheduler = torch.optim.lr_scheduler.MultiStepLR( - # optimizer, milestones=[200], gamma=0.1 - # ) - elif args.dataset == "fashion": - model = CNN_FMNIST().to(torch_dtype).to(device) - criterion = nn.CrossEntropyLoss() - optimizer = torch.optim.SGD( - model_helpers.get_trainable_model_parameters(model), - lr=args.lr, - weight_decay=1e-5, - momentum=args.momentum, - ) - accuracy_func = accuracy - # scheduler = torch.optim.lr_scheduler.MultiStepLR( - # optimizer, milestones=[200], gamma=0.1 - # ) - elif args.dataset == "shakespeare": - model = CharLSTM().to(torch_dtype).to(device) - criterion = nn.CrossEntropyLoss() - optimizer = torch.optim.SGD( - model_helpers.get_trainable_model_parameters(model), - lr=args.lr, - momentum=0.9, - weight_decay=5e-4, - ) - accuracy_func = accuracy - # scheduler = torch.optim.lr_scheduler.MultiStepLR( - # optimizer, milestones=[200], gamma=0.1 - # ) - elif args.dataset in LM_TEMPLATE_MAP.keys(): - large_model = args.large_model - model_name = SUPPORTED_LLM[large_model] - model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch_dtype).to(device) - model.model_name = large_model - tokenizer = AutoTokenizer.from_pretrained( - model_name, padding_side="left", truncate_side="left" - ) - template = LM_TEMPLATE_MAP[args.dataset]() - if args.dataset in ["sst2", "cb", "wsc", "wic", "multirc", "rte", "boolq"] and isinstance( - template, ClassificationTemplate - ): - if args.lora: - # this step initialize lora parameters, which should be under control of seed - lora_config = LoraConfig( - r=args.lora_r, - lora_alpha=args.lora_alpha, - target_modules=["q_proj", "v_proj"], - ) - model = get_peft_model(model, lora_config).to(torch_dtype) - - verbalizer_id_map = template.get_verbalizer_id(tokenizer) - criterion = get_lm_loss("last_token", verbalizer_id_map=verbalizer_id_map) - optimizer = torch.optim.SGD( - model_helpers.get_trainable_model_parameters(model), - lr=args.lr, - momentum=0, - weight_decay=5e-4, - ) - accuracy_func = get_lm_loss("accuracy", verbalizer_id_map=verbalizer_id_map) - elif args.dataset in ["squad", "drop", "xsum"]: - if server_or_client == "server": - criterion = get_lm_loss("f1", tokenizer=tokenizer) - optimizer = torch.optim.SGD( - model_helpers.get_trainable_model_parameters(model), - lr=args.lr, - momentum=0, - weight_decay=0, - ) - accuracy_func = get_lm_loss("f1", tokenizer=tokenizer) - elif server_or_client == "client": - criterion = get_lm_loss("full_sentence", verbalizer_id_map={}) - optimizer = torch.optim.SGD( - model_helpers.get_trainable_model_parameters(model), - lr=args.lr, - momentum=0, - weight_decay=0, - ) - accuracy_func = get_lm_loss("full_sentence", verbalizer_id_map={}) - else: - raise ValueError( - "server_or_client must be either 'server' or 'client'. " - f"But get {server_or_client}" - ) - else: - raise ValueError(f"Dataset {args.dataset} is not supported") - else: - raise Exception(f"Dataset {args.dataset} is not supported") - - if args.grad_estimate_method in ["rge-central", "rge-forward"]: - method = args.grad_estimate_method[4:] - print(f"Using RGE {method}") - if args.dataset in ["squad", "drop"] and server_or_client == "server": - generation_mode = True - # TODO move this setting partially to the args - generation_mode_kwargs = { - "do_sample": True, - "temperature": 1.0, - "num_beams": 2, - "top_p": 0.3, - "top_k": None, - "num_return_sequences": 1, - "max_new_tokens": 5, # will be adjusted dynamically later - "max_length": 2048, - "length_penalty": 2, - "early_stopping": True, - "eos_token_id": [ - tokenizer.encode("\n", add_special_tokens=False)[-1], - tokenizer.eos_token_id, - ], - } - elif args.dataset in ["xsum"] and server_or_client == "server": - generation_mode = True - # TODO move this setting partially to the args - generation_mode_kwargs = { - "do_sample": True, - "temperature": 1.0, - "num_beams": 2, - "top_p": 0.95, - "top_k": None, - "num_return_sequences": 1, - "max_new_tokens": 500, # will be adjusted dynamically later - "max_length": 2048, - "early_stopping": True, - "eos_token_id": [ - tokenizer.encode("\n", add_special_tokens=False)[-1], - tokenizer.eos_token_id, - ], - } - else: - generation_mode = False - generation_mode_kwargs = None - grad_estimator = RGE( - model, - parameters=model_helpers.get_trainable_model_parameters(model), - mu=args.mu, - num_pert=args.num_pert, - grad_estimate_method=method, - device=device, - torch_dtype=torch_dtype, - # To save memory consumption, we have to use parameter-wise perturb + no_optim together. - sgd_only_no_optim=args.no_optim, - paramwise_perturb=args.no_optim, - # For generation mode, the forward style is different - generation_mode=generation_mode, - generation_mode_kwargs=generation_mode_kwargs, - ) - else: - raise Exception(f"Grad estimate method {args.grad_estimate_method} not supported") - return model, criterion, optimizer, grad_estimator, accuracy_func - - def setup_server_and_clients( args, device_map: dict[str, torch.device], train_loaders ) -> CeZO_Server: @@ -224,7 +31,7 @@ def setup_server_and_clients( client_optimizer, client_grad_estimator, client_accuracy_func, - ) = prepare_settings_underseed(args, client_device, "client") + ) = prepare_settings.prepare_settings_underseed(args, client_device, "client") client_model.to(client_device) client = ResetClient( @@ -253,7 +60,7 @@ def setup_server_and_clients( server_optimizer, server_grad_estimator, server_accuracy_func, - ) = prepare_settings_underseed(args, server_device, "server") + ) = prepare_settings.prepare_settings_underseed(args, server_device, "server") server_model.to(server_device) server.set_server_model_and_criterion( server_model, diff --git a/fed_avg/__init__.py b/fed_avg/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fed_avg/client.py b/fed_avg/client.py new file mode 100644 index 0000000..36da2cf --- /dev/null +++ b/fed_avg/client.py @@ -0,0 +1,71 @@ +from typing import Iterator, Callable + +import torch +from peft import PeftModel +from torch.utils.data import DataLoader +from transformers.models.opt.modeling_opt import OPTForCausalLM + +from cezo_fl.shared import CriterionType +from cezo_fl.util.metrics import Metric +from cezo_fl.util.model_helpers import model_forward + + +class FedAvgClient: + def __init__( + self, + model: torch.nn.Module, + dataloader: DataLoader, + optimizer: torch.optim.Optimizer, + criterion: CriterionType, + accuracy_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + device: torch.device | None = None, + ): + self.model = model + self.dataloader = dataloader + + self._device = device + + self.optimizer = optimizer + self.criterion = criterion + self.accuracy_func = accuracy_func + + self.data_iterator = self._get_train_batch_iterator() + self.dtype = next(model.parameters()).dtype + + @property + def device(self) -> torch.device: + return torch.device(self._device) + + def _get_train_batch_iterator(self) -> Iterator: + # NOTE: used only in init, will generate an infinite iterator from dataloader + while True: + for v in self.dataloader: + yield v + + def local_update(self, local_update_steps: int) -> tuple[float, float]: + train_loss = Metric("Client train loss") + train_accuracy = Metric("Client train accuracy") + + for _ in range(local_update_steps): + self.optimizer.zero_grad() + # NOTE:dataloader manage its own randomnes state thus not affected by seed + batch_inputs, labels = next(self.data_iterator) + if self.device != torch.device("cpu") or self.dtype != torch.float32: + batch_inputs = batch_inputs.to(self.device, self.dtype) + # NOTE: label does not convert to dtype + labels = labels.to(self.device) + + pred = model_forward(self.model, batch_inputs) + loss = self.criterion(pred, labels) + loss.backward() + self.optimizer.step() + # get_train_info + train_loss.update(loss.detach().item()) + train_accuracy.update(self.accuracy_func(pred, labels).detach().item()) + + return train_loss.avg, train_accuracy.avg + + def pull_model(self, server_model: OPTForCausalLM | PeftModel | torch.nn.Module) -> None: + with torch.no_grad(): + for p, updated_p in zip(self.model.parameters(), server_model.parameters()): + p.set_(updated_p.to(self._device)) diff --git a/fed_avg/fed_zo_client.py b/fed_avg/fed_zo_client.py new file mode 100644 index 0000000..9cde966 --- /dev/null +++ b/fed_avg/fed_zo_client.py @@ -0,0 +1,76 @@ +from typing import Iterator, Callable + +import torch +from peft import PeftModel +from torch.utils.data import DataLoader +from transformers.models.opt.modeling_opt import OPTForCausalLM + +from cezo_fl.shared import CriterionType +from cezo_fl.util.metrics import Metric +from cezo_fl.util.model_helpers import model_forward +from cezo_fl.random_gradient_estimator import RandomGradientEstimator + + +class FedZOClient: + def __init__( + self, + model: torch.nn.Module, + dataloader: DataLoader, + optimizer: torch.optim.Optimizer, + criterion: CriterionType, + accuracy_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + rge: RandomGradientEstimator, + device: torch.device | None = None, + ): + self.model = model + self.dataloader = dataloader + + self._device = device + + self.optimizer = optimizer + self.criterion = criterion + self.accuracy_func = accuracy_func + + self.rge = rge + + self.data_iterator = self._get_train_batch_iterator() + self.dtype = next(model.parameters()).dtype + + @property + def device(self) -> torch.device: + return torch.device(self._device) + + def _get_train_batch_iterator(self) -> Iterator: + # NOTE: used only in init, will generate an infinite iterator from dataloader + while True: + for v in self.dataloader: + yield v + + def local_update(self, local_update_steps: int) -> tuple[float, float]: + with torch.no_grad(): + train_loss = Metric("Client train loss") + train_accuracy = Metric("Client train accuracy") + + for _ in range(local_update_steps): + self.optimizer.zero_grad() + # NOTE:dataloader manage its own randomnes state thus not affected by seed + batch_inputs, labels = next(self.data_iterator) + if self.device != torch.device("cpu") or self.dtype != torch.float32: + batch_inputs = batch_inputs.to(self.device, self.dtype) + # NOTE: label does not convert to dtype + labels = labels.to(self.device) + + self.rge.compute_grad(batch_inputs, labels, self.criterion, None) + self.optimizer.step() + # get_train_info + pred = model_forward(self.model, batch_inputs) + loss = self.criterion(pred, labels) + train_loss.update(loss.detach().item()) + train_accuracy.update(self.accuracy_func(pred, labels).detach().item()) + + return train_loss.avg, train_accuracy.avg + + def pull_model(self, server_model: OPTForCausalLM | PeftModel | torch.nn.Module) -> None: + with torch.no_grad(): + for p, updated_p in zip(self.model.parameters(), server_model.parameters()): + p.set_(updated_p.to(self._device)) diff --git a/fed_avg/server.py b/fed_avg/server.py new file mode 100644 index 0000000..f127bab --- /dev/null +++ b/fed_avg/server.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import random +from typing import Any, Callable, Iterable, Sequence + +import torch + +from cezo_fl.shared import CriterionType +from cezo_fl.util.metrics import Metric +from cezo_fl.util.model_helpers import model_forward + +from fed_avg.client import FedAvgClient +from fed_avg.fed_zo_client import FedZOClient + + +class FedAvgServer: + def __init__( + self, + clients: Sequence[FedAvgClient | FedZOClient], + device: torch.device, + server_model: torch.nn.Module, + server_criterion: CriterionType, + server_accuracy_func: Callable, + num_sample_clients: int = 10, + local_update_steps: int = 10, + ) -> None: + self.clients = clients + self.device = device + self.num_sample_clients = num_sample_clients + self.local_update_steps = local_update_steps + + self.server_model = server_model + self.server_criterion = server_criterion + self.server_accuracy_func = server_accuracy_func + + self.dtype = next(server_model.parameters()).dtype + + def get_sampled_client_index(self) -> list[int]: + return random.sample(range(len(self.clients)), self.num_sample_clients) + + def set_learning_rate(self, lr: float) -> None: + # Client + for client in self.clients: + for p in client.optimizer.param_groups: + p["lr"] = lr + # Server + if self.server_model: + for p in self.optim.param_groups: + p["lr"] = lr + + def aggregate_client_models(self, client_indices: list[int]) -> None: + self.server_model.train() + with torch.no_grad(): + running_sum: Sequence[torch.Tensor] = [0 for _ in self.server_model.parameters()] + for client_index in client_indices: + client = self.clients[client_index] + for i, p in enumerate(client.model.parameters()): + running_sum[i] += p.to(self.device) + + for p, to_set_p in zip(self.server_model.parameters(), running_sum): + p.set_(to_set_p.div_(self.num_sample_clients)) + + def train_one_step(self) -> tuple[float, float]: + # Step 0: initiate something + sampled_client_indices: list[int] = self.get_sampled_client_index() + + # Step 1 & 2: pull model and local update + step_train_loss = Metric("train_loss") + step_train_accuracy = Metric("train_loss") + for index in sampled_client_indices: + client = self.clients[index] + client.pull_model(self.server_model) + client_loss, client_accuracy = client.local_update(self.local_update_steps) + step_train_loss.update(client_loss) + step_train_accuracy.update(client_accuracy) + + self.aggregate_client_models(sampled_client_indices) + + return step_train_loss.avg, step_train_accuracy.avg + + def eval_model(self, test_loader: Iterable[Any], iteration: int) -> tuple[float, float]: + self.server_model.eval() + eval_loss = Metric("Eval loss") + eval_accuracy = Metric("Eval accuracy") + with torch.no_grad(): + for _, (batch_inputs, batch_labels) in enumerate(test_loader): + if self.device != torch.device("cpu") or self.dtype != torch.float32: + batch_inputs = batch_inputs.to(self.device, self.dtype) + # NOTE: label does not convert to dtype + batch_labels = batch_labels.to(self.device) + pred = model_forward(self.server_model, batch_inputs) + eval_loss.update(self.server_criterion(pred, batch_labels)) + eval_accuracy.update(self.server_accuracy_func(pred, batch_labels)) + print( + f"\nEvaluation(Iteration {iteration + 1}): ", + f"Eval Loss:{eval_loss.avg:.4f}, " f"Accuracy:{eval_accuracy.avg * 100:.2f}%", + ) + return eval_loss.avg, eval_accuracy.avg diff --git a/fed_zo_main.py b/fed_zo_main.py new file mode 100644 index 0000000..abe5d77 --- /dev/null +++ b/fed_zo_main.py @@ -0,0 +1,108 @@ +from os import path + +import torch +from tensorboardX import SummaryWriter +from tqdm import tqdm + +from cezo_fl.fl_helpers import get_client_name +from cezo_fl.util import model_helpers, prepare_settings + +from config import get_args_str, get_params + +from fed_avg.fed_zo_client import FedZOClient +from fed_avg.server import FedAvgServer +from preprocess import preprocess + + +def setup_server_and_clients( + args, device_map: dict[str, torch.device], train_loaders +) -> FedAvgServer: + clients = [] + + for i in range(args.num_clients): + client_name = get_client_name(i) + client_device = device_map[client_name] + ( + client_model, + client_criterion, + client_optimizer, + client_grad_estimator, + client_accuracy_func, + ) = prepare_settings.prepare_settings_underseed(args, client_device) + client_model.to(client_device) + + client = FedZOClient( + client_model, + train_loaders[i], + client_optimizer, + client_criterion, + client_accuracy_func, + rge=client_grad_estimator, + device=client_device, + ) + clients.append(client) + + server_device = device_map["server"] + ( + server_model, + server_criterion, + _, + _, + server_accuracy_func, + ) = prepare_settings.prepare_settings_underseed(args, server_device) + server_model.to(server_device) + server = FedAvgServer( + clients, + server_device, + server_model=server_model, + server_criterion=server_criterion, + server_accuracy_func=server_accuracy_func, + num_sample_clients=args.num_sample_clients, + local_update_steps=args.local_update_steps, + ) + + return server + + +if __name__ == "__main__": + args = get_params().parse_args() + if args.dataset == "shakespeare": + args.num_clients = 139 + print(args) + device_map, train_loaders, test_loader = preprocess(args) + + server = setup_server_and_clients(args, device_map, train_loaders) + + if args.log_to_tensorboard: + tensorboard_sub_folder = "-".join( + [ + get_args_str(args), + server.server_model.model_name, + model_helpers.get_current_datetime_str(), + ] + ) + writer = SummaryWriter( + path.join( + "tensorboards", + "fed_avg", + args.dataset, + args.log_to_tensorboard, + tensorboard_sub_folder, + ) + ) + + with tqdm(total=args.iterations, desc="Training:") as t: + for ite in range(args.iterations): + step_loss, step_accuracy = server.train_one_step() + t.set_postfix({"Loss": step_loss, "Accuracy": step_accuracy}) + t.update(1) + + if args.log_to_tensorboard: + writer.add_scalar("Loss/train", step_loss, ite) + writer.add_scalar("Accuracy/train", step_accuracy, ite) + # eval + if args.eval_iterations != 0 and (ite + 1) % args.eval_iterations == 0: + eval_loss, eval_accuracy = server.eval_model(test_loader, ite) + if args.log_to_tensorboard: + writer.add_scalar("Loss/test", eval_loss, ite) + writer.add_scalar("Accuracy/test", eval_accuracy, ite) diff --git a/fo_fl_main.py b/fo_fl_main.py new file mode 100644 index 0000000..bbe03b6 --- /dev/null +++ b/fo_fl_main.py @@ -0,0 +1,105 @@ +from os import path + +import torch +from tensorboardX import SummaryWriter +from tqdm import tqdm + +from cezo_fl.fl_helpers import get_client_name +from cezo_fl.util import model_helpers, prepare_settings +from config import get_args_str, get_params +from fed_avg.client import FedAvgClient +from fed_avg.server import FedAvgServer +from preprocess import preprocess + + +def setup_server_and_clients( + args, device_map: dict[str, torch.device], train_loaders +) -> FedAvgServer: + clients = [] + + for i in range(args.num_clients): + client_name = get_client_name(i) + client_device = device_map[client_name] + ( + client_model, + client_criterion, + client_optimizer, + _, + client_accuracy_func, + ) = prepare_settings.prepare_settings_underseed(args, client_device) + client_model.to(client_device) + + client = FedAvgClient( + client_model, + train_loaders[i], + client_optimizer, + client_criterion, + client_accuracy_func, + client_device, + ) + clients.append(client) + + server_device = device_map["server"] + ( + server_model, + server_criterion, + _, + _, + server_accuracy_func, + ) = prepare_settings.prepare_settings_underseed(args, server_device) + server_model.to(server_device) + server = FedAvgServer( + clients, + server_device, + server_model=server_model, + server_criterion=server_criterion, + server_accuracy_func=server_accuracy_func, + num_sample_clients=args.num_sample_clients, + local_update_steps=args.local_update_steps, + ) + + return server + + +if __name__ == "__main__": + args = get_params().parse_args() + if args.dataset == "shakespeare": + args.num_clients = 139 + print(args) + device_map, train_loaders, test_loader = preprocess(args) + + server = setup_server_and_clients(args, device_map, train_loaders) + + if args.log_to_tensorboard: + tensorboard_sub_folder = "-".join( + [ + get_args_str(args), + server.server_model.model_name, + model_helpers.get_current_datetime_str(), + ] + ) + writer = SummaryWriter( + path.join( + "tensorboards", + "fed_avg", + args.dataset, + args.log_to_tensorboard, + tensorboard_sub_folder, + ) + ) + + with tqdm(total=args.iterations, desc="Training:") as t: + for ite in range(args.iterations): + step_loss, step_accuracy = server.train_one_step() + t.set_postfix({"Loss": step_loss, "Accuracy": step_accuracy}) + t.update(1) + + if args.log_to_tensorboard: + writer.add_scalar("Loss/train", step_loss, ite) + writer.add_scalar("Accuracy/train", step_accuracy, ite) + # eval + if args.eval_iterations != 0 and (ite + 1) % args.eval_iterations == 0: + eval_loss, eval_accuracy = server.eval_model(test_loader, ite) + if args.log_to_tensorboard: + writer.add_scalar("Loss/test", eval_loss, ite) + writer.add_scalar("Accuracy/test", eval_accuracy, ite) diff --git a/llm_fo_fine_tune_main.py b/llm_fo_fine_tune_main.py index ff60f94..c809d13 100644 --- a/llm_fo_fine_tune_main.py +++ b/llm_fo_fine_tune_main.py @@ -1,8 +1,9 @@ import torch from tqdm import tqdm -import decomfl_main + from cezo_fl.util.metrics import Metric +from cezo_fl.util import prepare_settings from config import get_params from preprocess import preprocess @@ -36,7 +37,7 @@ def inf_loader(dl): # args_str = get_args_str(args) + "-" + server.server_model.model_name model, criterion, optimizer, grad_estimator, accuracy_func = ( - decomfl_main.prepare_settings_underseed(args, device) + prepare_settings.prepare_settings_underseed(args, device) ) model.to(device) diff --git a/run_experiments.sh b/run_experiments.sh new file mode 100644 index 0000000..5f49df0 --- /dev/null +++ b/run_experiments.sh @@ -0,0 +1,12 @@ +python fo_fl_main.py --dataset=sst2 --eval-iterations=20 --large-model=opt-125m --model-dtype=float32 --seed=66 --iterations=2000 --train-batch-size=32 --test-batch-size=64 --num-clients=8 --num-sample-clients=2 --local-update-steps=1 --momentum=0 --lr=1e-3 --no-optim --no-iid --dirichlet-alpha=1 --log-to-tensorboard=fedavg-125m-sst2 +python fo_fl_main.py --dataset=cb --eval-iterations=20 --large-model=opt-125m --model-dtype=float32 --seed=66 --iterations=2000 --train-batch-size=32 --test-batch-size=64 --num-clients=8 --num-sample-clients=2 --local-update-steps=1 --momentum=0 --lr=1e-3 --no-optim --no-iid --dirichlet-alpha=1 --log-to-tensorboard=fedavg-125m-cb +python fo_fl_main.py --dataset=wsc --eval-iterations=20 --large-model=opt-125m --model-dtype=float32 --seed=66 --iterations=2000 --train-batch-size=32 --test-batch-size=64 --num-clients=8 --num-sample-clients=2 --local-update-steps=1 --momentum=0 --lr=1e-3 --no-optim --no-iid --dirichlet-alpha=1 --log-to-tensorboard=fedavg-125m-wsc +python fo_fl_main.py --dataset=wic --eval-iterations=20 --large-model=opt-125m --model-dtype=float32 --seed=66 --iterations=2000 --train-batch-size=32 --test-batch-size=64 --num-clients=8 --num-sample-clients=2 --local-update-steps=1 --momentum=0 --lr=1e-3 --no-optim --no-iid --dirichlet-alpha=1 --log-to-tensorboard=fedavg-125m-wic +python fo_fl_main.py --dataset=rte --eval-iterations=20 --large-model=opt-125m --model-dtype=float32 --seed=66 --iterations=2000 --train-batch-size=32 --test-batch-size=64 --num-clients=8 --num-sample-clients=2 --local-update-steps=1 --momentum=0 --lr=1e-3 --no-optim --no-iid --dirichlet-alpha=1 --log-to-tensorboard=fedavg-125m-rte +python fo_fl_main.py --dataset=boolq --eval-iterations=20 --large-model=opt-125m --model-dtype=float32 --seed=66 --iterations=2000 --train-batch-size=16 --test-batch-size=32 --num-clients=8 --num-sample-clients=2 --local-update-steps=1 --momentum=0 --lr=1e-3 --no-optim --no-iid --dirichlet-alpha=1 --log-to-tensorboard=fedavg-125m-boolq +python fo_fl_main.py --dataset=sst2 --eval-iterations=20 --large-model=opt-125m --model-dtype=float32 --seed=66 --iterations=2000 --train-batch-size=32 --test-batch-size=64 --num-clients=8 --num-sample-clients=2 --local-update-steps=1 --momentum=0 --lr=1e-3 --no-optim --no-iid --dirichlet-alpha=1 --log-to-tensorboard=fedavg-1.3b-sst2 +python fo_fl_main.py --dataset=cb --eval-iterations=20 --large-model=opt-125m --model-dtype=float32 --seed=66 --iterations=2000 --train-batch-size=32 --test-batch-size=64 --num-clients=8 --num-sample-clients=2 --local-update-steps=1 --momentum=0 --lr=1e-3 --no-optim --no-iid --dirichlet-alpha=1 --log-to-tensorboard=fedavg-1.3b-cb +python fo_fl_main.py --dataset=wsc --eval-iterations=20 --large-model=opt-125m --model-dtype=float32 --seed=66 --iterations=2000 --train-batch-size=32 --test-batch-size=64 --num-clients=8 --num-sample-clients=2 --local-update-steps=1 --momentum=0 --lr=1e-3 --no-optim --no-iid --dirichlet-alpha=1 --log-to-tensorboard=fedavg-1.3b-wsc +python fo_fl_main.py --dataset=wic --eval-iterations=20 --large-model=opt-125m --model-dtype=float32 --seed=66 --iterations=2000 --train-batch-size=32 --test-batch-size=64 --num-clients=8 --num-sample-clients=2 --local-update-steps=1 --momentum=0 --lr=1e-3 --no-optim --no-iid --dirichlet-alpha=1 --log-to-tensorboard=fedavg-1.3b-wic +python fo_fl_main.py --dataset=rte --eval-iterations=20 --large-model=opt-125m --model-dtype=float32 --seed=66 --iterations=2000 --train-batch-size=32 --test-batch-size=64 --num-clients=8 --num-sample-clients=2 --local-update-steps=1 --momentum=0 --lr=1e-3 --no-optim --no-iid --dirichlet-alpha=1 --log-to-tensorboard=fedavg-1.3b-rte +python fo_fl_main.py --dataset=boolq --eval-iterations=20 --large-model=opt-125m --model-dtype=float32 --seed=66 --iterations=2000 --train-batch-size=16 --test-batch-size=32 --num-clients=8 --num-sample-clients=2 --local-update-steps=1 --momentum=0 --lr=1e-3 --no-optim --no-iid --dirichlet-alpha=1 --log-to-tensorboard=fedavg-1.3b-boolq