-
Notifications
You must be signed in to change notification settings - Fork 218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Construct FL course when server does not have data #236
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,99 +11,144 @@ | |
f'available.') | ||
|
||
|
||
def get_model(model_config, local_data, backend='torch'): | ||
def get_shape_from_data(data, model_config, backend='torch'): | ||
""" | ||
Extract the input shape from the given data, which can be used to build | ||
the data. Users can also use `data.input_shape` to specify the shape | ||
Arguments: | ||
data (object): the data used for local training or evaluation | ||
The expected data format: | ||
1): {train/val/test: {x:ndarray, y:ndarray}}} | ||
2): {train/val/test: DataLoader} | ||
Returns: | ||
shape (tuple): the input shape | ||
""" | ||
# Handle some special cases | ||
if model_config.type.lower() in ['vmfnet', 'hmfnet']: | ||
return data['train'].n_col if model_config.type.lower( | ||
) == 'vmfnet' else data['train'].n_row | ||
elif model_config.type.lower() in [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For graph-level tasks, data is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have fixed this issue, thx! |
||
'gcn', 'sage', 'gpr', 'gat', 'gin', 'mpnn' | ||
]: | ||
num_label = data['num_label'] if 'num_label' in data else None | ||
num_edge_features = data[ | ||
'num_edge_features'] if model_config.type == 'mpnn' else None | ||
if model_config.task.startswith('graph'): | ||
# graph-level task | ||
data_representative = next(iter(data['train'])) | ||
return (data_representative.x.shape, num_label, num_edge_features) | ||
else: | ||
# node/link-level task | ||
return (data.x.shape, num_label, num_edge_features) | ||
|
||
if isinstance(data, dict): | ||
keys = list(data.keys()) | ||
if 'test' in keys: | ||
key_representative = 'test' | ||
elif 'train' in keys: | ||
key_representative = 'train' | ||
elif 'data' in keys: | ||
key_representative = 'data' | ||
else: | ||
key_representative = keys[0] | ||
logger.warning(f'We chose the key {key_representative} as the ' | ||
f'representative key to extract data shape.') | ||
|
||
data_representative = data[key_representative] | ||
else: | ||
# Handle the data with non-dict format | ||
data_representative = data | ||
|
||
if isinstance(data_representative, dict): | ||
if 'x' in data_representative: | ||
shape = data_representative['x'].shape | ||
if len(shape) == 1: # (batch, ) = (batch, 1) | ||
return 1 | ||
else: | ||
return shape | ||
elif backend == 'torch': | ||
import torch | ||
if issubclass(type(data_representative), torch.utils.data.DataLoader): | ||
x, _ = next(iter(data_representative)) | ||
return x.shape | ||
else: | ||
try: | ||
x, _ = data_representative | ||
return x.shape | ||
except: | ||
raise TypeError('Unsupported data type.') | ||
elif backend == 'tensorflow': | ||
# TODO: Handle more tensorflow type here | ||
shape = data_representative['x'].shape | ||
if len(shape) == 1: # (batch, ) = (batch, 1) | ||
return 1 | ||
else: | ||
return shape | ||
|
||
|
||
def get_model(model_config, local_data=None, backend='torch'): | ||
""" | ||
Arguments: | ||
local_data (object): the model to be instantiated is | ||
responsible for the given data. | ||
Returns: | ||
model (torch.Module): the instantiated model. | ||
""" | ||
if local_data is not None: | ||
input_shape = get_shape_from_data(local_data, model_config, backend) | ||
else: | ||
input_shape = model_config.input_shape | ||
|
||
if input_shape is None: | ||
logger.warning('The input shape is None. Please specify the ' | ||
'`data.input_shape`(a tuple) or give the ' | ||
'representative data to `get_model` if necessary') | ||
|
||
for func in register.model_dict.values(): | ||
model = func(model_config, local_data) | ||
model = func(model_config, input_shape) | ||
if model is not None: | ||
return model | ||
|
||
if model_config.type.lower() == 'lr': | ||
if backend == 'torch': | ||
from federatedscope.core.lr import LogisticRegression | ||
# TODO: make the instantiation more general | ||
if isinstance( | ||
local_data, dict | ||
) and 'test' in local_data and 'x' in local_data['test']: | ||
model = LogisticRegression( | ||
in_channels=local_data['test']['x'].shape[-1], | ||
class_num=1, | ||
use_bias=model_config.use_bias) | ||
else: | ||
if isinstance(local_data, dict): | ||
if 'data' in local_data.keys(): | ||
data = local_data['data'] | ||
elif 'train' in local_data.keys(): | ||
# local_data['train'] is Dataloader | ||
data = next(iter(local_data['train'])) | ||
else: | ||
raise TypeError('Unsupported data type.') | ||
else: | ||
data = local_data | ||
|
||
x, _ = data | ||
model = LogisticRegression(in_channels=x.shape[-1], | ||
class_num=model_config.out_channels) | ||
model = LogisticRegression(in_channels=input_shape[-1], | ||
class_num=model_config.out_channels) | ||
elif backend == 'tensorflow': | ||
from federatedscope.cross_backends import LogisticRegression | ||
model = LogisticRegression( | ||
in_channels=local_data['test']['x'].shape[-1], | ||
class_num=1, | ||
use_bias=model_config.use_bias) | ||
model = LogisticRegression(in_channels=input_shape[-1], | ||
class_num=1, | ||
use_bias=model_config.use_bias) | ||
else: | ||
raise ValueError | ||
|
||
elif model_config.type.lower() == 'mlp': | ||
from federatedscope.core.mlp import MLP | ||
if isinstance(local_data, dict): | ||
if 'data' in local_data.keys(): | ||
data = local_data['data'] | ||
elif 'train' in local_data.keys(): | ||
# local_data['train'] is Dataloader | ||
data = next(iter(local_data['train'])) | ||
else: | ||
raise TypeError('Unsupported data type.') | ||
else: | ||
data = local_data | ||
|
||
x, _ = data | ||
model = MLP(channel_list=[x.shape[-1]] + [model_config.hidden] * | ||
model = MLP(channel_list=[input_shape[-1]] + [model_config.hidden] * | ||
(model_config.layer - 1) + [model_config.out_channels], | ||
dropout=model_config.dropout) | ||
|
||
elif model_config.type.lower() == 'quadratic': | ||
from federatedscope.tabular.model import QuadraticModel | ||
if isinstance(local_data, dict): | ||
data = next(iter(local_data['train'])) | ||
else: | ||
# TODO: complete the branch | ||
data = local_data | ||
x, _ = data | ||
model = QuadraticModel(x.shape[-1], 1) | ||
model = QuadraticModel(input_shape[-1], 1) | ||
|
||
elif model_config.type.lower() in ['convnet2', 'convnet5', 'vgg11', 'lr']: | ||
from federatedscope.cv.model import get_cnn | ||
model = get_cnn(model_config, local_data) | ||
model = get_cnn(model_config, input_shape) | ||
elif model_config.type.lower() in ['lstm']: | ||
from federatedscope.nlp.model import get_rnn | ||
model = get_rnn(model_config, local_data) | ||
model = get_rnn(model_config, input_shape) | ||
elif model_config.type.lower().endswith('transformers'): | ||
from federatedscope.nlp.model import get_transformer | ||
model = get_transformer(model_config, local_data) | ||
model = get_transformer(model_config, input_shape) | ||
elif model_config.type.lower() in [ | ||
'gcn', 'sage', 'gpr', 'gat', 'gin', 'mpnn' | ||
]: | ||
from federatedscope.gfl.model import get_gnn | ||
model = get_gnn(model_config, local_data) | ||
model = get_gnn(model_config, input_shape) | ||
elif model_config.type.lower() in ['vmfnet', 'hmfnet']: | ||
from federatedscope.mf.model.model_builder import get_mfnet | ||
model = get_mfnet(model_config, local_data) | ||
model = get_mfnet(model_config, input_shape) | ||
else: | ||
raise ValueError('Model {} is not provided'.format(model_config.type)) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,7 +54,12 @@ def extend_fl_setting_cfg(cfg): | |
cfg.distribute.client_port = 50050 | ||
cfg.distribute.role = 'client' | ||
cfg.distribute.data_file = 'data' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The data-related keyword There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
cfg.distribute.data_idx = -1 | ||
cfg.distribute.data_idx = -1 # data_idx is used to specify the data | ||
# index in distributed mode when adopting a centralized dataset for | ||
# simulation (formatted as {data_idx: data/dataloader}). | ||
# data_idx = -1 means that the whole dataset is owned by the participant. | ||
# when data_idx is other invalid values excepted for -1, we randomly | ||
# sample the data_idx for simulation | ||
cfg.distribute.grpc_max_send_message_length = 100 * 1024 * 1024 | ||
cfg.distribute.grpc_max_receive_message_length = 100 * 1024 * 1024 | ||
cfg.distribute.grpc_enable_http_proxy = False | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
use_gpu: True | ||
federate: | ||
client_num: 3 | ||
mode: 'distributed' | ||
total_round_num: 20 | ||
make_global_eval: False | ||
online_aggr: False | ||
distribute: | ||
use: True | ||
server_host: '127.0.0.1' | ||
server_port: 50051 | ||
role: 'server' | ||
trainer: | ||
type: 'general' | ||
eval: | ||
freq: 10 | ||
data: | ||
type: '' | ||
model: | ||
type: 'lr' | ||
input_shape: (5,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When
data.type
isNone
or""
whiledistribute.data_idx
is given, we should assert inconfigs
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When
date.type
is None or "", thedistribute.data_idx
does not work. IMO, we can provide a WARNING but not an assertion here.