Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FedZO #69

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
9 changes: 3 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ DeComFL is a library designed for training/fine-tuning deep learning models in t

## Performance

From Tables 1 and 2, we observe the DeComFL's effectiveness in communication cost reduction. We evaluate its performance with five and ten perturbations. Its performance matches or even outperforms MeZO and FedZO in all datasets. Surprisingly, DeComFL can just require about **1MB communication cost** to converge, which is a significant saving compared with other algorithms.
From Tables 1 and 2, we observe the DeComFL's effectiveness in communication cost reduction. We evaluate its performance with five and ten perturbations. Its performance matches or even outperforms MeZO and FedZO in all datasets. Surprisingly, DeComFL can just require about **1MB communication cost** to converge, which is a significant saving compared with other algorithms.

<table>
<caption style="caption-side: top; text-align: center; font-weight: bold;">Table 1: Test accuracy and communication cost on fine-tuning tasks</caption>
Expand Down Expand Up @@ -110,7 +110,6 @@ From Tables 1 and 2, we observe the DeComFL's effectiveness in communication cos
</tbody>
</table>


<table>
<caption style="caption-side: top; text-align: center; font-weight: bold;">Table 2: Test accuracy on fine-tuning tasks (LoRA)</caption>
<thead>
Expand Down Expand Up @@ -170,8 +169,6 @@ From Tables 1 and 2, we observe the DeComFL's effectiveness in communication cos
</tbody>
</table>



## Environment Setup

