Skip to content

Commit

Permalink
Avoided deep copying the context by removing private attributes from …
Browse files Browse the repository at this point in the history
…context before checkpoint

Signed-off-by: Parth Mandaliya <parthx.mandaliya@intel.com>
  • Loading branch information
ParthM-GitHub committed Sep 28, 2023
1 parent 8a40e91 commit 07218ce
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 23 deletions.
14 changes: 6 additions & 8 deletions openfl/experimental/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __set_attributes_to_clone(self, clone: Any) -> None:
for name, attr in self.__private_attrs.items():
setattr(clone, name, attr)

def __delete_agg_attrs_from_clone(self, clone: Any) -> None:
def __delete_agg_attrs_from_clone(self, clone: Any, replace_str: str = None) -> None:
"""
Remove aggregator private attributes from FLSpec clone before
transition from Aggregator step to collaborator steps.
Expand All @@ -132,7 +132,10 @@ def __delete_agg_attrs_from_clone(self, clone: Any) -> None:
for attr_name in self.__private_attrs:
if hasattr(clone, attr_name):
self.__private_attrs.update({attr_name: getattr(clone, attr_name)})
delattr(clone, attr_name)
if replace_str:
setattr(clone, attr_name, replace_str)
else:
delattr(clone, attr_name)

def _log_big_warning(self) -> None:
"""Warn user about single collaborator cert mode."""
Expand Down Expand Up @@ -221,11 +224,6 @@ def call_checkpoint(self, ctx: Any, f: Callable, stream_buffer: bytes = None,
# Set stream buffer as function parameter
setattr(f.__func__, "_stream_buffer", pickle.loads(stream_buffer))

# Replce reserved attribute values with string
for attr in reserved_attributes:
if hasattr(ctx, attr):
setattr(ctx, attr, "Private attributes: Not Available.")

checkpoint(ctx, f)

def get_tasks(self, collaborator_name: str) -> Tuple:
Expand Down Expand Up @@ -288,7 +286,7 @@ def do_task(self, f_name: str) -> Any:
if f.__name__ == "end":
f()
# Take the checkpoint of "end" step
self.__delete_agg_attrs_from_clone(self.flow)
self.__delete_agg_attrs_from_clone(self.flow, "Private attributes: Not Available.")
self.call_checkpoint(self.flow, f)
self.__set_attributes_to_clone(self.flow)
# self.call_checkpoint(deepcopy(self.flow), f,
Expand Down
12 changes: 8 additions & 4 deletions openfl/experimental/component/collaborator/collaborator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __set_attributes_to_clone(self, clone: Any) -> None:
for name, attr in self.__private_attrs.items():
setattr(clone, name, attr)

def __delete_agg_attrs_from_clone(self, clone: Any) -> None:
def __delete_agg_attrs_from_clone(self, clone: Any, replace_str: str = None) -> None:
"""
Remove aggregator private attributes from FLSpec clone before
transition from Aggregator step to collaborator steps
Expand All @@ -101,7 +101,10 @@ def __delete_agg_attrs_from_clone(self, clone: Any) -> None:
for attr_name in self.__private_attrs:
if hasattr(clone, attr_name):
self.__private_attrs.update({attr_name: getattr(clone, attr_name)})
delattr(clone, attr_name)
if replace_str:
setattr(clone, attr_name, replace_str)
else:
delattr(clone, attr_name)

def call_checkpoint(self, ctx: Any, f: Callable, stream_buffer: Any) -> None:
"""
Expand All @@ -117,8 +120,7 @@ def call_checkpoint(self, ctx: Any, f: Callable, stream_buffer: Any) -> None:
"""
self.client.call_checkpoint(
self.name,
deepcopy(ctx), pickle.dumps(f), pickle.dumps(stream_buffer),
list(self.__private_attrs.keys())
pickle.dumps(ctx), pickle.dumps(f), pickle.dumps(stream_buffer)
)

def run(self) -> None:
Expand Down Expand Up @@ -200,7 +202,9 @@ def do_task(self, f_name: str, ctx: Any) -> Tuple:
f = getattr(ctx, f_name)
f()
# Checkpoint the function
self.__delete_agg_attrs_from_clone(ctx, "Private attributes: Not Available.")
self.call_checkpoint(ctx, f, f._stream_buffer)
self.__set_attributes_to_clone(ctx)

_, f, parent_func = ctx.execute_task_args[:3]
# Display transition logs if transition
Expand Down
12 changes: 1 addition & 11 deletions openfl/experimental/transport/grpc/aggregator_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,20 +304,10 @@ def get_tasks(self, collaborator_name):

@_atomic_connection
@_resend_data_on_reconnection
def call_checkpoint(self, collaborator_name, ctx, function, stream_buffer, private_attrs):
def call_checkpoint(self, collaborator_name, clone_bytes, function, stream_buffer):
"""Perform checkpoint for collaborator task."""
self._set_header(collaborator_name)

# Remove private attributes from context
if len(private_attrs) > 0:
import pickle

for attr in private_attrs:
if hasattr(ctx, attr):
setattr(ctx, attr, "Private attributes: Not Available.")

clone_bytes = pickle.dumps(ctx)

request = aggregator_pb2.CheckpointRequest(
header=self.header,
execution_environment=clone_bytes,
Expand Down

0 comments on commit 07218ce

Please sign in to comment.