Skip to content

Commit

Permalink
- implement a new task_group filtering decorator 'with_selected_task_…
Browse files Browse the repository at this point in the history
…group' in Assigner class

- update all the sub-classes that use task_groups to use the decorator
- update fedeval sample workspace to use default assigner, tasks and aggregator
- use of federated-evaluation/aggregator.yaml for FedEval specific workspace example to use round_number as 1
- removed assigner and tasks yaml from defaults/federated-evaluation, superseded by default assigner/tasks
- added additional checks for assigner sub-classes that might not have task_groups
- Addressing review comments
- Updated existing test cases for Assigner sub-classes
- Remove hard-coded setting in assigner for torch_cnn_mnist ws, refer to default as in other Workspaces
- Use aggregator supplied --task_group to override the assinger selected_task_group
- update existing test cases of aggregator cli
- add test cases for the decorator
- rebased 25-Jan.1
- implemented the support of multiple task_group without selection
- defaulting of selected_task group 'percentage' to 1.0 post successful filtering
- updated test cases for multiple task group support
Signed-off-by: Shailesh Pant <shailesh.pant@intel.com>
  • Loading branch information
ishaileshpant committed Jan 29, 2025
1 parent 11cabf5 commit 0b51e3d
Show file tree
Hide file tree
Showing 17 changed files with 186 additions and 60 deletions.
12 changes: 2 additions & 10 deletions openfl-workspace/torch_cnn_mnist/plan/plan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,8 @@ aggregator:
rounds_to_train: 2
write_logs: false
template: openfl.component.aggregator.Aggregator
assigner:
settings:
task_groups:
- name: learning
percentage: 1.0
tasks:
- aggregated_model_validation
- train
- locally_tuned_model_validation
template: openfl.component.RandomGroupedAssigner
assigner :
defaults : plan/defaults/assigner.yaml
collaborator:
settings:
db_store_rounds: 1
Expand Down
8 changes: 5 additions & 3 deletions openfl-workspace/torch_cnn_mnist_fed_eval/plan/plan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ network :
defaults : plan/defaults/network.yaml

assigner :
defaults : plan/defaults/federated-evaluation/assigner.yaml

defaults : plan/defaults/assigner.yaml
settings :
selected_task_group : evaluation

tasks :
defaults : plan/defaults/federated-evaluation/tasks_torch.yaml
defaults : plan/defaults/tasks_torch.yaml

compression_pipeline :
defaults : plan/defaults/compression_pipeline.yaml
1 change: 1 addition & 0 deletions openfl-workspace/workspace/plan/defaults/aggregator.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ settings :
db_store_rounds : 2
persist_checkpoint: True
persistent_db_path: local_state/tensor.db

4 changes: 4 additions & 0 deletions openfl-workspace/workspace/plan/defaults/assigner.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ settings :
- aggregated_model_validation
- train
- locally_tuned_model_validation
- name : evaluation
percentage : 0
tasks :
- aggregated_model_validation

This file was deleted.

This file was deleted.

1 change: 1 addition & 0 deletions openfl/component/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""OpenFL Component Module."""

from openfl.component.aggregator.aggregator import Aggregator
from openfl.component.assigner.assigner import Assigner
Expand Down
14 changes: 12 additions & 2 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,25 @@ def __init__(
self.straggler_handling_policy = (
straggler_handling_policy or CutoffTimeBasedStragglerHandling()
)
self._end_of_round_check_done = [False] * rounds_to_train
self.stragglers = []

self.rounds_to_train = rounds_to_train
if self.task_group == "evaluation":
self.rounds_to_train = 1
logger.info(
f"task_group is {self.task_group}, setting rounds_to_train = {self.rounds_to_train}"
)

self._end_of_round_check_done = [False] * rounds_to_train
self.stragglers = []

# if the collaborator requests a delta, this value is set to true
self.authorized_cols = authorized_cols
self.uuid = aggregator_uuid
self.federation_uuid = federation_uuid
# # override the assigner selected_task_group
# # FIXME check the case of CustomAssigner as base class Assigner is redefined
# # and doesn't have selected_task_group as attribute
# assigner.selected_task_group = task_group
self.assigner = assigner
self.quit_job_sent_to = []

Expand Down
1 change: 1 addition & 0 deletions openfl/component/assigner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""OpenFL Assigner Module."""