We use [conda](https://docs.conda.io/projects/conda/en/stable/) as our cross-platform environment management tool. However, due to macOS' lacking support for cuda, we have to make 2 different environment setup files:
Expand All @@ -197,7 +194,6 @@ For READMD.md, we will use `environment.yml` whenever an environment file is nee
- **Run DeComFL:** Follow FL routine, split data into chunks and train on different clients.
Usage example: `python decomfl_main.py --large-model=opt-125m --dataset=sst2 --iterations=1000 --train-batch-size=32 --test-batch-size=200 --eval-iterations=25 --num-clients=3 --num-sample-clients=2 --local-update-steps=1 --num-pert=5 --lr=1e-5 --mu=1e-3 --grad-estimate-method=rge-forward`


## Citation

```
Expand All @@ -210,7 +206,8 @@ For READMD.md, we will use `environment.yml` whenever an environment file is nee
```

## Contributors
DeComFL is currently contributed and maintained by <a href="https://zidongliu.github.io/" style="text-decoration: none;">**Zidong Liu**</a> (ComboCurve), <a href="https://scholar.google.com/citations?user=LuF6KX4AAAAJ&hl=en&oi=ao" style="text-decoration: none;">**Bicheng Ying**</a> (Google) and <a href="https://rogerrogerusc.github.io/" style="text-decoration: none;">**Zhe Li**</a> (RIT), and advised by Prof. <a href="https://haibo-yang-osu.github.io/homepage/" style="text-decoration: none;">**Haibo Yang**</a> (RIT).

DeComFL is currently contributed and maintained by <a href="https://zidongliu.github.io/" style="text-decoration: none;">**Zidong Liu**</a> (ComboCurve), <a href="https://scholar.google.com/citations?user=LuF6KX4AAAAJ&hl=en&oi=ao" style="text-decoration: none;">**Bicheng Ying**</a> (Google) and <a href="https://rogerrogerusc.github.io/" style="text-decoration: none;">**Zhe Li**</a> (RIT), and advised by Prof. <a href="https://haibo-yang-osu.github.io/homepage/" style="text-decoration: none;">**Haibo Yang**</a> (RIT).

<div style="display: flex; justify-content: space-between;">
<img src="https://github.com/user-attachments/assets/b3982917-e302-42c3-b396-e33bb9f52c90" alt="Image 1" style="width: 80%;" />
Expand Down
22 changes: 16 additions & 6 deletions cezo_fl/random_gradient_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

GradEstimateMethod: TypeAlias = Literal["forward", "central"]

BatchInput: TypeAlias = torch.Tensor | LLMBatchInput


# TODO: split this class into abstract class and several subcalsses.
class RandomGradientEstimator:
Expand Down Expand Up @@ -57,7 +59,7 @@ def __init__(
# TODO(zidong) move this func out of this class
def model_forward(
self,
batch_inputs: torch.Tensor | LLMBatchInput,
batch_inputs: BatchInput,
):
if self.generation_mode:
if not isinstance(self.model, (OPTForCausalLM, PeftModel)):
Expand Down Expand Up @@ -136,7 +138,14 @@ def generate_then_put_grad(self, seed: int, dir_grads: torch.Tensor) -> None:
assert update_grad is not None
self.put_grad(update_grad)

def compute_grad(self, batch_inputs, labels, criterion, seed: int) -> torch.Tensor:
def compute_grad(
self,
batch_inputs: BatchInput,
labels: torch.Tensor,
criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
seed: int | None,
) -> torch.Tensor:
"""When seed is None, it means we are just using random gradient estimator to estimate grad, we do not need to reconstruct each perturbation"""
if not self.paramwise_perturb:
# We generate the perturbation vector all together. It should be faster but consume
# more memory
Expand All @@ -145,6 +154,7 @@ def compute_grad(self, batch_inputs, labels, criterion, seed: int) -> torch.Tens
)
self.put_grad(grad)
else:
assert isinstance(seed, int)
perturbation_dir_grads = self._zo_grad_estimate_paramwise(
batch_inputs, labels, criterion, seed
)
Expand All @@ -166,10 +176,10 @@ def sgd_no_optim_update_model(

def _zo_grad_estimate(
self,
batch_inputs: torch.Tensor,
batch_inputs: BatchInput,
labels: torch.Tensor,
criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
seed: int,
seed: int | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Calculate the zeroth-order gradient estimate.

Expand All @@ -188,7 +198,7 @@ def _zo_grad_estimate(
pert_minus_loss = criterion(self.model_forward(batch_inputs), labels)

for i in range(self.num_pert):
rng = self.get_rng(seed, i)
rng: torch.Generator | None = self.get_rng(seed, i) if isinstance(seed, int) else None
pb_norm = self.generate_perturbation_norm(rng)

self.perturb_model(pb_norm, alpha=self.mu)
Expand Down Expand Up @@ -236,7 +246,7 @@ def perturb_model_paramwise(self, rng: torch.Generator, alpha: float | int) -> N

def _zo_grad_estimate_paramwise(
self,
batch_inputs: torch.Tensor,
batch_inputs: BatchInput,
labels: torch.Tensor,
criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
seed: int,
Expand Down
15 changes: 15 additions & 0 deletions cezo_fl/util/model_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

import torch
import torch.optim as optim
from peft import PeftModel
from transformers.models.opt.modeling_opt import OPTForCausalLM

from cezo_fl.util.language_utils import LLMBatchInput


def get_current_datetime_str():
Expand Down Expand Up @@ -73,3 +77,14 @@ def get_trainable_model_parameters(
for param in model.parameters():
if param.requires_grad:
yield param


def model_forward(
model: OPTForCausalLM | PeftModel | torch.nn.Module, batch_inputs: torch.Tensor | LLMBatchInput
):
if isinstance(model, (OPTForCausalLM, PeftModel)):
return model(input_ids=batch_inputs.input_ids, attention_mask=batch_inputs.attention_mask)
elif isinstance(model, torch.nn.Module):
return model(batch_inputs)
else:
raise Exception("This model type is not supported")
191 changes: 191 additions & 0 deletions cezo_fl/util/prepare_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import torch
import torch.nn as nn

from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer


from cezo_fl.util import model_helpers
from cezo_fl.models.cnn_fashion import CNN_FMNIST
from cezo_fl.models.cnn_mnist import CNN_MNIST
from cezo_fl.models.lenet import LeNet
from cezo_fl.models.lstm import CharLSTM
from cezo_fl.random_gradient_estimator import RandomGradientEstimator as RGE
from cezo_fl.util.language_utils import LM_TEMPLATE_MAP, SUPPORTED_LLM, get_lm_loss
from cezo_fl.util.metrics import accuracy


def prepare_settings_underseed(args, device, server_or_client: str = "server"):
torch_dtype = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}[args.model_dtype]
torch.manual_seed(args.seed)
if args.dataset == "mnist":
model = CNN_MNIST().to(torch_dtype).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
model_helpers.get_trainable_model_parameters(model),
lr=args.lr,
weight_decay=1e-5,
momentum=args.momentum,
)
accuracy_func = accuracy
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)
elif args.dataset == "cifar10":
model = LeNet().to(torch_dtype).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
model_helpers.get_trainable_model_parameters(model),
lr=args.lr,
weight_decay=5e-4,
momentum=args.momentum,
)
accuracy_func = accuracy
# scheduler = torch.optim.lr_scheduler.MultiStepLR(
# optimizer, milestones=[200], gamma=0.1
# )
elif args.dataset == "fashion":
model = CNN_FMNIST().to(torch_dtype).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
model_helpers.get_trainable_model_parameters(model),
lr=args.lr,
weight_decay=1e-5,
momentum=args.momentum,
)
accuracy_func = accuracy
# scheduler = torch.optim.lr_scheduler.MultiStepLR(
# optimizer, milestones=[200], gamma=0.1
# )
elif args.dataset == "shakespeare":
model = CharLSTM().to(torch_dtype).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
model_helpers.get_trainable_model_parameters(model),
lr=args.lr,
momentum=0.9,
weight_decay=5e-4,
)
accuracy_func = accuracy
# scheduler = torch.optim.lr_scheduler.MultiStepLR(
# optimizer, milestones=[200], gamma=0.1
# )
elif args.dataset in LM_TEMPLATE_MAP.keys():
large_model = args.large_model
model_name = SUPPORTED_LLM[large_model]
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch_dtype).to(device)
model.model_name = large_model
tokenizer = AutoTokenizer.from_pretrained(
model_name, padding_side="left", truncate_side="left"
)
template = LM_TEMPLATE_MAP[args.dataset]()
if args.dataset in ["sst2", "cb", "wsc", "wic", "multirc", "rte", "boolq"]:
if args.lora:
# this step initialize lora parameters, which should be under control of seed
lora_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=["q_proj", "v_proj"],
)
model = get_peft_model(model, lora_config).to(torch_dtype)
verbalizer_id_map = template.get_verbalizer_id(tokenizer)
criterion = get_lm_loss("last_token", verbalizer_id_map=verbalizer_id_map)
optimizer = torch.optim.SGD(
model_helpers.get_trainable_model_parameters(model),
lr=args.lr,
momentum=0,
weight_decay=5e-4,
)
accuracy_func = get_lm_loss("accuracy", verbalizer_id_map=verbalizer_id_map)
elif args.dataset in ["squad", "drop", "xsum"]:
if server_or_client == "server":
criterion = get_lm_loss("f1", tokenizer=tokenizer)
optimizer = torch.optim.SGD(
model_helpers.get_trainable_model_parameters(model),
lr=args.lr,
momentum=0,
weight_decay=0,
)
accuracy_func = get_lm_loss("f1", tokenizer=tokenizer)
elif server_or_client == "client":
criterion = get_lm_loss("full_sentence", verbalizer_id_map={})
optimizer = torch.optim.SGD(
model_helpers.get_trainable_model_parameters(model),
lr=args.lr,
momentum=0,
weight_decay=0,
)
accuracy_func = get_lm_loss("full_sentence", verbalizer_id_map={})
else:
raise ValueError(
"server_or_client must be either 'server' or 'client'. "
f"But get {server_or_client}"
)
else:
raise ValueError(f"Dataset {args.dataset} is not supported")
else:
raise Exception(f"Dataset {args.dataset} is not supported")

