Skip to content

Commit

Permalink
Changing peristed path to be saved along with the proto model files, …
Browse files Browse the repository at this point in the history
…for Gramine as well

Signed-off-by: Lerer, Eran <eran.lerer@intel.com>
  • Loading branch information
cloudnoize committed Jan 13, 2025
1 parent 00392a0 commit 3ac042c
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 40 deletions.
2 changes: 1 addition & 1 deletion docs/about/features_index/taskrunner.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Configurable Settings
- :code:`rounds_to_train`: (int) Specifies the number of rounds in a federation. A federated learning round is defined as one complete iteration when the collaborators train the model and send the updated model weights back to the aggregator to form a new global model. Within a round, collaborators can train the model for multiple iterations called epochs.
- :code:`write_logs`: (boolean) Metric logging callback feature. By default, logging is done through `tensorboard <https://www.tensorflow.org/tensorboard/get_started>`_ but users can also use custom metric logging function for each task.
- :code:`persist_checkpoint`: (boolean) Specifies whether to enable the storage of a persistent checkpoint in non-volatile storage for recovery purposes. When enabled, the aggregator will restore its state to what it was prior to the restart, ensuring continuity after a restart.

- :code:`persistent_db_path`: (str:path) Defines the persisted database path.

- :class:`Collaborator <openfl.component.Collaborator>`
`openfl.component.Collaborator <https://github.com/intel/openfl/blob/develop/openfl/component/collaborator/collaborator.py>`_
Expand Down
1 change: 1 addition & 0 deletions openfl-workspace/workspace/plan/defaults/aggregator.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ template : openfl.component.Aggregator
settings :
db_store_rounds : 2
persist_checkpoint: True
persistent_db_path: save/tensor.db
21 changes: 12 additions & 9 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def __init__(

self.tensor_db = TensorDB()
if persist_checkpoint:
logger.info("Persistent checkpoint is enabled")
persistent_db_path = persistent_db_path or "tensor.db"
logger.info("Persistent checkpoint is enabled, setting persistent db at path %s",persistent_db_path)
self.persistent_db = PersistentTensorDB(persistent_db_path)
else:
logger.info("Persistent checkpoint is disabled")
Expand All @@ -168,7 +169,7 @@ def __init__(
# these enable getting all tensors for a task
self.collaborator_tasks_results = {} # {TaskResultKey: list of TensorKeys}
self.collaborator_task_weight = {} # {TaskResultKey: data_size}


# maintain a list of collaborators that have completed task and
# reported results in a given round
Expand All @@ -177,8 +178,13 @@ def __init__(
self.lock = Lock()
self.use_delta_updates = use_delta_updates

self.model = None # Initialize the model attribute to None
if self.persistent_db and self._recover():
logger.info("recovered state of aggregator")

# The model is built by recovery if at least one round has finished
if self.model:
logger.info("Model was loaded by recovery")
elif initial_tensor_dict:
self._load_initial_tensors_from_dict(initial_tensor_dict)
self.model = utils.construct_model_proto(
Expand All @@ -204,13 +210,12 @@ def __init__(
# https://github.com/securefederatedai/openfl/pull/1195#discussion_r1879479537
self.callbacks.on_experiment_begin()
self.callbacks.on_round_begin(self.round_number)

def _recover(self):
"""Populates the aggregator state to the state it was prior a restart
"""
recovered = False
# load tensors persistent DB
logger.info("Recovering previous state from persistent storage")
tensor_key_dict = self.persistent_db.load_tensors(self.persistent_db.get_tensors_table_name())
if len(tensor_key_dict) > 0:
logger.info(f"Recovering {len(tensor_key_dict)} model tensors")
Expand All @@ -228,21 +233,19 @@ def _recover(self):
# round number is the current round which is still in process i.e. committed_round_number + 1
self.round_number = committed_round_number + 1
logger.info("Recovery - loaded round number %s and best score %s", self.round_number,self.best_model_score)

next_round_tensor_key_dict = self.persistent_db.load_tensors(self.persistent_db.get_next_round_tensors_table_name())
if len(next_round_tensor_key_dict) > 0:
logger.info(f"Recovering {len(next_round_tensor_key_dict)} next round model tensors")
recovered = True
self.tensor_db.cache_tensor(next_round_tensor_key_dict)


logger.info("Recovery - Finished populating tensor DB")

logger.debug("Recovery - this is the tensor_db after recovery: %s", self.tensor_db)

if self.persistent_db.is_task_table_empty():
logger.debug("task table is empty")
return recovered

logger.info("Recovery - Replaying saved task results")
task_id = 1
while True:
Expand Down
56 changes: 27 additions & 29 deletions openfl/databases/persistent_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import sqlite3
import numpy as np
from threading import Lock
from typing import Dict, Iterator, Optional
import os
from typing import Dict, Optional
pass
import logging

from openfl.utilities import LocalTensor, TensorKey, change_tags
Expand All @@ -27,15 +27,13 @@ class PersistentTensorDB:
NEXT_ROUND_TENSORS_TABLE = "next_round_tensors"
TASK_RESULT_TABLE = "task_results"
KEY_VALUE_TABLE = "key_value_store"
def __init__(self, db_path: str = "") -> None:
def __init__(self, db_path) -> None:
"""Initializes a new instance of the PersistentTensorDB class."""
full_path = "tensordb.sqlite"
if db_path:
full_path = os.path.join(db_path, full_path)
logger.info("Initializing persistent db at %s",full_path)
self.conn = sqlite3.connect(full_path, check_same_thread=False)

logger.info("Initializing persistent db at %s",db_path)
self.conn = sqlite3.connect(db_path, check_same_thread=False)
self.lock = Lock()

cursor = self.conn.cursor()
self._create_model_tensors_table(cursor,PersistentTensorDB.TENSORS_TABLE)
self._create_model_tensors_table(cursor,PersistentTensorDB.NEXT_ROUND_TENSORS_TABLE)
Expand All @@ -57,8 +55,8 @@ def _create_model_tensors_table(self,cursor,table_name) -> None:
)
"""
cursor.execute(query)


def _create_task_results_table(self,cursor) -> None:
"""Creates a table for storing task results."""
query = f"""
Expand All @@ -72,7 +70,7 @@ def _create_task_results_table(self,cursor) -> None:
)
"""
cursor.execute(query)

def _create_key_value_store(self,cursor) -> None:
"""Create a key-value store table for storing additional metadata."""
query = f"""
Expand Down Expand Up @@ -106,7 +104,7 @@ def save_task_results(

# Insert into the database
insert_query = f"""
INSERT INTO {PersistentTensorDB.TASK_RESULT_TABLE}
INSERT INTO {PersistentTensorDB.TASK_RESULT_TABLE}
(collaborator_name, round_number, task_name, data_size, named_tensors)
VALUES (?, ?, ?, ?, ?);
"""
Expand Down Expand Up @@ -171,7 +169,7 @@ def __repr__(self) -> str:
return f"PersistentTensorDB contents:\n{rows}"

def finalize_round(self,tensor_key_dict: Dict[TensorKey, np.ndarray],next_round_tensor_key_dict: Dict[TensorKey, np.ndarray],round_number: int, best_score: float):
"""Finalize a training round by saving tensors, preparing for the next round,
"""Finalize a training round by saving tensors, preparing for the next round,
and updating metadata in the database.
This function performs the following steps as a single transaction:
Expand All @@ -183,17 +181,17 @@ def finalize_round(self,tensor_key_dict: Dict[TensorKey, np.ndarray],next_round_
If any step fails, the transaction is rolled back to ensure data integrity.
Args:
tensor_key_dict (Dict[TensorKey, np.ndarray]):
tensor_key_dict (Dict[TensorKey, np.ndarray]):
A dictionary mapping tensor keys to their corresponding NumPy arrays for the current round.
next_round_tensor_key_dict (Dict[TensorKey, np.ndarray]):
next_round_tensor_key_dict (Dict[TensorKey, np.ndarray]):
A dictionary mapping tensor keys to their corresponding NumPy arrays for the next round.
round_number (int):
round_number (int):
The current training round number.
best_score (float):
best_score (float):
The best score achieved during the current round.
Raises:
RuntimeError: If an error occurs during the transaction, the transaction is rolled back,
RuntimeError: If an error occurs during the transaction, the transaction is rolled back,
and a RuntimeError is raised with the details of the failure.
"""
with self.lock:
Expand All @@ -212,35 +210,35 @@ def finalize_round(self,tensor_key_dict: Dict[TensorKey, np.ndarray],next_round_
# Rollback transaction in case of an error
self.conn.rollback()
raise RuntimeError(f"Failed to finalize round: {e}")

def _persist_tensors(self,cursor,table_name, tensor_key_dict: Dict[TensorKey, np.ndarray]) -> None:
"""Insert a dictionary of tensors into the SQLite as part of transaction"""
for tensor_key, nparray in tensor_key_dict.items():
tensor_name, origin, fl_round, report, tags = tensor_key
serialized_array = self._serialize_array(nparray)
serialized_tags = json.dumps(tags)
serialized_tags = json.dumps(tags)
query = f"""
INSERT INTO {table_name} (tensor_name, origin, round, report, tags, nparray)
VALUES (?, ?, ?, ?, ?, ?)
"""
cursor.execute(query, (tensor_name, origin, fl_round, int(report), serialized_tags, serialized_array))

def _persist_next_round_tensors(self,cursor, tensor_key_dict: Dict[TensorKey, np.ndarray]) -> None:
"""Persisting the last round next_round tensors."""
drop_table_query = f"DROP TABLE IF EXISTS {PersistentTensorDB.NEXT_ROUND_TENSORS_TABLE}"
cursor.execute(drop_table_query)
self._create_model_tensors_table(cursor,PersistentTensorDB.NEXT_ROUND_TENSORS_TABLE)
self._persist_tensors(cursor,PersistentTensorDB.NEXT_ROUND_TENSORS_TABLE,tensor_key_dict)


def _init_task_results_table(self,cursor):
"""
Creates a table for storing task results. Drops the table first if it already exists.
"""
drop_table_query = "DROP TABLE IF EXISTS task_results"
cursor.execute(drop_table_query)
self._create_task_results_table(cursor)

def _save_round_and_best_score(self,cursor, round_number: int, best_score: float) -> None:
"""Save the round number and best score as key-value pairs in the database."""
# Create a table with key-value structure where values can be integer or float
Expand All @@ -258,15 +256,15 @@ def _save_round_and_best_score(self,cursor, round_number: int, best_score: float

def get_tensors_table_name(self) -> str:
return PersistentTensorDB.TENSORS_TABLE

def get_next_round_tensors_table_name(self) -> str:
return PersistentTensorDB.NEXT_ROUND_TENSORS_TABLE

def load_tensors(self,tensor_table) -> Dict[TensorKey, np.ndarray]:
"""Load all tensors from the SQLite database and return them as a dictionary."""
tensor_dict = {}
with self.lock:
cursor = self.conn.cursor()
cursor = self.conn.cursor()
query = f"SELECT tensor_name, origin, round, report, tags, nparray FROM {tensor_table}"
cursor.execute(query)
rows = cursor.fetchall()
Expand All @@ -278,11 +276,11 @@ def load_tensors(self,tensor_table) -> Dict[TensorKey, np.ndarray]:
tensor_dict[tensor_key] = self._deserialize_array(nparray)
return tensor_dict


def get_round_and_best_score(self) -> tuple[int, float]:
"""Retrieve the round number and best score from the database."""
with self.lock:
cursor = self.conn.cursor()
cursor = self.conn.cursor()
# Fetch the round_number
cursor.execute("""
SELECT value FROM key_value_store WHERE key = ?
Expand Down Expand Up @@ -330,7 +328,7 @@ def close(self) -> None:
def is_task_table_empty(self) -> bool:
"""Check if the task table is empty."""
with self.lock:
cursor = self.conn.cursor()
cursor = self.conn.cursor()
cursor.execute("SELECT COUNT(*) FROM task_results")
count = cursor.fetchone()[0]
return count == 0
2 changes: 1 addition & 1 deletion openfl/databases/tensor_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def get_tensor_from_cache(self, tensor_key: TensorKey) -> Optional[np.ndarray]:
if len(df) == 0:
return None
return np.array(df["nparray"].iloc[0])

def get_tensors_by_round_and_tags(self, fl_round: int, tags: tuple) -> dict:
"""Retrieve all tensors that match the specified round and tags.
Expand Down

0 comments on commit 3ac042c

Please sign in to comment.