from openfl.component.assigner.assigner import Assigner
from openfl.component.assigner.random_grouped_assigner import RandomGroupedAssigner
Expand Down
63 changes: 62 additions & 1 deletion openfl/component/assigner/assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

"""Assigner module."""

import logging
from functools import wraps

logger = logging.getLogger(__name__)


class Assigner:
r"""
Expand Down Expand Up @@ -35,18 +40,27 @@ class Assigner:
\* - ``tasks`` argument is taken from ``tasks`` section of FL plan YAML file.
"""

def __init__(self, tasks, authorized_cols, rounds_to_train, **kwargs):
def __init__(
self,
tasks,
authorized_cols,
rounds_to_train,
selected_task_group: str = None,
**kwargs,
):
"""Initializes the Assigner.
Args:
tasks (list of object): List of tasks to assign.
authorized_cols (list of str): Collaborators.
rounds_to_train (int): Number of training rounds.
selected_task_group (str, optional): Selected task_group.
**kwargs: Additional keyword arguments.
"""
self.tasks = tasks
self.authorized_cols = authorized_cols
self.rounds = rounds_to_train
self.selected_task_group = selected_task_group
self.all_tasks_in_groups = []

self.task_group_collaborators = {}
Expand Down Expand Up @@ -93,3 +107,50 @@ def get_aggregation_type_for_task(self, task_name):
if "aggregation_type" not in self.tasks[task_name]:
return None
return self.tasks[task_name]["aggregation_type"]

@classmethod
def with_selected_task_group(cls, func):
"""Decorator to filter task groups based on selected_task_group.
This decorator should be applied to define_task_assignments() method
in Assigner subclasses to handle task_group filtering.
"""

@wraps(func)
def wrapper(self, *args, **kwargs):
# Check if selection of task_group is applicable
if hasattr(self, "selected_task_group") and self.selected_task_group is not None:
# Verify task_groups exists before attempting filtering
if not hasattr(self, "task_groups"):
logger.warning(
"Task group specified for selection but no task_groups found. "
"Skipping filtering. This might be intentional for custom assigners."
)
return func(self, *args, **kwargs)

assert self.task_groups, "No task_groups defined in assigner."

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.

# Perform the filtering
selected_task_groups = [
group for group in self.task_groups if group["name"] == self.selected_task_group
]

assert len(selected_task_groups) == 1, (
f"Only one task group with name {self.selected_task_group} should exist"

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
)

# Since we have filtered to one of the task_groups, we need to ensure that
# the selected_task_group percentage compute allocation is defaulted to 1.0
current_percentage = selected_task_groups[0]["percentage"]
logger.info(
f"`percentage` for task_group {self.selected_task_group} is "
f"{current_percentage}, setting it to 1.0"
)
selected_task_groups[0]["percentage"] = 1.0

self.task_groups = selected_task_groups

# Call the original method
return func(self, *args, **kwargs)

return wrapper
7 changes: 5 additions & 2 deletions openfl/component/assigner/random_grouped_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from openfl.component.assigner.assigner import Assigner
from openfl.component.assigner import Assigner


class RandomGroupedAssigner(Assigner):
Expand All @@ -33,16 +33,19 @@ class RandomGroupedAssigner(Assigner):
\* - Plan setting.
"""

with_selected_task_group = Assigner.with_selected_task_group

def __init__(self, task_groups, **kwargs):
"""Initializes the RandomGroupedAssigner.
Args:
task_groups (list of object): Task groups to assign.
**kwargs: Additional keyword arguments.
**kwargs: Additional keyword arguments, including mode.
"""
self.task_groups = task_groups
super().__init__(**kwargs)

@with_selected_task_group
def define_task_assignments(self):
"""Define task assignments for each round and collaborator.
Expand Down
3 changes: 3 additions & 0 deletions openfl/component/assigner/static_grouped_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class StaticGroupedAssigner(Assigner):
\* - Plan setting.
"""

with_selected_task_group = Assigner.with_selected_task_group

def __init__(self, task_groups, **kwargs):
"""Initializes the StaticGroupedAssigner.
Expand All @@ -42,6 +44,7 @@ def __init__(self, task_groups, **kwargs):
self.task_groups = task_groups
super().__init__(**kwargs)

