Skip to content

Commit

Permalink
Merge pull request #96 from DavdGao/feature/hotfix_opt
Browse files Browse the repository at this point in the history
Support optimizers with different parameters
  • Loading branch information
yxdyc authored May 31, 2022
2 parents d7234f6 + c6bbaa0 commit cf97ccb
Show file tree
Hide file tree
Showing 34 changed files with 165 additions and 132 deletions.
6 changes: 3 additions & 3 deletions federatedscope/attack/worker_as_attacker/server_attacker.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,14 @@ def __init__(self,
fl_model_criterion=get_criterion(self._cfg.criterion.type,
device=self.device),
device=self.device,
grad_clip=self._cfg.optimizer.grad_clip,
grad_clip=self._cfg.grad.grad_clip,
dataset_name=self._cfg.data.type,
fl_local_update_num=self._cfg.federate.local_update_steps,
fl_type_optimizer=self._cfg.fedopt.type_optimizer,
fl_type_optimizer=self._cfg.fedopt.optimizer.type,
fl_lr=self._cfg.optimizer.lr,
batch_size=100)

# self.optimizer = get_optimizer(type=self._cfg.fedopt.type_optimizer, model=self.model,lr=self._cfg.fedopt.lr_server)
# self.optimizer = get_optimizer(type=self._cfg.fedopt.type_optimizer, model=self.model,lr=self._cfg.fedopt.optimizer.lr)
# print(self.optimizer)
def callback_funcs_model_para(self, message: Message):
round, sender, content = message.state, message.sender, message.content
Expand Down
6 changes: 2 additions & 4 deletions federatedscope/core/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,8 @@ def __init__(self, config, model, device='cpu'):
self.cfg = config
self.model = model
self.device = device
self.optimizer = get_optimizer(type=config.fedopt.type_optimizer,
model=self.model,
lr=config.fedopt.lr_server,
momentum=config.fedopt.momentum_server)
self.optimizer = get_optimizer(model=self.model,
**config.fedopt.optimizer)

def aggregate(self, agg_info):
new_model = super().aggregate(agg_info)
Expand Down
2 changes: 1 addition & 1 deletion federatedscope/core/auxiliaries/optimizer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
torch = None


def get_optimizer(type, model, lr, **kwargs):
def get_optimizer(model, type, lr, **kwargs):
if torch is None:
return None
if isinstance(type, str):
Expand Down
7 changes: 4 additions & 3 deletions federatedscope/core/configs/cfg_fl_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ def extend_fl_algo_cfg(cfg):
cfg.fedopt = CN()

cfg.fedopt.use = False
cfg.fedopt.lr_server = 0.01
cfg.fedopt.momentum_server = 0.
cfg.fedopt.type_optimizer = 'SGD'

cfg.fedopt.optimizer = CN(new_allowed=True)
cfg.fedopt.optimizer.type = 'SGD'
cfg.fedopt.optimizer.lr = 0.01

# ------------------------------------------------------------------------ #
# fedprox related options, general fl
Expand Down
11 changes: 7 additions & 4 deletions federatedscope/core/configs/cfg_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@ def extend_training_cfg(cfg):
# ------------------------------------------------------------------------ #
# Optimizer related options
# ------------------------------------------------------------------------ #
cfg.optimizer = CN()
cfg.optimizer = CN(new_allowed=True)

cfg.optimizer.type = 'SGD'
cfg.optimizer.lr = 0.1
cfg.optimizer.weight_decay = .0
cfg.optimizer.momentum = .0
cfg.optimizer.grad_clip = -1.0 # negative numbers indicate we do not clip grad

# ------------------------------------------------------------------------ #
# Gradient related options
# ------------------------------------------------------------------------ #
cfg.grad = CN()
cfg.grad.grad_clip = -1.0 # negative numbers indicate we do not clip grad

# ------------------------------------------------------------------------ #
# lr_scheduler related options
Expand Down
47 changes: 38 additions & 9 deletions federatedscope/core/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os

from yacs.config import CfgNode
from yacs.config import _assert_with_logging
from yacs.config import _check_and_coerce_cfg_value_type

import federatedscope.register as register

Expand All @@ -17,7 +19,14 @@ class CN(CfgNode):
"""
def __init__(self, init_dict=None, key_list=None, new_allowed=False):
super().__init__(init_dict, key_list, new_allowed)
self.cfg_check_funcs = [] # to check the config values validity
self.__dict__["cfg_check_funcs"] = list(
) # to check the config values validity

def __getattr__(self, name):
if name in self:
return self[name]
else:
raise AttributeError(name)

def register_cfg_check_fun(self, cfg_check_fun):
self.cfg_check_funcs.append(cfg_check_fun)
Expand All @@ -29,9 +38,7 @@ def merge_from_file(self, cfg_filename):
:param cfg_filename (string):
:return:
"""
cfg_check_funcs = copy.copy(self.cfg_check_funcs)
super(CN, self).merge_from_file(cfg_filename)
self.cfg_check_funcs = cfg_check_funcs
self.assert_cfg()

