Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FedOpt with batchnorm #1851

Merged
merged 4 commits into from
Jul 12, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions nvflare/app_opt/pt/fedopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.abstract.learnable import Learnable
from nvflare.app_common.abstract.model import make_model_learnable
from nvflare.app_common.abstract.model import ModelLearnableKey, make_model_learnable
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.shareablegenerators.full_model_shareable_generator import FullModelShareableGenerator
from nvflare.security.logging import secure_format_exception
Expand All @@ -40,6 +40,8 @@ def __init__(
The algorithm is proposed in Reddi, Sashank, et al. "Adaptive federated optimization." arXiv preprint arXiv:2003.00295 (2020).
This SharableGenerator will update the global model using the specified
PyTorch optimizer and learning rate scheduler.
Note: This class will use FedOpt to optimize the global trainable parameters (i.e. `self.model.named_parameters()`)
but use FedAvg to update any other layers such as batch norm statistics.

Args:
optimizer_args: dictionary of optimizer arguments, e.g.
Expand Down Expand Up @@ -164,14 +166,17 @@ def server_update(self, model_diff):

# Apply the update to the model. We must multiply weights_delta by -1.0 to
# view it as a gradient that should be applied to the server_optimizer.
updated_params = []
for name, param in self.model.named_parameters():
param.grad = torch.tensor(-1.0 * model_diff[name]).to(self.device)
if name in model_diff:
guopengf marked this conversation as resolved.
Show resolved Hide resolved
param.grad = torch.tensor(-1.0 * model_diff[name]).to(self.device)
updated_params.append(name)

self.optimizer.step()
if self.lr_scheduler is not None:
self.lr_scheduler.step()

return self.model.state_dict()
return self.model.state_dict(), updated_params

def shareable_to_learnable(self, shareable: Shareable, fl_ctx: FLContext) -> Learnable:
"""Convert Shareable to Learnable while doing a FedOpt update step.
Expand Down Expand Up @@ -206,7 +211,7 @@ def shareable_to_learnable(self, shareable: Shareable, fl_ctx: FLContext) -> Lea
model_diff = dxo.data

start = time.time()
weights = self.server_update(model_diff)
weights, updated_params = self.server_update(model_diff)
secs = time.time() - start

# convert to numpy dict of weights
Expand All @@ -215,12 +220,28 @@ def shareable_to_learnable(self, shareable: Shareable, fl_ctx: FLContext) -> Lea
weights[key] = weights[key].detach().cpu().numpy()
secs_detach = time.time() - start

# update unnamed parameters such as batch norm layers if there are any using the averaged update
base_model = fl_ctx.get_prop(AppConstants.GLOBAL_MODEL)
if not base_model:
self.system_panic(reason="No global base model!", fl_ctx=fl_ctx)
return base_model

base_model_weights = base_model[ModelLearnableKey.WEIGHTS]

n_fedavg = 0
for key, value in model_diff.items():
if key not in updated_params:
weights[key] = base_model_weights[key] + value
guopengf marked this conversation as resolved.
Show resolved Hide resolved
n_fedavg += 1

self.log_info(
fl_ctx,
f"FedOpt ({self.optimizer_name}, {self.device}) server model update "
f"round {fl_ctx.get_prop(AppConstants.CURRENT_ROUND)}, "
f"{self.lr_scheduler_name if self.lr_scheduler_name else ''} "
f"lr: {self.optimizer.param_groups[-1]['lr']}, "
f"fedopt layers: {len(updated_params)}, "
f"fedavg layers: {n_fedavg}, "
f"update: {secs} secs., detach: {secs_detach} secs.",
)
# TODO: write server-side lr to tensorboard
Expand Down