Skip to content

Commit

Permalink
Handling next round model tensors
Browse files Browse the repository at this point in the history
Signed-off-by: Lerer, Eran <eran.lerer@intel.com>
  • Loading branch information
cloudnoize committed Jan 13, 2025
1 parent 6fd5eeb commit 00392a0
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 123 deletions.
104 changes: 66 additions & 38 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
callbacks: List of callbacks to be used during the experiment.
"""
self.round_number = 0
self.next_model_round_number = 0

if single_col_cert_common_name:
logger.warning(
Expand Down Expand Up @@ -145,6 +146,7 @@ def __init__(
logger.info("Persistent checkpoint is enabled")
self.persistent_db = PersistentTensorDB(persistent_db_path)
else:
logger.info("Persistent checkpoint is disabled")
self.persistent_db = None
# FIXME: I think next line generates an error on the second round
# if it is set to 1 for the aggregator.
Expand Down Expand Up @@ -204,37 +206,63 @@ def __init__(
self.callbacks.on_round_begin(self.round_number)

def _recover(self):
if self.persistent_db.is_task_table_empty():
return False
# load tensors persistent DB
logger.info("Recovering previous state from persistent storage")
tensor_key_dict = self.persistent_db.load_tensors()
if len(tensor_key_dict) > 0:
self.tensor_db.cache_tensor(tensor_key_dict)
logger.debug("Recovery - this is the tensor_db after recovery: %s", self.tensor_db)
logger.info("Recovery - Finished populating tensor DB")
"""Populates the aggregator state to the state it was prior a restart
"""
recovered = False
# load tensors persistent DB
logger.info("Recovering previous state from persistent storage")
tensor_key_dict = self.persistent_db.load_tensors(self.persistent_db.get_tensors_table_name())
if len(tensor_key_dict) > 0:
logger.info(f"Recovering {len(tensor_key_dict)} model tensors")
recovered = True
self.tensor_db.cache_tensor(tensor_key_dict)
committed_round_number, self.best_model_score = self.persistent_db.get_round_and_best_score()
logger.info("Recovery - Setting model proto")
to_proto_tensor_dict = {}
for tk in tensor_key_dict:
tk_name, _, _, _, _ = tk
to_proto_tensor_dict[tk_name] = tensor_key_dict[tk]
self.model = utils.construct_model_proto(
to_proto_tensor_dict, committed_round_number, self.compression_pipeline
)
# round number is the current round which is still in process i.e. committed_round_number + 1
self.round_number = committed_round_number + 1
logger.info("Recovery - loaded round number %s and best score %s", self.round_number,self.best_model_score)
logger.info("Recovery - Replaying saved task results")
task_id = 1
while True:
task_result = self.persistent_db.get_task_result_by_id(task_id)
if not task_result:
break
collaborator_name = task_result["collaborator_name"]
round_number = task_result["round_number"]
task_name = task_result["task_name"]
data_size = task_result["data_size"]
serialized_tensors = task_result["named_tensors"]
named_tensors = [
NamedTensor.FromString(serialized_tensor)
for serialized_tensor in serialized_tensors
]
logger.info("Recovery - Replaying task results %s %s %s",collaborator_name ,round_number, task_name )
self.process_task_results(collaborator_name, round_number, task_name, data_size, named_tensors)
task_id += 1

next_round_tensor_key_dict = self.persistent_db.load_tensors(self.persistent_db.get_next_round_tensors_table_name())
if len(next_round_tensor_key_dict) > 0:
logger.info(f"Recovering {len(next_round_tensor_key_dict)} next round model tensors")
recovered = True
self.tensor_db.cache_tensor(next_round_tensor_key_dict)


logger.info("Recovery - Finished populating tensor DB")
logger.debug("Recovery - this is the tensor_db after recovery: %s", self.tensor_db)

if self.persistent_db.is_task_table_empty():
logger.debug("task table is empty")
return recovered

logger.info("Recovery - Replaying saved task results")
task_id = 1
while True:
task_result = self.persistent_db.get_task_result_by_id(task_id)
if not task_result:
break
recovered = True
collaborator_name = task_result["collaborator_name"]
round_number = task_result["round_number"]
task_name = task_result["task_name"]
data_size = task_result["data_size"]
serialized_tensors = task_result["named_tensors"]
named_tensors = [
NamedTensor.FromString(serialized_tensor)
for serialized_tensor in serialized_tensors
]
logger.info("Recovery - Replaying task results %s %s %s",collaborator_name ,round_number, task_name )
self.process_task_results(collaborator_name, round_number, task_name, data_size, named_tensors)
task_id += 1
return recovered

def _load_initial_tensors(self):
"""Load all of the tensors required to begin federated learning.
Expand Down Expand Up @@ -308,15 +336,14 @@ def _save_model(self, round_number, file_path):
round_number,
)
return
#E.L here we can save the tensor_dict as well. as transaction.
# we can omit the proto save, at the end of the experiment to write the last and best model tensors as proto
# and clean all the db.
if file_path == self.best_state_path:
self.best_tensor_dict = tensor_dict
if file_path == self.last_state_path:
# Transaction to persist/delete all data needed to increment the round
if self.persistent_db:
self.persistent_db.finalize_round(tensor_tuple_dict,self.round_number,self.best_model_score)
if self.next_model_round_number > 0:
next_round_tensors = self.tensor_db.get_tensors_by_round_and_tags(self.next_model_round_number,("model",))
self.persistent_db.finalize_round(tensor_tuple_dict,next_round_tensors,self.round_number,self.best_model_score)
logger.info(
"Persist model and clean task result for round %s",
round_number,
Expand Down Expand Up @@ -662,7 +689,13 @@ def send_local_task_results(
"""
# Save task and its metadata for recovery
serialized_tensors = [tensor.SerializeToString() for tensor in named_tensors]
self.persistent_db and self.persistent_db.save_task_results(collaborator_name,round_number,task_name,data_size,serialized_tensors)
if self.persistent_db:
self.persistent_db.save_task_results(collaborator_name,round_number,task_name,data_size,serialized_tensors)
logger.debug(f"Persisting task results {task_name} from {collaborator_name} round {round_number}")
logger.info(
f"Collaborator {collaborator_name} is sending task results "
f"for {task_name}, round {round_number}"
)
self.process_task_results(collaborator_name,round_number,task_name,data_size,named_tensors)

def process_task_results(
Expand All @@ -687,11 +720,6 @@ def process_task_results(
)
return

logger.info(
f"Collaborator {collaborator_name} is sending task results "
f"for {task_name}, round {round_number}"
)

task_key = TaskResultKey(task_name, collaborator_name, round_number)

# we mustn't have results already
Expand Down Expand Up @@ -931,7 +959,7 @@ def _prepare_trained(self, tensor_name, origin, round_number, report, agg_result
new_model_report,
("model",),
)

self.next_model_round_number = new_model_round_number
# Finally, cache the updated model tensor
self.tensor_db.cache_tensor({final_model_tk: new_model_nparray})

Expand Down
Loading

0 comments on commit 00392a0

Please sign in to comment.