diff --git a/openfl/experimental/component/aggregator/aggregator.py b/openfl/experimental/component/aggregator/aggregator.py index e9e2fb20985..b2c4e34eab8 100644 --- a/openfl/experimental/component/aggregator/aggregator.py +++ b/openfl/experimental/component/aggregator/aggregator.py @@ -121,7 +121,7 @@ def __set_attributes_to_clone(self, clone: Any) -> None: for name, attr in self.__private_attrs.items(): setattr(clone, name, attr) - def __delete_agg_attrs_from_clone(self, clone: Any) -> None: + def __delete_agg_attrs_from_clone(self, clone: Any, replace_str: str = None) -> None: """ Remove aggregator private attributes from FLSpec clone before transition from Aggregator step to collaborator steps. @@ -132,7 +132,10 @@ def __delete_agg_attrs_from_clone(self, clone: Any) -> None: for attr_name in self.__private_attrs: if hasattr(clone, attr_name): self.__private_attrs.update({attr_name: getattr(clone, attr_name)}) - delattr(clone, attr_name) + if replace_str: + setattr(clone, attr_name, replace_str) + else: + delattr(clone, attr_name) def _log_big_warning(self) -> None: """Warn user about single collaborator cert mode.""" @@ -221,11 +224,6 @@ def call_checkpoint(self, ctx: Any, f: Callable, stream_buffer: bytes = None, # Set stream buffer as function parameter setattr(f.__func__, "_stream_buffer", pickle.loads(stream_buffer)) - # Replce reserved attribute values with string - for attr in reserved_attributes: - if hasattr(ctx, attr): - setattr(ctx, attr, "Private attributes: Not Available.") - checkpoint(ctx, f) def get_tasks(self, collaborator_name: str) -> Tuple: @@ -288,7 +286,7 @@ def do_task(self, f_name: str) -> Any: if f.__name__ == "end": f() # Take the checkpoint of "end" step - self.__delete_agg_attrs_from_clone(self.flow) + self.__delete_agg_attrs_from_clone(self.flow, "Private attributes: Not Available.") self.call_checkpoint(self.flow, f) self.__set_attributes_to_clone(self.flow) # self.call_checkpoint(deepcopy(self.flow), f, diff --git a/openfl/experimental/component/collaborator/collaborator.py b/openfl/experimental/component/collaborator/collaborator.py index e90a467c069..044ce128474 100644 --- a/openfl/experimental/component/collaborator/collaborator.py +++ b/openfl/experimental/component/collaborator/collaborator.py @@ -83,7 +83,7 @@ def __set_attributes_to_clone(self, clone: Any) -> None: for name, attr in self.__private_attrs.items(): setattr(clone, name, attr) - def __delete_agg_attrs_from_clone(self, clone: Any) -> None: + def __delete_agg_attrs_from_clone(self, clone: Any, replace_str: str = None) -> None: """ Remove aggregator private attributes from FLSpec clone before transition from Aggregator step to collaborator steps @@ -101,7 +101,10 @@ def __delete_agg_attrs_from_clone(self, clone: Any) -> None: for attr_name in self.__private_attrs: if hasattr(clone, attr_name): self.__private_attrs.update({attr_name: getattr(clone, attr_name)}) - delattr(clone, attr_name) + if replace_str: + setattr(clone, attr_name, replace_str) + else: + delattr(clone, attr_name) def call_checkpoint(self, ctx: Any, f: Callable, stream_buffer: Any) -> None: """ @@ -117,8 +120,7 @@ def call_checkpoint(self, ctx: Any, f: Callable, stream_buffer: Any) -> None: """ self.client.call_checkpoint( self.name, - deepcopy(ctx), pickle.dumps(f), pickle.dumps(stream_buffer), - list(self.__private_attrs.keys()) + pickle.dumps(ctx), pickle.dumps(f), pickle.dumps(stream_buffer) ) def run(self) -> None: @@ -200,7 +202,9 @@ def do_task(self, f_name: str, ctx: Any) -> Tuple: f = getattr(ctx, f_name) f() # Checkpoint the function + self.__delete_agg_attrs_from_clone(ctx, "Private attributes: Not Available.") self.call_checkpoint(ctx, f, f._stream_buffer) + self.__set_attributes_to_clone(ctx) _, f, parent_func = ctx.execute_task_args[:3] # Display transition logs if transition diff --git a/openfl/experimental/transport/grpc/aggregator_client.py b/openfl/experimental/transport/grpc/aggregator_client.py index afcc6d1bf94..82458792772 100644 --- a/openfl/experimental/transport/grpc/aggregator_client.py +++ b/openfl/experimental/transport/grpc/aggregator_client.py @@ -304,20 +304,10 @@ def get_tasks(self, collaborator_name): @_atomic_connection @_resend_data_on_reconnection - def call_checkpoint(self, collaborator_name, ctx, function, stream_buffer, private_attrs): + def call_checkpoint(self, collaborator_name, clone_bytes, function, stream_buffer): """Perform checkpoint for collaborator task.""" self._set_header(collaborator_name) - # Remove private attributes from context - if len(private_attrs) > 0: - import pickle - - for attr in private_attrs: - if hasattr(ctx, attr): - setattr(ctx, attr, "Private attributes: Not Available.") - - clone_bytes = pickle.dumps(ctx) - request = aggregator_pb2.CheckpointRequest( header=self.header, execution_environment=clone_bytes,