Skip to content

Commit

Permalink
FedOpt with batchnorm (#1851)
Browse files Browse the repository at this point in the history
* support FedOpt with batch norm layers

* support filtered model diff
  • Loading branch information
holgerroth authored Jul 12, 2023
1 parent 6afc958 commit 75536cd
Showing 1 changed file with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions nvflare/app_opt/pt/fedopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.abstract.learnable import Learnable
from nvflare.app_common.abstract.model import make_model_learnable
from nvflare.app_common.abstract.model import ModelLearnableKey, make_model_learnable
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.shareablegenerators.full_model_shareable_generator import FullModelShareableGenerator
from nvflare.security.logging import secure_format_exception
Expand All @@ -40,6 +40,8 @@ def __init__(
The algorithm is proposed in Reddi, Sashank, et al. "Adaptive federated optimization." arXiv preprint arXiv:2003.00295 (2020).
This SharableGenerator will update the global model using the specified
PyTorch optimizer and learning rate scheduler.
Note: This class will use FedOpt to optimize the global trainable parameters (i.e. `self.model.named_parameters()`)
but use FedAvg to update any other layers such as batch norm statistics.
Args:
optimizer_args: dictionary of optimizer arguments, e.g.
Expand Down Expand Up @@ -164,14 +166,17 @@ def server_update(self, model_diff):

# Apply the update to the model. We must multiply weights_delta by -1.0 to
# view it as a gradient that should be applied to the server_optimizer.
updated_params = []
for name, param in self.model.named_parameters():
param.grad = torch.tensor(-1.0 * model_diff[name]).to(self.device)
if name in model_diff:
param.grad = torch.tensor(-1.0 * model_diff[name]).to(self.device)
updated_params.append(name)

self.optimizer.step()
if self.lr_scheduler is not None:
self.lr_scheduler.step()

return self.model.state_dict()
return self.model.state_dict(), updated_params

def shareable_to_learnable(self, shareable: Shareable, fl_ctx: FLContext) -> Learnable:
"""Convert Shareable to Learnable while doing a FedOpt update step.
Expand Down Expand Up @@ -206,7 +211,7 @@ def shareable_to_learnable(self, shareable: Shareable, fl_ctx: FLContext) -> Lea
model_diff = dxo.data

start = time.time()
weights = self.server_update(model_diff)
weights, updated_params = self.server_update(model_diff)
secs = time.time() - start

# convert to numpy dict of weights
Expand All @@ -215,12 +220,28 @@ def shareable_to_learnable(self, shareable: Shareable, fl_ctx: FLContext) -> Lea
weights[key] = weights[key].detach().cpu().numpy()
secs_detach = time.time() - start

# update unnamed parameters such as batch norm layers if there are any using the averaged update
base_model = fl_ctx.get_prop(AppConstants.GLOBAL_MODEL)
if not base_model:
self.system_panic(reason="No global base model!", fl_ctx=fl_ctx)
return base_model

base_model_weights = base_model[ModelLearnableKey.WEIGHTS]

n_fedavg = 0
for key, value in model_diff.items():
if key not in updated_params:
weights[key] = base_model_weights[key] + value
n_fedavg += 1

self.log_info(
fl_ctx,
f"FedOpt ({self.optimizer_name}, {self.device}) server model update "
f"round {fl_ctx.get_prop(AppConstants.CURRENT_ROUND)}, "
f"{self.lr_scheduler_name if self.lr_scheduler_name else ''} "
f"lr: {self.optimizer.param_groups[-1]['lr']}, "
f"fedopt layers: {len(updated_params)}, "
f"fedavg layers: {n_fedavg}, "
f"update: {secs} secs., detach: {secs_detach} secs.",
)
# TODO: write server-side lr to tensorboard
Expand Down

0 comments on commit 75536cd

Please sign in to comment.