@with_selected_task_group
def define_task_assignments(self):
"""Define task assignments for each round and collaborator.
Expand Down
15 changes: 8 additions & 7 deletions openfl/interface/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def aggregator(context):
@option(
"--task_group",
required=False,
default="learning",
help="Selected task-group for assignment - defaults to learning",
help="Selected task-group for assignment",
)
def start_(plan, authorized_cols, task_group):
"""Start the aggregator service.
Expand All @@ -94,11 +93,13 @@ def start_(plan, authorized_cols, task_group):
cols_config_path=Path(authorized_cols).absolute(),
)

# Set task_group in aggregator settings
if "settings" not in parsed_plan.config["aggregator"]:
parsed_plan.config["aggregator"]["settings"] = {}
parsed_plan.config["aggregator"]["settings"]["task_group"] = task_group
logger.info(f"Setting aggregator to assign: {task_group} task_group")
# Set task_group in aggregator and assigner settings if provided
if task_group:
if "settings" not in parsed_plan.config["aggregator"]:
parsed_plan.config["aggregator"]["settings"] = {}
parsed_plan.config["aggregator"]["settings"]["task_group"] = task_group
parsed_plan.config["assigner"]["settings"]["selected_task_group"] = task_group
logger.info(f"Setting aggregator to assign: {task_group} task_group")

logger.info("🧿 Starting the Aggregator Service.")

Expand Down
28 changes: 18 additions & 10 deletions tests/openfl/component/assigner/test_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def assigner():

def test_get_aggregation_type_for_task_none(assigner):
"""Assert that aggregation type of custom task is None."""
task_name = 'test_name'
task_name = "test_name"
tasks = {task_name: {}}

assigner = assigner(tasks, None, None)
Expand All @@ -31,11 +31,9 @@ def test_get_aggregation_type_for_task_none(assigner):

def test_get_aggregation_type_for_task(assigner):
"""Assert that aggregation type of task is getting correctly."""
task_name = 'test_name'
test_aggregation_type = 'test_aggregation_type'
tasks = {task_name: {
'aggregation_type': test_aggregation_type
}}
task_name = "test_name"
test_aggregation_type = "test_aggregation_type"
tasks = {task_name: {"aggregation_type": test_aggregation_type}}
assigner = assigner(tasks, None, None)

aggregation_type = assigner.get_aggregation_type_for_task(task_name)
Expand All @@ -46,13 +44,23 @@ def test_get_aggregation_type_for_task(assigner):
def test_get_all_tasks_for_round(assigner):
"""Assert that assigner tasks object is list."""
assigner = Assigner(None, None, None)
tasks = assigner.get_all_tasks_for_round('test')
tasks = assigner.get_all_tasks_for_round("test")

assert isinstance(tasks, list)

def test_default_task_group(assigner):
"""Assert that by default learning task_group is assigned."""
assigner = Assigner(None,None,None)
assert assigner.selected_task_group == None

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.

class TestNotImplError(TestCase):
def test_task_group_filtering_no_task_groups(assigner):
"""Assert that task_group_filtering does not filter when no task_groups are defined."""
assigner = Assigner(None,None,None)
assigner.selected_task_group = "test_group"
assigner.define_task_assignments()
assert not hasattr(assigner, "task_groups")

class TestNotImplError(TestCase):
def test_define_task_assignments(self):
# TODO: define_task_assignments is defined as a mock in multiple fixtures,
# which leads the function to behave as a mock here and other tests.
Expand All @@ -61,9 +69,9 @@ def test_define_task_assignments(self):
def test_get_tasks_for_collaborator(self):
with self.assertRaises(NotImplementedError):
assigner = Assigner(None, None, None)
assigner.get_tasks_for_collaborator('col1', 0)
assigner.get_tasks_for_collaborator("col1", 0)

def test_get_collaborators_for_task(self):
with self.assertRaises(NotImplementedError):
assigner = Assigner(None, None, None)
assigner.get_collaborators_for_task('task_name', 0)
assigner.get_collaborators_for_task("task_name", 0)
Loading

0 comments on commit 0b51e3d

Please sign in to comment.