Skip to content

Commit

Permalink
Recover missing state after collaborator restart (#1268)
Browse files Browse the repository at this point in the history
Signed-off-by: Lerer, Eran <eran.lerer@intel.com>
  • Loading branch information
cloudnoize authored Jan 15, 2025
1 parent fdad4fb commit 3375609
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 6 deletions.
24 changes: 19 additions & 5 deletions openfl/component/collaborator/collaborator.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,15 +382,17 @@ def get_data_for_tensorkey(self, tensor_key):
return nparray
prior_round -= 1
logger.info(f"Cannot find any prior version of tensor {tensor_name} locally...")
logger.debug(
"Unable to get tensor from local store..." "attempting to retrieve from client"
)
# Determine whether there are additional compression related
# dependencies.
# Typically, dependencies are only relevant to model layers
tensor_dependencies = self.tensor_codec.find_dependencies(
tensor_key, self.delta_updates
)
logger.debug(
"Unable to get tensor from local store..."
"attempting to retrieve from client len tensor_dependencies"
f" tensor_key {tensor_key}"
)
if len(tensor_dependencies) > 0:
# Resolve dependencies
# tensor_dependencies[0] corresponds to the prior version
Expand All @@ -411,10 +413,10 @@ def get_data_for_tensorkey(self, tensor_key):
self.tensor_db.cache_tensor({new_model_tk: nparray})
else:
logger.info(
"Count not find previous model layer."
"Could not find previous model layer."
"Fetching latest layer from aggregator"
)
# The original model tensor should be fetched from client
# The original model tensor should be fetched from aggregator
nparray = self.get_aggregated_tensor_from_aggregator(
tensor_key, require_lossless=True
)
Expand All @@ -423,6 +425,18 @@ def get_data_for_tensorkey(self, tensor_key):
nparray = self.get_aggregated_tensor_from_aggregator(
tensor_key, require_lossless=True
)
else:
# we should try fetching the tensor from aggregator
tensor_name, origin, round_number, report, tags = tensor_key
tags = (self.collaborator_name,) + tags
tensor_key = (tensor_name, origin, round_number, report, tags)
logger.info(
"Could not find previous model layer."
f"Fetching latest layer from aggregator {tensor_key}"
)
nparray = self.get_aggregated_tensor_from_aggregator(
tensor_key, require_lossless=True
)
else:
logger.debug("Found tensor %s in local TensorDB", tensor_key)

Expand Down
11 changes: 10 additions & 1 deletion openfl/federated/task/runner_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,16 @@ def train_(self, batch_generator, metrics: list = None, **kwargs):
# initialization (build_model).
# If metrics are added (i.e. not a subset of what was originally
# defined) then the model must be recompiled.
results = self.model.get_metrics_result()
try:
results = self.model.get_metrics_result()
except ValueError:
if "batch_size" in kwargs:
batch_size = kwargs["batch_size"]
else:
batch_size = 1
# evaluation needed before metrics can be resolved
self.model.evaluate(self.data_loader.get_valid_loader(batch_size), verbose=1)
results = self.model.get_metrics_result()

# TODO if there are new metrics in the flplan that were not included
# in the originally
Expand Down

0 comments on commit 3375609

Please sign in to comment.