From cdb87299d0463799bfa17206700b6c1cd7414d81 Mon Sep 17 00:00:00 2001 From: ParthM-GitHub Date: Wed, 19 Jul 2023 14:10:49 +0530 Subject: [PATCH] External Loop Functionality Added External Loop Functionality Added Signed-off-by: Parth Mandaliya --- .../component/aggregator/aggregator.py | 23 ++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/openfl/experimental/component/aggregator/aggregator.py b/openfl/experimental/component/aggregator/aggregator.py index d6615d640f5..f3a21eaff0e 100644 --- a/openfl/experimental/component/aggregator/aggregator.py +++ b/openfl/experimental/component/aggregator/aggregator.py @@ -60,6 +60,7 @@ def __init__( self.authorized_cols = authorized_cols self.round_number = rounds_to_train + self.current_round = 1 self.collaborators_counter = 0 self.quit_job_sent_to = [] self.time_to_quit = False @@ -120,6 +121,7 @@ def run_flow_until_transition(self) -> None: """ f_name = self.flow.run() + self.logger.info(f"Starting round {self.current_round}...") while True: next_step = self.do_task(f_name) @@ -245,7 +247,7 @@ def get_tasks(self, collaborator_name: str) -> Tuple: next_step, clone = self.collaborator_tasks_queue[ collaborator_name].get() - return 0, next_step, pickle.dumps(clone), 0, self.time_to_quit + return self.current_round, next_step, pickle.dumps(clone), 0, self.time_to_quit def do_task(self, f_name: str) -> Any: """Execute aggregator steps until transition.""" @@ -258,10 +260,17 @@ def do_task(self, f_name: str) -> Any: if f.__name__ == "end": f() + # TODO: Think of different approach than deep-copying the + # flow object. self.call_checkpoint(deepcopy(self.flow), f, reserved_attributes=list(self.__private_attrs.keys())) - self.time_to_quit = True - not_at_transition_point = False + if self.current_round is self.round_number: + not_at_transition_point = False + self.time_to_quit = True + else: + self.current_round += 1 + self.logger.info(f"Starting round {self.current_round}...") + f_name = "start" continue if len(args) > 0: @@ -308,9 +317,11 @@ def send_task_results(self, collab_name: str, round_number: int, next_step: str, self.logger.info(f"Aggregator step received from {collab_name} for " + f"round number: {round_number}.") - # TODO: Think about taking values from collaborators. - # Do not take rn. - self.round_number = round_number + if round_number is not self.current_round: + self.logger.warning(f"Collaborator sent {round_number} results, aggregator " + + f"is executing {self.current_round}") + else: + self.logger.info(f"Collaborator sent task results for round {self.current_round}") clone = pickle.loads(clone_bytes) self.clones_dict[clone.input] = clone self.next_step = next_step[0]