def merge_from_other_cfg(self, cfg_other):
Expand All @@ -41,21 +48,43 @@ def merge_from_other_cfg(self, cfg_other):
:param cfg_other (CN):
:return:
"""
cfg_check_funcs = copy.copy(self.cfg_check_funcs)
super(CN, self).merge_from_other_cfg(cfg_other)
self.cfg_check_funcs = cfg_check_funcs
self.assert_cfg()

def merge_from_list(self, cfg_list):
"""
load configs from a list stores the keys and values.
modified `merge_from_list` in `yacs.config.py` to allow adding new keys if `is_new_allowed()` returns True
:param cfg_list (list):
:return:
"""
cfg_check_funcs = copy.copy(self.cfg_check_funcs)
super(CN, self).merge_from_list(cfg_list)
self.cfg_check_funcs = cfg_check_funcs
_assert_with_logging(
len(cfg_list) % 2 == 0,
"Override list has odd length: {}; it must be a list of pairs".format(
cfg_list
),
)
root = self
for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
if root.key_is_deprecated(full_key):
continue
if root.key_is_renamed(full_key):
root.raise_key_rename_error(full_key)
key_list = full_key.split(".")
d = self
for subkey in key_list[:-1]:
_assert_with_logging(
subkey in d, "Non-existent key: {}".format(full_key)
)
d = d[subkey]
subkey = key_list[-1]
_assert_with_logging(subkey in d or d.is_new_allowed(), "Non-existent key: {}".format(full_key))
value = self._decode_cfg_value(v)
if subkey in d:
value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, full_key)
d[subkey] = value

self.assert_cfg()

def assert_cfg(self):
Expand Down Expand Up @@ -99,7 +128,7 @@ def freeze(self, inform=True):
from contextlib import redirect_stdout
with redirect_stdout(outfile):
tmp_cfg = copy.deepcopy(self)
tmp_cfg.cfg_check_funcs = []
tmp_cfg.cfg_check_funcs.clear()
print(tmp_cfg.dump())
if self.wandb.use:
# update the frozen config
Expand Down
9 changes: 2 additions & 7 deletions federatedscope/core/trainers/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,8 @@ def setup_vars(self):
self.criterion = get_criterion(self.cfg.criterion.type,
self.device)
self.regularizer = get_regularizer(self.cfg.regularizer.type)
self.optimizer = get_optimizer(
self.cfg.optimizer.type,
self.model,
self.cfg.optimizer.lr,
weight_decay=self.cfg.optimizer.weight_decay,
momentum=self.cfg.optimizer.momentum)
self.grad_clip = self.cfg.optimizer.grad_clip
self.optimizer = get_optimizer(self.model, **self.cfg.optimizer)
self.grad_clip = self.cfg.grad.grad_clip
elif self.cfg.backend == 'tensorflow':
self.trainable_para_names = self.model.trainable_variables()
self.criterion = None
Expand Down
14 changes: 4 additions & 10 deletions federatedscope/core/trainers/trainer_Ditto.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,10 @@ def init_Ditto_ctx(base_trainer):
ctx.local_model = copy.deepcopy(ctx.model) # the personalized model
ctx.models = [ctx.local_model, ctx.global_model]

ctx.optimizer_for_global_model = get_optimizer(
cfg.optimizer.type,
ctx.global_model,
cfg.optimizer.lr,
weight_decay=cfg.optimizer.weight_decay)
ctx.optimizer_for_local_model = get_optimizer(
cfg.optimizer.type,
ctx.local_model,
cfg.personalization.lr,
weight_decay=cfg.optimizer.weight_decay)
ctx.optimizer_for_global_model = get_optimizer(ctx.global_model,
**cfg.optimizer)
ctx.optimizer_for_local_model = get_optimizer(ctx.local_model,
**cfg.optimizer)
ctx.optimizer_for_local_model = wrap_regularized_optimizer(
ctx.optimizer_for_local_model, cfg.personalization.regular_weight)

Expand Down
5 changes: 1 addition & 4 deletions federatedscope/core/trainers/trainer_multi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,7 @@ def init_multiple_models(self):
self.ctx.models = [self.ctx.model] + additional_models

additional_optimizers = [
get_optimizer(self.cfg.optimizer.type,
self.ctx.models[i],
self.cfg.optimizer.lr,
weight_decay=self.cfg.optimizer.weight_decay)
get_optimizer(self.ctx.models[i], **self.cfg.optimizer)
for i in range(1, self.model_nums)
]
self.ctx.optimizers = [self.ctx.optimizer] + additional_optimizers
Expand Down
1 change: 1 addition & 0 deletions federatedscope/cv/baseline/fedavg_convnet2_on_femnist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ model:
optimizer:
lr: 0.01
weight_decay: 0.0
grad:
grad_clip: 5.0
criterion:
type: CrossEntropyLoss
Expand Down
1 change: 1 addition & 0 deletions federatedscope/cv/baseline/fedbn_convnet2_on_femnist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ personalization:
optimizer:
lr: 0.01
weight_decay: 0.0
grad:
grad_clip: 5.0
criterion:
type: CrossEntropyLoss
Expand Down
1 change: 1 addition & 0 deletions federatedscope/example_configs/femnist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ model:
optimizer:
lr: 0.01
weight_decay: 0.0
grad:
grad_clip: 5.0
criterion:
type: CrossEntropyLoss
Expand Down
37 changes: 0 additions & 37 deletions scripts/dp_exp_scrips/run_femnist_dp_standalone.sh

This file was deleted.

19 changes: 0 additions & 19 deletions scripts/dp_exp_scrips/run_femnist_fedopt_standalone.sh

This file was deleted.

17 changes: 0 additions & 17 deletions scripts/dp_exp_scrips/run_femnist_standard_standalone.sh

This file was deleted.

2 changes: 1 addition & 1 deletion scripts/fedopt_exp_scripts/run_fedopt_femnist.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ do
data.root /mnt/gaodawei.gdw/data/ \
fedopt.use True \
federate.method FedOpt \
fedopt.lr_server ${lrs[$il]} \
fedopt.optimizer.lr ${lrs[$il]} \
>>out_fedopt_femnist/nothing.out \
2>>out_fedopt_femnist/lr_${lrs[$il]}.log
done
Expand Down
2 changes: 1 addition & 1 deletion scripts/fedopt_exp_scripts/run_fedopt_lr.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ do
data.root /mnt/gaodawei.gdw/data/ \
fedopt.use True \
federate.method FedOpt \
fedopt.lr_server ${lrs[$il]} \
fedopt.optimizer.lr ${lrs[$il]} \
>>out_fedopt_lr/nothing.out \
2>>out_fedopt_lr/lr_${lrs[$il]}.log
done
Expand Down
2 changes: 1 addition & 1 deletion scripts/fedopt_exp_scripts/run_fedopt_shakespeare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ do
data.root /mnt/gaodawei.gdw/data/ \
fedopt.use True \
federate.method FedOpt \
fedopt.lr_server ${lrs[$il]} \
fedopt.optimizer.lr ${lrs[$il]} \
>>out_fedopt_shakespeare/nothing.out \
2>>out_fedopt_shakespeare/lr_${lrs[$il]}.log
done
Expand Down
2 changes: 1 addition & 1 deletion scripts/gnn_exp_scripts/run_graph_level_opt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ for (( s=0; s<${#lr_servers[@]}; s++ ))
do
for k in {1..5}
do
python federatedscope/main.py --cfg federatedscope/gfl/baseline/fedavg_gcn_minibatch_on_hiv.yaml device ${cudaid} data.type ${dataset} data.splitter ${splitter} optimizer.lr ${lr} federate.local_update_steps ${local_update} model.type ${gnn} model.out_channels ${out_channels} model.hidden ${hidden} seed $k federate.method FedOpt fedopt.lr_server ${lr_servers[$s]} >>out/${gnn}_${lr}_${local_update}_on_${dataset}_${splitter}_${lr_servers[$s]}_opt.log 2>&1
python federatedscope/main.py --cfg federatedscope/gfl/baseline/fedavg_gcn_minibatch_on_hiv.yaml device ${cudaid} data.type ${dataset} data.splitter ${splitter} optimizer.lr ${lr} federate.local_update_steps ${local_update} model.type ${gnn} model.out_channels ${out_channels} model.hidden ${hidden} seed $k federate.method FedOpt fedopt.optimizer.lr ${lr_servers[$s]} >>out/${gnn}_${lr}_${local_update}_on_${dataset}_${splitter}_${lr_servers[$s]}_opt.log 2>&1
done
done

Expand Down
2 changes: 1 addition & 1 deletion scripts/gnn_exp_scripts/run_link_level_opt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ for (( s=0; s<${#lr_servers[@]}; s++ ))
do
for k in {1..5}
do
python federatedscope/main.py --cfg federatedscope/gfl/baseline/fedavg_gcn_fullbatch_on_kg.yaml device ${cudaid} data.type ${dataset} data.splitter ${splitter} optimizer.lr ${lr} federate.local_update_steps ${local_update} model.type ${gnn} model.out_channels ${out_channels} model.hidden ${hidden} seed $k federate.method FedOpt fedopt.lr_server ${lr_servers[$s]} model.layer ${layer} >>out/${gnn}_${lr}_${local_update}_on_${dataset}_${splitter}_${lr_servers[$s]}_opt.log 2>&1
python federatedscope/main.py --cfg federatedscope/gfl/baseline/fedavg_gcn_fullbatch_on_kg.yaml device ${cudaid} data.type ${dataset} data.splitter ${splitter} optimizer.lr ${lr} federate.local_update_steps ${local_update} model.type ${gnn} model.out_channels ${out_channels} model.hidden ${hidden} seed $k federate.method FedOpt fedopt.optimizer.lr ${lr_servers[$s]} model.layer ${layer} >>out/${gnn}_${lr}_${local_update}_on_${dataset}_${splitter}_${lr_servers[$s]}_opt.log 2>&1
done
done

Expand Down
2 changes: 1 addition & 1 deletion scripts/gnn_exp_scripts/run_multi_opt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ for (( s=0; s<${#lr_servers[@]}; s++ ))
do
for k in {1..5}
do
python federatedscope/main.py --cfg federatedscope/gfl/baseline/fedavg_gnn_minibatch_on_multi_task.yaml device ${cudaid} data.type ${dataset} data.splitter ${splitter} optimizer.lr ${lr} federate.local_update_steps ${local_update} model.type ${gnn} model.out_channels ${out_channels} model.hidden ${hidden} seed $k federate.method FedOpt fedopt.lr_server ${lr_servers[$s]} >>out/${gnn}_${lr}_${local_update}_on_${dataset}_${splitter}_${lr_servers[$s]}_opt.log 2>&1
python federatedscope/main.py --cfg federatedscope/gfl/baseline/fedavg_gnn_minibatch_on_multi_task.yaml device ${cudaid} data.type ${dataset} data.splitter ${splitter} optimizer.lr ${lr} federate.local_update_steps ${local_update} model.type ${gnn} model.out_channels ${out_channels} model.hidden ${hidden} seed $k federate.method FedOpt fedopt.optimizer.lr ${lr_servers[$s]} >>out/${gnn}_${lr}_${local_update}_on_${dataset}_${splitter}_${lr_servers[$s]}_opt.log 2>&1
done
done

Expand Down
2 changes: 1 addition & 1 deletion scripts/gnn_exp_scripts/run_node_level_opt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ for (( s=0; s<${#lr_servers[@]}; s++ ))
do
for k in {1..5}
do
python federatedscope/main.py --cfg federatedscope/gfl/baseline/fedavg_gnn_node_fullbatch_citation.yaml device ${cudaid} data.type ${dataset} data.splitter ${splitter} optimizer.lr ${lr} federate.local_update_steps ${local_update} model.type ${gnn} model.out_channels ${out_channels} model.hidden ${hidden} seed $k federate.method FedOpt fedopt.lr_server ${lr_servers[$s]} model.layer ${layer} >>out/${gnn}_${lr}_${local_update}_on_${dataset}_${splitter}_${lr_servers[$s]}_opt.log 2>&1
python federatedscope/main.py --cfg federatedscope/gfl/baseline/fedavg_gnn_node_fullbatch_citation.yaml device ${cudaid} data.type ${dataset} data.splitter ${splitter} optimizer.lr ${lr} federate.local_update_steps ${local_update} model.type ${gnn} model.out_channels ${out_channels} model.hidden ${hidden} seed $k federate.method FedOpt fedopt.optimizer.lr ${lr_servers[$s]} model.layer ${layer} >>out/${gnn}_${lr}_${local_update}_on_${dataset}_${splitter}_${lr_servers[$s]}_opt.log 2>&1
done
done

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ personalization:
optimizer:
lr: 0.5
weight_decay: 0.0
grad:
grad_clip: 5.0
criterion:
type: CrossEntropyLoss
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ personalization:
optimizer:
lr: 0.1
weight_decay: 0.0
grad:
grad_clip: 5.0
criterion:
type: CrossEntropyLoss
Expand Down
Loading

0 comments on commit cf97ccb

Please sign in to comment.