diff --git a/federatedscope/core/auxiliaries/model_builder.py b/federatedscope/core/auxiliaries/model_builder.py index 435b70d41..3dcb460c4 100644 --- a/federatedscope/core/auxiliaries/model_builder.py +++ b/federatedscope/core/auxiliaries/model_builder.py @@ -33,7 +33,13 @@ def get_shape_from_data(data, model_config, backend='torch'): num_label = data['num_label'] if 'num_label' in data else None num_edge_features = data[ 'num_edge_features'] if model_config.type == 'mpnn' else None - return (data['train'].x.shape, num_label, num_edge_features) + if model_config.task.startswith('graph'): + # graph-level task + data_representative = next(iter(data['train'])) + return (data_representative.x.shape, num_label, num_edge_features) + else: + # node/link-level task + return (data.x.shape, num_label, num_edge_features) if isinstance(data, dict): keys = list(data.keys())