if args.grad_estimate_method in ["rge-central", "rge-forward"]:
method = args.grad_estimate_method[4:]
print(f"Using RGE {method}")
if args.dataset in ["squad", "drop"] and server_or_client == "server":
generation_mode = True
# TODO move this setting partially to the args
generation_mode_kwargs = {
"do_sample": True,
"temperature": 1.0,
"num_beams": 2,
"top_p": 0.3,
"top_k": None,
"num_return_sequences": 1,
"max_new_tokens": 5, # will be adjusted dynamically later
"max_length": 2048,
"length_penalty": 2,
"early_stopping": True,
"eos_token_id": [
tokenizer.encode("\n", add_special_tokens=False)[-1],
tokenizer.eos_token_id,
],
}
elif args.dataset in ["xsum"] and server_or_client == "server":
generation_mode = True
# TODO move this setting partially to the args
generation_mode_kwargs = {
"do_sample": True,
"temperature": 1.0,
"num_beams": 2,
"top_p": 0.95,
"top_k": None,
"num_return_sequences": 1,
"max_new_tokens": 500, # will be adjusted dynamically later
"max_length": 2048,
"early_stopping": True,
"eos_token_id": [
tokenizer.encode("\n", add_special_tokens=False)[-1],
tokenizer.eos_token_id,
],
}
else:
generation_mode = False
generation_mode_kwargs = None
grad_estimator = RGE(
model,
parameters=model_helpers.get_trainable_model_parameters(model),
mu=args.mu,
num_pert=args.num_pert,
grad_estimate_method=method,
device=device,
torch_dtype=torch_dtype,
# To save memory consumption, we have to use parameter-wise perturb + no_optim together.
sgd_only_no_optim=args.no_optim,
paramwise_perturb=args.no_optim,
# For generation mode, the forward style is different
generation_mode=generation_mode,
generation_mode_kwargs=generation_mode_kwargs,
)
else:
raise Exception(f"Grad estimate method {args.grad_estimate_method} not supported")
return model, criterion, optimizer, grad_estimator, accuracy_func
Loading
Loading