From 8f140ba079f86e7db29753b5492269fade93612d Mon Sep 17 00:00:00 2001 From: ThibaultFy Date: Wed, 2 Oct 2024 15:56:05 +0200 Subject: [PATCH] chore: copy paste code Signed-off-by: ThibaultFy --- substra/tools/__init__.py | 20 ++ substra/tools/__version__.py | 1 + substra/tools/exceptions.py | 30 +++ substra/tools/function.py | 238 ++++++++++++++++++++ substra/tools/opener.py | 155 +++++++++++++ substra/tools/task_resources.py | 121 ++++++++++ substra/tools/utils.py | 127 +++++++++++ substra/tools/workspace.py | 95 ++++++++ tests/tools/__init__.py | 0 tests/tools/test_aggregatealgo.py | 244 ++++++++++++++++++++ tests/tools/test_compositealgo.py | 312 ++++++++++++++++++++++++++ tests/tools/test_function.py | 343 +++++++++++++++++++++++++++++ tests/tools/test_genericalgo.py | 2 + tests/tools/test_metrics.py | 158 +++++++++++++ tests/tools/test_opener.py | 98 +++++++++ tests/tools/test_task_resources.py | 86 ++++++++ tests/tools/test_utils.py | 44 ++++ tests/tools/test_workflow.py | 140 ++++++++++++ tests/tools/tools_conftest.py | 85 +++++++ tests/tools/utils.py | 65 ++++++ 20 files changed, 2364 insertions(+) create mode 100644 substra/tools/__init__.py create mode 100644 substra/tools/__version__.py create mode 100644 substra/tools/exceptions.py create mode 100644 substra/tools/function.py create mode 100644 substra/tools/opener.py create mode 100644 substra/tools/task_resources.py create mode 100644 substra/tools/utils.py create mode 100644 substra/tools/workspace.py create mode 100644 tests/tools/__init__.py create mode 100644 tests/tools/test_aggregatealgo.py create mode 100644 tests/tools/test_compositealgo.py create mode 100644 tests/tools/test_function.py create mode 100644 tests/tools/test_genericalgo.py create mode 100644 tests/tools/test_metrics.py create mode 100644 tests/tools/test_opener.py create mode 100644 tests/tools/test_task_resources.py create mode 100644 tests/tools/test_utils.py create mode 100644 tests/tools/test_workflow.py create mode 100644 tests/tools/tools_conftest.py create mode 100644 tests/tools/utils.py diff --git a/substra/tools/__init__.py b/substra/tools/__init__.py new file mode 100644 index 00000000..ab92f57f --- /dev/null +++ b/substra/tools/__init__.py @@ -0,0 +1,20 @@ +from substratools.__version__ import __version__ + +from . import function +from . import opener +from .function import execute +from .function import load_performance +from .function import register +from .function import save_performance +from .opener import Opener + +__all__ = [ + "__version__", + function, + opener, + Opener, + execute, + load_performance, + register, + save_performance, +] diff --git a/substra/tools/__version__.py b/substra/tools/__version__.py new file mode 100644 index 00000000..3142c7bf --- /dev/null +++ b/substra/tools/__version__.py @@ -0,0 +1 @@ +__version__ = "0.22.0a2" diff --git a/substra/tools/exceptions.py b/substra/tools/exceptions.py new file mode 100644 index 00000000..982db94d --- /dev/null +++ b/substra/tools/exceptions.py @@ -0,0 +1,30 @@ +class InvalidInterfaceError(Exception): + pass + + +class EmptyInterfaceError(InvalidInterfaceError): + pass + + +class NotAFileError(Exception): + pass + + +class MissingFileError(Exception): + pass + + +class InvalidInputOutputsError(Exception): + pass + + +class InvalidCLIError(Exception): + pass + + +class FunctionNotFoundError(Exception): + pass + + +class ExistingRegisteredFunctionError(Exception): + pass diff --git a/substra/tools/function.py b/substra/tools/function.py new file mode 100644 index 00000000..35b276e9 --- /dev/null +++ b/substra/tools/function.py @@ -0,0 +1,238 @@ +# coding: utf8 +import argparse +import json +import logging +import os +import sys +from copy import deepcopy +from typing import Any +from typing import Callable +from typing import Dict +from typing import Optional + +from substratools import exceptions +from substratools import opener +from substratools import utils +from substratools.exceptions import ExistingRegisteredFunctionError +from substratools.exceptions import FunctionNotFoundError +from substratools.task_resources import StaticInputIdentifiers +from substratools.task_resources import TaskResources +from substratools.workspace import FunctionWorkspace + +logger = logging.getLogger(__name__) + + +def _parser_add_default_arguments(parser): + parser.add_argument( + "--function-name", + type=str, + help="The name of the function to execute from the given file", + ) + parser.add_argument( + "-r", + "--task-properties", + type=str, + default="{}", + help="Define the task properties", + ), + parser.add_argument( + "-d", + "--fake-data", + action="store_true", + default=False, + help="Enable fake data mode", + ) + parser.add_argument( + "--n-fake-samples", + default=None, + type=int, + help="Number of fake samples if fake data is used.", + ) + parser.add_argument( + "--log-path", + default=None, + help="Define log filename path", + ) + parser.add_argument( + "--log-level", + default="info", + choices=utils.MAPPING_LOG_LEVEL.keys(), + help="Choose log level", + ) + parser.add_argument( + "--inputs", + type=str, + default="[]", + help="Inputs of the compute task", + ) + parser.add_argument( + "--outputs", + type=str, + default="[]", + help="Outputs of the compute task", + ) + + +class FunctionRegister: + """Class to create a decorator to register function in substratools. The functions are registered in the _functions + dictionary, with the function.__name__ as key. + Register a function in substratools means that this function can be access by the function.execute functions through + the --function-name CLI argument.""" + + def __init__(self): + self._functions = {} + + def __call__(self, function: Callable, function_name: Optional[str] = None): + """Function called when using an instance of the class as a decorator. + + Args: + function (Callable): function to register in substratools. + function_name (str, optional): function name to register the given function. + If None, function.__name__ is used for registration. + Raises: + ExistingRegisteredFunctionError: Raise if a function with the same function.__name__ + has already been registered in substratools. + + Returns: + Callable: returns the function without decorator + """ + + function_name = function_name or function.__name__ + if function_name not in self._functions: + self._functions[function_name] = function + else: + raise ExistingRegisteredFunctionError("A function with the same name is already registered.") + + return function + + def get_registered_functions(self): + return self._functions + + +# Instance of the decorator to store the function to register in memory. +# Can be imported directly from substratools. +register = FunctionRegister() + + +class FunctionWrapper(object): + """Wrapper to execute a function on the platform.""" + + def __init__(self, workspace: FunctionWorkspace, opener_wrapper: Optional[opener.OpenerWrapper]): + self._workspace = workspace + self._opener_wrapper = opener_wrapper + + def _assert_outputs_exists(self, outputs: Dict[str, str]): + for key, path in outputs.items(): + if os.path.isdir(path): + raise exceptions.NotAFileError(f"Expected output file at {path}, found dir for output `{key}`") + if not os.path.isfile(path): + raise exceptions.MissingFileError(f"Output file {path} used to save argument `{key}` does not exists.") + + @utils.Timer(logger) + def execute( + self, function: Callable, task_properties: dict = {}, fake_data: bool = False, n_fake_samples: int = None + ): + """Execute a compute task""" + + # load inputs + inputs = deepcopy(self._workspace.task_inputs) + + # load data from opener + if self._opener_wrapper: + loaded_datasamples = self._opener_wrapper.get_data(fake_data, n_fake_samples) + + if fake_data: + logger.info("Using fake data with %i fake samples." % int(n_fake_samples)) + + assert ( + StaticInputIdentifiers.datasamples.value not in inputs.keys() + ), f"{StaticInputIdentifiers.datasamples.value} must be an input of kind `datasamples`" + inputs.update({StaticInputIdentifiers.datasamples.value: loaded_datasamples}) + + # load outputs + outputs = deepcopy(self._workspace.task_outputs) + + logger.info("Launching task: executing `%s` function." % function.__name__) + function( + inputs=inputs, + outputs=outputs, + task_properties=task_properties, + ) + + self._assert_outputs_exists( + self._workspace.task_outputs, + ) + + +def _generate_function_cli(): + """Helper to generate a command line interface client.""" + + def _function_from_args(args): + inputs = TaskResources(args.inputs) + outputs = TaskResources(args.outputs) + log_path = args.log_path + chainkeys_path = inputs.chainkeys_path + + workspace = FunctionWorkspace( + log_path=log_path, + chainkeys_path=chainkeys_path, + inputs=inputs, + outputs=outputs, + ) + + utils.configure_logging(workspace.log_path, log_level=args.log_level) + + opener_wrapper = opener.load_from_module( + workspace=workspace, + ) + + return FunctionWrapper(workspace, opener_wrapper) + + def _user_func(args, function): + function_wrapper = _function_from_args(args) + function_wrapper.execute( + function=function, + task_properties=json.loads(args.task_properties), + fake_data=args.fake_data, + n_fake_samples=args.n_fake_samples, + ) + + parser = argparse.ArgumentParser(fromfile_prefix_chars="@") + _parser_add_default_arguments(parser) + parser.set_defaults(func=_user_func) + + return parser + + +def _get_function_from_name(functions: dict, function_name: str): + + if function_name not in functions: + raise FunctionNotFoundError( + f"The function {function_name} given as --function-name argument as not been found." + ) + + return functions[function_name] + + +def save_performance(performance: Any, path: os.PathLike): + with open(path, "w") as f: + json.dump({"all": performance}, f) + + +def load_performance(path: os.PathLike) -> Any: + with open(path, "r") as f: + performance = json.load(f)["all"] + return performance + + +def execute(sysargs=None): + """Launch function command line interface.""" + + cli = _generate_function_cli() + + sysargs = sysargs if sysargs is not None else sys.argv[1:] + args = cli.parse_args(sysargs) + function = _get_function_from_name(register.get_registered_functions(), args.function_name) + args.func(args, function) + + return args diff --git a/substra/tools/opener.py b/substra/tools/opener.py new file mode 100644 index 00000000..2dc846f4 --- /dev/null +++ b/substra/tools/opener.py @@ -0,0 +1,155 @@ +import abc +import logging +import os +import types +from typing import Optional + +from substratools import exceptions +from substratools import utils +from substratools.workspace import OpenerWorkspace + +logger = logging.getLogger(__name__) + + +REQUIRED_FUNCTIONS = set( + [ + "get_data", + "fake_data", + ] +) + + +class Opener(abc.ABC): + """Dataset opener abstract base class. + + To define a new opener script, subclass this class and implement the + following abstract methods: + + - #Opener.get_data() + - #Opener.fake_data() + + # Example + + ```python + import os + import pandas as pd + import string + import numpy as np + + import substratools as tools + + class DummyOpener(tools.Opener): + def get_data(self, folders): + return [ + pd.read_csv(os.path.join(folder, 'train.csv')) + for folder in folders + ] + + def fake_data(self, n_samples): + return [] # compute random fake data + ``` + + # How to test locally an opener script + + An opener can be imported and used in python scripts as would any other class. + + For example, assuming that you have a local file named `opener.py` that contains + an `Opener` named `MyOpener`: + + ```python + import os + from opener import MyOpener + + folders = os.listdir('./sandbox/data_samples/') + + o = MyOpener() + loaded_datasamples = o.get_data(folders) + ``` + """ + + @abc.abstractmethod + def get_data(self, folders): + """Datasamples loader + + # Arguments + + folders: list of folders. Each folder represents a data sample. + + # Returns + + data: data object. + """ + raise NotImplementedError + + @abc.abstractmethod + def fake_data(self, n_samples): + """Generate fake loaded datasamples for offline testing. + + # Arguments + + n_samples (int): number of samples to return + + # Returns + + data: data object. + """ + raise NotImplementedError + + +class OpenerWrapper(object): + """Internal wrapper to call opener interface.""" + + def __init__(self, interface, workspace=None): + assert isinstance(interface, Opener) or isinstance(interface, types.ModuleType) + + self._workspace = workspace or OpenerWorkspace() + self._interface = interface + + @property + def data_folder_paths(self): + return self._workspace.input_data_folder_paths + + def get_data(self, fake_data=False, n_fake_samples=None): + if fake_data: + logger.info("loading data from fake data") + return self._interface.fake_data(n_samples=n_fake_samples) + else: + logger.info("loading data from '{}'".format(self.data_folder_paths)) + return self._interface.get_data(self.data_folder_paths) + + def _assert_output_exists(self, path, key): + + if os.path.isdir(path): + raise exceptions.NotAFileError(f"Expected output file at {path}, found dir for output `{key}`") + if not os.path.isfile(path): + raise exceptions.MissingFileError(f"Output file {path} used to save argument `{key}` does not exists.") + + +def load_from_module(workspace=None) -> Optional[OpenerWrapper]: + """Load opener interface. + + If a workspace is given, the associated opener will be returned. This means that if no + opener_path is defined within the workspace, no opener will be returned + If no workspace is given, the opener interface will be directly loaded as a module. + + Return an OpenerWrapper instance. + """ + if workspace is None: + # import from module + path = None + + elif workspace.opener_path is None: + # no opener within this workspace + return None + + else: + # import opener from workspace specified path + path = workspace.opener_path + + interface = utils.load_interface_from_module( + "opener", + interface_class=Opener, + interface_signature=None, # XXX does not support interface for debugging + path=path, + ) + return OpenerWrapper(interface, workspace=workspace) diff --git a/substra/tools/task_resources.py b/substra/tools/task_resources.py new file mode 100644 index 00000000..012df6ef --- /dev/null +++ b/substra/tools/task_resources.py @@ -0,0 +1,121 @@ +import json +from enum import Enum +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + +from substratools import exceptions + + +class StaticInputIdentifiers(str, Enum): + opener = "opener" + datasamples = "datasamples" + chainkeys = "chainkeys" + rank = "rank" + + +_RESOURCE_ID = "id" +_RESOURCE_VALUE = "value" +_RESOURCE_MULTIPLE = "multiple" + + +def _check_resources_format(resource_list): + + _required_keys = set((_RESOURCE_ID, _RESOURCE_VALUE, _RESOURCE_MULTIPLE)) + _error_message = ( + "`--inputs` and `--outputs` args should be json serialized list of dict. Each dict containing " + f"the following keys: {_required_keys}. {_RESOURCE_ID} and {_RESOURCE_VALUE} must be strings, " + f"{_RESOURCE_MULTIPLE} must be a bool." + ) + + if not isinstance(resource_list, list): + raise exceptions.InvalidCLIError(_error_message) + + if not all([isinstance(d, dict) for d in resource_list]): + raise exceptions.InvalidCLIError(_error_message) + + if not all([set(d.keys()) == _required_keys for d in resource_list]): + raise exceptions.InvalidCLIError(_error_message) + + if not all([isinstance(d[_RESOURCE_MULTIPLE], bool) for d in resource_list]): + raise exceptions.InvalidCLIError(_error_message) + + if not all([isinstance(d[_RESOURCE_ID], str) for d in resource_list]): + raise exceptions.InvalidCLIError(_error_message) + + if not all([isinstance(d[_RESOURCE_VALUE], str) for d in resource_list]): + raise exceptions.InvalidCLIError(_error_message) + + +def _check_resources_multiplicity(resource_dict): + for k, v in resource_dict.items(): + if not v[_RESOURCE_MULTIPLE] and len(v[_RESOURCE_VALUE]) > 1: + raise exceptions.InvalidInputOutputsError(f"There is more than one path for the non multiple resource {k}") + + +class TaskResources: + """TaskResources is created from stdin to provide a nice abstraction over inputs/outputs""" + + _values: Dict[str, List[str]] + + def __init__(self, argstr: str) -> None: + """Argstr is expected to be a JSON array like: + [ + {"id": "local", "value": "/sandbox/output/model/uuid", "multiple": False}, + {"id": "shared", ...} + ] + """ + self._values = {} + resource_list = json.loads(argstr.replace("\\", "/")) + + _check_resources_format(resource_list) + + for item in resource_list: + self._values.setdefault( + item[_RESOURCE_ID], {_RESOURCE_VALUE: [], _RESOURCE_MULTIPLE: item[_RESOURCE_MULTIPLE]} + ) + self._values[item[_RESOURCE_ID]][_RESOURCE_VALUE].append(item[_RESOURCE_VALUE]) + + _check_resources_multiplicity(self._values) + + self.opener_path = self.get_value(StaticInputIdentifiers.opener.value) + self.input_data_folder_paths = self.get_value(StaticInputIdentifiers.datasamples.value) + self.chainkeys_path = self.get_value(StaticInputIdentifiers.chainkeys.value) + + def get_value(self, key: str) -> Optional[Union[List[str], str]]: + """Returns the value for a given key. Return None if there is no matching resource. + Will raise if there is a mismatch between the given multiplicity and the number of returned + elements. + + If multiple is True, will return a list else will return a single value + """ + if key not in self._values: + return None + + val = self._values[key][_RESOURCE_VALUE] + multiple = self._values[key][_RESOURCE_MULTIPLE] + + if multiple: + return val + + return val[0] + + @property + def formatted_dynamic_resources(self) -> Union[List[str], str]: + """Returns all the resources (except the datasamples, the opener and the chainkeys_path under the user format: + A dict where each input is an element where + - the key is the user identifier + - the value is a list of Path for multiple resources and a Path for non multiple resources + """ + + return { + k: self.get_value(k) + for k in self._values.keys() + if k + not in ( + StaticInputIdentifiers.opener.value, + StaticInputIdentifiers.datasamples.value, + StaticInputIdentifiers.chainkeys.value, + ) + } diff --git a/substra/tools/utils.py b/substra/tools/utils.py new file mode 100644 index 00000000..02cf5da9 --- /dev/null +++ b/substra/tools/utils.py @@ -0,0 +1,127 @@ +import importlib +import importlib.util +import inspect +import logging +import os +import sys +import time + +from substratools import exceptions + +logger = logging.getLogger(__name__) + +MAPPING_LOG_LEVEL = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, +} + + +def configure_logging(path=None, log_level="info"): + level = MAPPING_LOG_LEVEL[log_level] + + formatter = logging.Formatter(fmt="%(asctime)s %(levelname)-6s %(name)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S") + + h = logging.StreamHandler() + h.setLevel(level) + h.setFormatter(formatter) + + root = logging.getLogger("substratools") + root.setLevel(level) + root.addHandler(h) + + if path and level == logging.DEBUG: + fh = logging.FileHandler(path) + fh.setLevel(level) + fh.setFormatter(formatter) + + root.addHandler(h) + + +def get_logger(name, path=None, log_level="info"): + new_logger = logging.getLogger(f"substratools.{name}") + configure_logging(path, log_level) + return new_logger + + +class Timer(object): + """This decorator prints the execution time for the decorated function.""" + + def __init__(self, module_logger): + self.module_logger = module_logger + + def __call__(self, func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + self.module_logger.info("{} ran in {}s".format(func.__qualname__, round(end - start, 2))) + return result + + return wrapper + + +def import_module(module_name, code): + if module_name in sys.modules: + logging.warning("Module {} will be overwritten".format(module_name)) + spec = importlib.util.spec_from_loader(module_name, loader=None, origin=module_name) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + exec(code, module.__dict__) + + +def import_module_from_path(path, module_name): + assert os.path.exists(path), "path '{}' not found".format(path) + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec, "could not load spec from path '{}'".format(path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +# TODO: 'load_interface_from_module' is too complex, consider refactoring +def load_interface_from_module(module_name, interface_class, interface_signature=None, path=None): # noqa: C901 + if path: + module = import_module_from_path(path, module_name) + logger.info(f"Module '{module_name}' loaded from path '{path}'") + else: + try: + module = importlib.import_module(module_name) + logger.info(f"Module '{module_name}' imported dynamically; module={module}") + except ImportError: + # XXX don't use ModuleNotFoundError for python3.5 compatibility + raise + + # check if module empty + if not inspect.getmembers(module, lambda m: inspect.isclass(m) or inspect.isfunction(m)): + raise exceptions.EmptyInterfaceError( + f"Module '{module_name}' seems empty: no method/class found in members: '{dir(module)}'" + ) + + # find interface class + for _, obj in inspect.getmembers(module, inspect.isclass): + if issubclass(obj, interface_class): + return obj() # return interface instance + + # backward compatibility; accept methods at module level directly + if interface_signature is None: + class_name = interface_class.__name__ + elements = str(dir(module)) + logger.info(f"Class '{class_name}' not found from: '{elements}'") + raise exceptions.InvalidInterfaceError("Expecting {} subclass in {}".format(class_name, module_name)) + + missing_functions = interface_signature.copy() + for name, obj in inspect.getmembers(module): + if not inspect.isfunction(obj): + continue + try: + missing_functions.remove(name) + except KeyError: + pass + + if missing_functions: + message = "Method(s) {} not implemented".format(", ".join(["'{}'".format(m) for m in missing_functions])) + raise exceptions.InvalidInterfaceError(message) + return module diff --git a/substra/tools/workspace.py b/substra/tools/workspace.py new file mode 100644 index 00000000..777315f2 --- /dev/null +++ b/substra/tools/workspace.py @@ -0,0 +1,95 @@ +import abc +import os + + +def makedir_safe(path): + """Create dir (no failure).""" + try: + os.makedirs(path) + except (FileExistsError, PermissionError): + pass + + +DEFAULT_INPUT_DATA_FOLDER_PATH = "data/" +DEFAULT_INPUT_PREDICTIONS_PATH = "pred/pred" +DEFAULT_OUTPUT_PERF_PATH = "pred/perf.json" +DEFAULT_LOG_PATH = "model/log_model.log" +DEFAULT_CHAINKEYS_PATH = "chainkeys/" + + +class Workspace(abc.ABC): + """Filesystem workspace for task execution.""" + + def __init__(self, dirpath=None): + self._workdir = dirpath if dirpath else os.getcwd() + + def _get_default_path(self, path): + return os.path.join(self._workdir, path) + + def _get_default_subpaths(self, path): + rootpath = os.path.join(self._workdir, path) + if os.path.isdir(rootpath): + return [ + os.path.join(rootpath, subfolder) + for subfolder in os.listdir(rootpath) + if os.path.isdir(os.path.join(rootpath, subfolder)) + ] + return [] + + +class OpenerWorkspace(Workspace): + """Filesystem workspace required by the opener.""" + + def __init__( + self, + dirpath=None, + input_data_folder_paths=None, + ): + super().__init__(dirpath=dirpath) + + assert input_data_folder_paths is None or isinstance(input_data_folder_paths, list) + + self.input_data_folder_paths = input_data_folder_paths or self._get_default_subpaths( + DEFAULT_INPUT_DATA_FOLDER_PATH + ) + + +class FunctionWorkspace(Workspace): + """Filesystem workspace for user defined function execution.""" + + def __init__( + self, + dirpath=None, + log_path=None, + chainkeys_path=None, + inputs=None, + outputs=None, + ): + + super().__init__(dirpath=dirpath) + + self.input_data_folder_paths = ( + self._get_default_subpaths(DEFAULT_INPUT_DATA_FOLDER_PATH) + if inputs is None + else inputs.input_data_folder_paths + ) + + self.log_path = log_path or self._get_default_path(DEFAULT_LOG_PATH) + self.chainkeys_path = chainkeys_path or self._get_default_path(DEFAULT_CHAINKEYS_PATH) + + self.opener_path = inputs.opener_path if inputs else None + + self.task_inputs = inputs.formatted_dynamic_resources if inputs else {} + self.task_outputs = outputs.formatted_dynamic_resources if outputs else {} + + dirs = [ + self.chainkeys_path, + ] + paths = [ + self.log_path, + ] + + dirs.extend([os.path.dirname(p) for p in paths]) + for d in dirs: + if d: + makedir_safe(d) diff --git a/tests/tools/__init__.py b/tests/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/tools/test_aggregatealgo.py b/tests/tools/test_aggregatealgo.py new file mode 100644 index 00000000..9b814db3 --- /dev/null +++ b/tests/tools/test_aggregatealgo.py @@ -0,0 +1,244 @@ +import json +from os import PathLike +from typing import Any +from typing import List +from typing import TypedDict +from uuid import uuid4 + +import pytest + +from substratools import exceptions +from substratools import function +from substratools import opener +from substratools.task_resources import TaskResources +from substratools.workspace import FunctionWorkspace +from tests.tools import utils +from tests.utils import InputIdentifiers +from tests.utils import OutputIdentifiers + + +@pytest.fixture(autouse=True) +def setup(valid_opener): + pass + + +@function.register +def aggregate( + inputs: TypedDict( + "inputs", + {InputIdentifiers.shared: List[PathLike]}, + ), + outputs: TypedDict("outputs", {OutputIdentifiers.shared: PathLike}), + task_properties: TypedDict("task_properties", {InputIdentifiers.rank: int}), +) -> None: + if inputs: + models = utils.load_models(paths=inputs.get(InputIdentifiers.shared, [])) + else: + models = [] + + new_model = {"value": 0} + for m in models: + new_model["value"] += m["value"] + + utils.save_model(model=new_model, path=outputs.get(OutputIdentifiers.shared)) + + +@function.register +def aggregate_predict( + inputs: TypedDict( + "inputs", + { + InputIdentifiers.datasamples: Any, + InputIdentifiers.shared: PathLike, + }, + ), + outputs: TypedDict("outputs", {OutputIdentifiers.shared: PathLike}), + task_properties: TypedDict("task_properties", {InputIdentifiers.rank: int}), +): + model = utils.load_model(path=inputs.get(OutputIdentifiers.shared)) + + # Predict + X = inputs.get(InputIdentifiers.datasamples)[0] + pred = X * model["value"] + + # save predictions + utils.save_predictions(predictions=pred, path=outputs.get(OutputIdentifiers.predictions)) + + +def no_saved_aggregate(inputs, outputs, task_properties): + if inputs: + models = utils.load_models(paths=inputs.get(InputIdentifiers.shared, [])) + else: + models = [] + + new_model = {"value": 0} + for m in models: + new_model["value"] += m["value"] + + utils.no_save_model(model=new_model, path=outputs.get(OutputIdentifiers.shared)) + + +def wrong_saved_aggregate(inputs, outputs, task_properties): + if inputs: + models = utils.load_models(paths=inputs.get(InputIdentifiers.shared, [])) + else: + models = [] + + new_model = {"value": 0} + for m in models: + new_model["value"] += m["value"] + + utils.wrong_save_model(model=new_model, path=outputs.get(OutputIdentifiers.shared)) + + +@pytest.fixture +def create_models(workdir): + model_a = {"value": 1} + model_b = {"value": 2} + + model_dir = workdir / OutputIdentifiers.shared + model_dir.mkdir() + + def _create_model(model_data): + model_name = model_data["value"] + filename = "{}.json".format(model_name) + path = model_dir / filename + path.write_text(json.dumps(model_data)) + return str(path) + + model_datas = [model_a, model_b] + model_filenames = [_create_model(d) for d in model_datas] + + return model_datas, model_filenames + + +def test_aggregate_no_model(valid_function_workspace): + wp = function.FunctionWrapper(workspace=valid_function_workspace, opener_wrapper=None) + wp.execute(function=aggregate) + model = utils.load_model(wp._workspace.task_outputs[OutputIdentifiers.shared]) + assert model["value"] == 0 + + +def test_aggregate_multiple_models(create_models, output_model_path): + _, model_filenames = create_models + + workspace_inputs = TaskResources( + json.dumps([{"id": InputIdentifiers.shared, "value": f, "multiple": True} for f in model_filenames]) + ) + workspace_outputs = TaskResources( + json.dumps([{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}]) + ) + + workspace = FunctionWorkspace(inputs=workspace_inputs, outputs=workspace_outputs) + wp = function.FunctionWrapper(workspace, opener_wrapper=None) + + wp.execute(function=aggregate) + model = utils.load_model(wp._workspace.task_outputs[OutputIdentifiers.shared]) + + assert model["value"] == 3 + + +@pytest.mark.parametrize( + "fake_data,expected_pred,n_fake_samples", + [ + (False, "X", None), + (True, ["Xfake"], 1), + ], +) +def test_predict(fake_data, expected_pred, n_fake_samples, create_models): + _, model_filenames = create_models + + workspace_inputs = TaskResources( + json.dumps([{"id": InputIdentifiers.shared, "value": model_filenames[0], "multiple": False}]) + ) + workspace_outputs = TaskResources( + json.dumps([{"id": OutputIdentifiers.predictions, "value": model_filenames[0], "multiple": False}]) + ) + + workspace = FunctionWorkspace(inputs=workspace_inputs, outputs=workspace_outputs) + + wp = function.FunctionWrapper(workspace, opener_wrapper=opener.load_from_module()) + + wp.execute(function=aggregate_predict, fake_data=fake_data, n_fake_samples=n_fake_samples) + + pred = utils.load_predictions(wp._workspace.task_outputs[OutputIdentifiers.predictions]) + assert pred == expected_pred + + +def test_execute_aggregate(output_model_path): + assert not output_model_path.exists() + + outputs = [{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}] + + function.execute(sysargs=["--function-name", "aggregate", "--outputs", json.dumps(outputs)]) + assert output_model_path.exists() + output_model_path.unlink() + function.execute( + sysargs=["--function-name", "aggregate", "--outputs", json.dumps(outputs), "--log-level", "debug"], + ) + assert output_model_path.exists() + + +def test_execute_aggregate_multiple_models(workdir, create_models, output_model_path): + _, model_filenames = create_models + + assert not output_model_path.exists() + + inputs = [ + {"id": InputIdentifiers.shared, "value": str(workdir / model), "multiple": True} for model in model_filenames + ] + outputs = [ + {"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}, + ] + options = ["--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)] + + command = ["--function-name", "aggregate"] + command.extend(options) + + function.execute(sysargs=command) + assert output_model_path.exists() + with open(output_model_path, "r") as f: + model = json.load(f) + assert model["value"] == 3 + + +def test_execute_predict(workdir, create_models, output_model_path, valid_opener_script): + _, model_filenames = create_models + assert not output_model_path.exists() + + inputs = [ + {"id": InputIdentifiers.shared, "value": str(workdir / model_name), "multiple": True} + for model_name in model_filenames + ] + outputs = [{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}] + options = ["--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)] + command = ["--function-name", "aggregate"] + command.extend(options) + function.execute(sysargs=command) + assert output_model_path.exists() + + # do predict on output model + pred_path = workdir / str(uuid4()) + assert not pred_path.exists() + + pred_inputs = [ + {"id": InputIdentifiers.shared, "value": str(output_model_path), "multiple": False}, + {"id": InputIdentifiers.opener, "value": valid_opener_script, "multiple": False}, + ] + pred_outputs = [{"id": OutputIdentifiers.predictions, "value": str(pred_path), "multiple": False}] + pred_options = ["--inputs", json.dumps(pred_inputs), "--outputs", json.dumps(pred_outputs)] + + function.execute(sysargs=["--function-name", "predict"] + pred_options) + assert pred_path.exists() + with open(pred_path, "r") as f: + pred = json.load(f) + assert pred == "XXX" + pred_path.unlink() + + +@pytest.mark.parametrize("function_to_run", (no_saved_aggregate, wrong_saved_aggregate)) +def test_model_check(function_to_run, valid_function_workspace): + wp = function.FunctionWrapper(valid_function_workspace, opener_wrapper=None) + + with pytest.raises(exceptions.MissingFileError): + wp.execute(function=function_to_run) diff --git a/tests/tools/test_compositealgo.py b/tests/tools/test_compositealgo.py new file mode 100644 index 00000000..89beb6d9 --- /dev/null +++ b/tests/tools/test_compositealgo.py @@ -0,0 +1,312 @@ +import json +import os +from typing import Any +from typing import Optional +from typing import TypedDict + +import pytest + +from substratools import exceptions +from substratools import function +from substratools import opener +from substratools.task_resources import TaskResources +from substratools.workspace import FunctionWorkspace +from tests.tools import utils +from tests.utils import InputIdentifiers +from tests.utils import OutputIdentifiers + + +@pytest.fixture(autouse=True) +def setup(valid_opener): + pass + + +def fake_data_train(inputs: dict, outputs: dict, task_properties: dict): + utils.save_model(model=inputs[InputIdentifiers.datasamples][0], path=outputs["local"]) + utils.save_model(model=inputs[InputIdentifiers.datasamples][1], path=outputs["shared"]) + + +def fake_data_predict(inputs: dict, outputs: dict, task_properties: dict) -> None: + utils.save_model(model=inputs[InputIdentifiers.datasamples][0], path=outputs["predictions"]) + + +def train( + inputs: TypedDict( + "inputs", + { + InputIdentifiers.datasamples: Any, + InputIdentifiers.local: Optional[os.PathLike], + InputIdentifiers.shared: Optional[os.PathLike], + }, + ), + outputs: TypedDict( + "outputs", + { + OutputIdentifiers.local: os.PathLike, + OutputIdentifiers.shared: os.PathLike, + }, + ), + task_properties: TypedDict("task_properties", {InputIdentifiers.rank: int}), +): + # init phase + # load models + head_model = utils.load_model(path=inputs.get(InputIdentifiers.local)) + trunk_model = utils.load_model(path=inputs.get(InputIdentifiers.shared)) + + if head_model and trunk_model: + new_head_model = dict(head_model) + new_trunk_model = dict(trunk_model) + else: + new_head_model = {"value": 0} + new_trunk_model = {"value": 0} + + # train models + new_head_model["value"] += 1 + new_trunk_model["value"] -= 1 + + # save model + utils.save_model(model=new_head_model, path=outputs.get(OutputIdentifiers.local)) + utils.save_model(model=new_trunk_model, path=outputs.get(OutputIdentifiers.shared)) + + +def predict( + inputs: TypedDict( + "inputs", + { + InputIdentifiers.datasamples: Any, + InputIdentifiers.local: os.PathLike, + InputIdentifiers.shared: os.PathLike, + }, + ), + outputs: TypedDict( + "outputs", + { + OutputIdentifiers.predictions: os.PathLike, + }, + ), + task_properties: TypedDict("task_properties", {InputIdentifiers.rank: int}), +): + + # init phase + # load models + head_model = utils.load_model(path=inputs.get(InputIdentifiers.local)) + trunk_model = utils.load_model(path=inputs.get(InputIdentifiers.shared)) + + pred = list(range(head_model["value"], trunk_model["value"])) + + # save predictions + utils.save_predictions(predictions=pred, path=outputs.get(OutputIdentifiers.predictions)) + + +def no_saved_trunk_train(inputs, outputs, task_properties): + # init phase + # load models + head_model = utils.load_model(path=inputs.get(InputIdentifiers.local)) + trunk_model = utils.load_model(path=inputs.get(InputIdentifiers.shared)) + + if head_model and trunk_model: + new_head_model = dict(head_model) + new_trunk_model = dict(trunk_model) + else: + new_head_model = {"value": 0} + new_trunk_model = {"value": 0} + + # train models + new_head_model["value"] += 1 + new_trunk_model["value"] -= 1 + + # save model + utils.save_model(model=new_head_model, path=outputs.get(OutputIdentifiers.local)) + utils.no_save_model(model=new_trunk_model, path=outputs.get(OutputIdentifiers.shared)) + + +def no_saved_head_train(inputs, outputs, task_properties): + # init phase + # load models + head_model = utils.load_model(path=inputs.get(InputIdentifiers.local)) + trunk_model = utils.load_model(path=inputs.get(InputIdentifiers.shared)) + + if head_model and trunk_model: + new_head_model = dict(head_model) + new_trunk_model = dict(trunk_model) + else: + new_head_model = {"value": 0} + new_trunk_model = {"value": 0} + + # train models + new_head_model["value"] += 1 + new_trunk_model["value"] -= 1 + + # save model + utils.no_save_model(model=new_head_model, path=outputs.get(OutputIdentifiers.local)) + utils.save_model(model=new_trunk_model, path=outputs.get(OutputIdentifiers.shared)) + + +def wrong_saved_trunk_train(inputs, outputs, task_properties): + # init phase + # load models + head_model = utils.load_model(path=inputs.get(InputIdentifiers.local)) + trunk_model = utils.load_model(path=inputs.get(InputIdentifiers.shared)) + + if head_model and trunk_model: + new_head_model = dict(head_model) + new_trunk_model = dict(trunk_model) + else: + new_head_model = {"value": 0} + new_trunk_model = {"value": 0} + + # train models + new_head_model["value"] += 1 + new_trunk_model["value"] -= 1 + + # save model + utils.save_model(model=new_head_model, path=outputs.get(OutputIdentifiers.local)) + utils.wrong_save_model(model=new_trunk_model, path=outputs.get(OutputIdentifiers.shared)) + + +def wrong_saved_head_train(inputs, outputs, task_properties): + # init phase + # load models + head_model = utils.load_model(path=inputs.get(InputIdentifiers.local)) + trunk_model = utils.load_model(path=inputs.get(InputIdentifiers.shared)) + + if head_model and trunk_model: + new_head_model = dict(head_model) + new_trunk_model = dict(trunk_model) + else: + new_head_model = {"value": 0} + new_trunk_model = {"value": 0} + + # train models + new_head_model["value"] += 1 + new_trunk_model["value"] -= 1 + + # save model + utils.wrong_save_model(model=new_head_model, path=outputs.get(OutputIdentifiers.local)) + utils.save_model(model=new_trunk_model, path=outputs.get(OutputIdentifiers.shared)) + + +@pytest.fixture +def train_outputs(output_model_path, output_model_path_2): + outputs = TaskResources( + json.dumps( + [ + {"id": "local", "value": str(output_model_path), "multiple": False}, + {"id": "shared", "value": str(output_model_path_2), "multiple": False}, + ] + ) + ) + return outputs + + +@pytest.fixture +def composite_inputs(create_models): + _, local_path, shared_path = create_models + inputs = TaskResources( + json.dumps( + [ + {"id": InputIdentifiers.local, "value": str(local_path), "multiple": False}, + {"id": InputIdentifiers.shared, "value": str(shared_path), "multiple": False}, + ] + ) + ) + + return inputs + + +@pytest.fixture +def predict_outputs(output_model_path): + outputs = TaskResources( + json.dumps([{"id": OutputIdentifiers.predictions, "value": str(output_model_path), "multiple": False}]) + ) + return outputs + + +@pytest.fixture +def create_models(workdir): + head_model = {"value": 1} + trunk_model = {"value": -1} + + def _create_model(model_data, name): + filename = "{}.json".format(name) + path = workdir / filename + path.write_text(json.dumps(model_data)) + return path + + head_path = _create_model(head_model, "head") + trunk_path = _create_model(trunk_model, "trunk") + + return ( + [head_model, trunk_model], + head_path, + trunk_path, + ) + + +def test_train_no_model(train_outputs): + + dummy_train_workspace = FunctionWorkspace(outputs=train_outputs) + dummy_train_wrapper = function.FunctionWrapper(dummy_train_workspace, None) + dummy_train_wrapper.execute(function=train) + local_model = utils.load_model(dummy_train_wrapper._workspace.task_outputs["local"]) + shared_model = utils.load_model(dummy_train_wrapper._workspace.task_outputs["shared"]) + + assert local_model["value"] == 1 + assert shared_model["value"] == -1 + + +def test_train_input_head_trunk_models(composite_inputs, train_outputs): + + dummy_train_workspace = FunctionWorkspace(inputs=composite_inputs, outputs=train_outputs) + dummy_train_wrapper = function.FunctionWrapper(dummy_train_workspace, None) + dummy_train_wrapper.execute(function=train) + local_model = utils.load_model(dummy_train_wrapper._workspace.task_outputs["local"]) + shared_model = utils.load_model(dummy_train_wrapper._workspace.task_outputs["shared"]) + + assert local_model["value"] == 2 + assert shared_model["value"] == -2 + + +@pytest.mark.parametrize("n_fake_samples", (0, 1, 2)) +def test_train_fake_data(train_outputs, n_fake_samples): + _opener = opener.load_from_module() + dummy_train_workspace = FunctionWorkspace(outputs=train_outputs) + dummy_train_wrapper = function.FunctionWrapper(dummy_train_workspace, _opener) + dummy_train_wrapper.execute(function=fake_data_train, fake_data=bool(n_fake_samples), n_fake_samples=n_fake_samples) + + local_model = utils.load_model(dummy_train_wrapper._workspace.task_outputs[OutputIdentifiers.local]) + shared_model = utils.load_model(dummy_train_wrapper._workspace.task_outputs[OutputIdentifiers.shared]) + + assert local_model == _opener.get_data(fake_data=bool(n_fake_samples), n_fake_samples=n_fake_samples)[0] + assert shared_model == _opener.get_data(fake_data=bool(n_fake_samples), n_fake_samples=n_fake_samples)[1] + + +@pytest.mark.parametrize("n_fake_samples", (0, 1, 2)) +def test_predict_fake_data(composite_inputs, predict_outputs, n_fake_samples): + _opener = opener.load_from_module() + dummy_train_workspace = FunctionWorkspace(inputs=composite_inputs, outputs=predict_outputs) + dummy_train_wrapper = function.FunctionWrapper(dummy_train_workspace, _opener) + dummy_train_wrapper.execute( + function=fake_data_predict, fake_data=bool(n_fake_samples), n_fake_samples=n_fake_samples + ) + + predictions = utils.load_model(dummy_train_wrapper._workspace.task_outputs[OutputIdentifiers.predictions]) + + assert predictions == _opener.get_data(fake_data=bool(n_fake_samples), n_fake_samples=n_fake_samples)[0] + + +@pytest.mark.parametrize( + "function_to_run", + ( + no_saved_head_train, + no_saved_trunk_train, + wrong_saved_head_train, + wrong_saved_trunk_train, + ), +) +def test_model_check(function_to_run, train_outputs): + dummy_train_workspace = FunctionWorkspace(outputs=train_outputs) + wp = function.FunctionWrapper(workspace=dummy_train_workspace, opener_wrapper=None) + + with pytest.raises(exceptions.MissingFileError): + wp.execute(function_to_run) diff --git a/tests/tools/test_function.py b/tests/tools/test_function.py new file mode 100644 index 00000000..c76c6170 --- /dev/null +++ b/tests/tools/test_function.py @@ -0,0 +1,343 @@ +import json +import shutil +from os import PathLike +from pathlib import Path +from typing import Any +from typing import List +from typing import Optional +from typing import Tuple +from typing import TypedDict + +import pytest + +from substratools import exceptions +from substratools import function +from substratools import opener +from substratools.task_resources import StaticInputIdentifiers +from substratools.task_resources import TaskResources +from substratools.workspace import FunctionWorkspace +from tests.tools import utils +from tests.utils import InputIdentifiers +from tests.utils import OutputIdentifiers + + +@pytest.fixture(autouse=True) +def setup(valid_opener): + pass + + +@function.register +def train( + inputs: TypedDict( + "inputs", + { + InputIdentifiers.datasamples: Tuple[List["str"], List[int]], # cf valid_opener_code + InputIdentifiers.shared: Optional[ + PathLike + ], # inputs contains a dict where keys are identifiers and values are paths on the disk + }, + ), + outputs: TypedDict( + "outputs", {OutputIdentifiers.shared: PathLike} + ), # outputs contains a dict where keys are identifiers and values are paths on disk + task_properties: TypedDict("task_properties", {InputIdentifiers.rank: int}), +) -> None: + # TODO: checks on data + # load models + if inputs: + models = utils.load_models(paths=inputs.get(InputIdentifiers.shared, [])) + else: + models = [] + # init model + new_model = {"value": 0} + + # train (just add the models values) + for m in models: + assert isinstance(m, dict) + assert "value" in m + new_model["value"] += m["value"] + + # save model + utils.save_model(model=new_model, path=outputs.get(OutputIdentifiers.shared)) + + +@function.register +def predict( + inputs: TypedDict("inputs", {InputIdentifiers.datasamples: Any, InputIdentifiers.shared: List[PathLike]}), + outputs: TypedDict("outputs", {OutputIdentifiers.predictions: PathLike}), + task_properties: TypedDict("task_properties", {InputIdentifiers.rank: int}), +) -> None: + # TODO: checks on data + + # load_model + model = utils.load_model(path=inputs.get(InputIdentifiers.shared)) + + # predict + X = inputs.get(InputIdentifiers.datasamples)[0] + pred = X * model["value"] + + # save predictions + utils.save_predictions(predictions=pred, path=outputs.get(OutputIdentifiers.predictions)) + + +@function.register +def no_saved_train(inputs, outputs, task_properties): + # TODO: checks on data + # load models + if inputs: + models = utils.load_models(paths=inputs.get(InputIdentifiers.shared, [])) + else: + models = [] + # init model + new_model = {"value": 0} + + # train (just add the models values) + for m in models: + assert isinstance(m, dict) + assert "value" in m + new_model["value"] += m["value"] + + # save model + utils.no_save_model(model=new_model, path=outputs.get(OutputIdentifiers.shared)) + + +@function.register +def wrong_saved_train(inputs, outputs, task_properties): + # TODO: checks on data + # load models + if inputs: + models = utils.load_models(paths=inputs.get(InputIdentifiers.shared, [])) + else: + models = [] + # init model + new_model = {"value": 0} + + # train (just add the models values) + for m in models: + assert isinstance(m, dict) + assert "value" in m + new_model["value"] += m["value"] + + # save model + utils.wrong_save_model(model=new_model, path=outputs.get(OutputIdentifiers.shared)) + + +@pytest.fixture +def create_models(workdir): + model_a = {"value": 1} + model_b = {"value": 2} + + model_dir = workdir / "model" + model_dir.mkdir() + + def _create_model(model_data): + model_name = model_data["value"] + filename = "{}.json".format(model_name) + path = model_dir / filename + path.write_text(json.dumps(model_data)) + return str(path) + + model_datas = [model_a, model_b] + model_filenames = [_create_model(d) for d in model_datas] + + return model_datas, model_filenames + + +def test_train_no_model(valid_function_workspace): + wp = function.FunctionWrapper(valid_function_workspace, opener_wrapper=None) + wp.execute(function=train) + model = utils.load_model(wp._workspace.task_outputs[OutputIdentifiers.shared]) + assert model["value"] == 0 + + +def test_train_multiple_models(output_model_path, create_models): + _, model_filenames = create_models + + workspace_inputs = TaskResources( + json.dumps([{"id": InputIdentifiers.shared, "value": str(f), "multiple": True} for f in model_filenames]) + ) + workspace_outputs = TaskResources( + json.dumps([{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}]) + ) + + workspace = FunctionWorkspace(inputs=workspace_inputs, outputs=workspace_outputs) + wp = function.FunctionWrapper(workspace=workspace, opener_wrapper=None) + + wp.execute(function=train) + model = utils.load_model(wp._workspace.task_outputs[OutputIdentifiers.shared]) + + assert model["value"] == 3 + + +def test_train_fake_data(output_model_path): + workspace_outputs = TaskResources( + json.dumps([{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}]) + ) + + workspace = FunctionWorkspace(outputs=workspace_outputs) + wp = function.FunctionWrapper(workspace=workspace, opener_wrapper=None) + wp.execute(function=train, fake_data=True, n_fake_samples=2) + model = utils.load_model(wp._workspace.task_outputs[OutputIdentifiers.shared]) + assert model["value"] == 0 + + +@pytest.mark.parametrize( + "fake_data,expected_pred,n_fake_samples", + [ + (False, "X", None), + (True, ["Xfake"], 1), + ], +) +def test_predict(fake_data, expected_pred, n_fake_samples, create_models, output_model_path): + _, model_filenames = create_models + + workspace_inputs = TaskResources( + json.dumps([{"id": InputIdentifiers.shared, "value": model_filenames[0], "multiple": False}]) + ) + workspace_outputs = TaskResources( + json.dumps([{"id": OutputIdentifiers.predictions, "value": str(output_model_path), "multiple": False}]) + ) + + workspace = FunctionWorkspace(inputs=workspace_inputs, outputs=workspace_outputs) + wp = function.FunctionWrapper(workspace=workspace, opener_wrapper=opener.load_from_module()) + wp.execute(function=predict, fake_data=fake_data, n_fake_samples=n_fake_samples) + + pred = utils.load_predictions(wp._workspace.task_outputs["predictions"]) + assert pred == expected_pred + + +def test_execute_train(workdir, output_model_path): + inputs = [ + { + "id": StaticInputIdentifiers.datasamples.value, + "value": str(workdir / "datasamples_unused"), + "multiple": True, + }, + ] + outputs = [ + {"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}, + ] + options = ["--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)] + + assert not output_model_path.exists() + + function.execute(sysargs=["--function-name", "train"] + options) + assert output_model_path.exists() + + function.execute( + sysargs=["--function-name", "train", "--fake-data", "--n-fake-samples", "1", "--outputs", json.dumps(outputs)] + ) + assert output_model_path.exists() + + function.execute(sysargs=["--function-name", "train", "--log-level", "debug"] + options) + assert output_model_path.exists() + + +def test_execute_train_multiple_models(workdir, output_model_path, create_models): + _, model_filenames = create_models + + output_model_path = Path(output_model_path) + + assert not output_model_path.exists() + pred_path = workdir / "pred" + assert not pred_path.exists() + + inputs = [ + {"id": InputIdentifiers.shared, "value": str(workdir / model), "multiple": True} for model in model_filenames + ] + + outputs = [ + {"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}, + ] + options = ["--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)] + + command = ["--function-name", "train"] + command.extend(options) + + function.execute(sysargs=command) + assert output_model_path.exists() + with open(output_model_path, "r") as f: + model = json.load(f) + assert model["value"] == 3 + + assert not pred_path.exists() + + +def test_execute_predict(workdir, output_model_path, create_models, valid_opener_script): + _, model_filenames = create_models + pred_path = workdir / "pred" + train_inputs = [ + {"id": InputIdentifiers.shared, "value": str(workdir / model), "multiple": True} for model in model_filenames + ] + + train_outputs = [{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}] + train_options = ["--inputs", json.dumps(train_inputs), "--outputs", json.dumps(train_outputs)] + + output_model_path = Path(output_model_path) + # first train models + assert not pred_path.exists() + command = ["--function-name", "train"] + command.extend(train_options) + function.execute(sysargs=command) + assert output_model_path.exists() + + # do predict on output model + pred_inputs = [ + {"id": InputIdentifiers.opener, "value": valid_opener_script, "multiple": False}, + {"id": InputIdentifiers.shared, "value": str(output_model_path), "multiple": False}, + ] + pred_outputs = [{"id": OutputIdentifiers.predictions, "value": str(pred_path), "multiple": False}] + pred_options = ["--inputs", json.dumps(pred_inputs), "--outputs", json.dumps(pred_outputs)] + + assert not pred_path.exists() + function.execute(sysargs=["--function-name", "predict"] + pred_options) + assert pred_path.exists() + with open(pred_path, "r") as f: + pred = json.load(f) + assert pred == "XXX" + pred_path.unlink() + + # do predict with different model paths + input_models_dir = workdir / "other_models" + input_models_dir.mkdir() + input_model_path = input_models_dir / "supermodel" + shutil.move(output_model_path, input_model_path) + + pred_inputs = [ + {"id": InputIdentifiers.shared, "value": str(input_model_path), "multiple": False}, + {"id": InputIdentifiers.opener, "value": valid_opener_script, "multiple": False}, + ] + pred_outputs = [{"id": OutputIdentifiers.predictions, "value": str(pred_path), "multiple": False}] + pred_options = ["--inputs", json.dumps(pred_inputs), "--outputs", json.dumps(pred_outputs)] + + assert not pred_path.exists() + function.execute(sysargs=["--function-name", "predict"] + pred_options) + assert pred_path.exists() + with open(pred_path, "r") as f: + pred = json.load(f) + assert pred == "XXX" + + +@pytest.mark.parametrize("function_to_run", (no_saved_train, wrong_saved_train)) +def test_model_check(valid_function_workspace, function_to_run): + wp = function.FunctionWrapper(workspace=valid_function_workspace, opener_wrapper=None) + + with pytest.raises(exceptions.MissingFileError): + wp.execute(function=function_to_run) + + +def test_function_not_found(): + with pytest.raises(exceptions.FunctionNotFoundError): + function.execute(sysargs=["--function-name", "imaginary_function"]) + + +def test_function_name_already_register(): + @function.register + def fake_function(): + pass + + with pytest.raises(exceptions.ExistingRegisteredFunctionError): + + @function.register + def fake_function(): + pass diff --git a/tests/tools/test_genericalgo.py b/tests/tools/test_genericalgo.py new file mode 100644 index 00000000..6c20a6e3 --- /dev/null +++ b/tests/tools/test_genericalgo.py @@ -0,0 +1,2 @@ +# TODO: As the implementation is going to change from a class to a function +# decorator, those test will be added when the implementation is stable diff --git a/tests/tools/test_metrics.py b/tests/tools/test_metrics.py new file mode 100644 index 00000000..c8d4944c --- /dev/null +++ b/tests/tools/test_metrics.py @@ -0,0 +1,158 @@ +import json +import uuid +from os import PathLike +from typing import Any +from typing import TypedDict + +import numpy as np +import pytest + +from substratools import function +from substratools import load_performance +from substratools import opener +from substratools import save_performance +from substratools.task_resources import TaskResources +from substratools.workspace import FunctionWorkspace +from tests.tools import utils +from tests.utils import InputIdentifiers +from tests.utils import OutputIdentifiers + + +@pytest.fixture() +def write_pred_file(workdir): + pred_file = str(workdir / str(uuid.uuid4())) + data = list(range(3, 6)) + with open(pred_file, "w") as f: + json.dump(data, f) + return pred_file, data + + +@pytest.fixture +def inputs(workdir, valid_opener_script, write_pred_file): + return [ + {"id": InputIdentifiers.predictions, "value": str(write_pred_file[0]), "multiple": False}, + {"id": InputIdentifiers.datasamples, "value": str(workdir / "datasamples_unused"), "multiple": True}, + {"id": InputIdentifiers.opener, "value": str(valid_opener_script), "multiple": False}, + ] + + +@pytest.fixture +def outputs(workdir): + return [{"id": OutputIdentifiers.performance, "value": str(workdir / str(uuid.uuid4())), "multiple": False}] + + +@pytest.fixture(autouse=True) +def setup(valid_opener, write_pred_file): + pass + + +@function.register +def score( + inputs: TypedDict("inputs", {InputIdentifiers.datasamples: Any, InputIdentifiers.predictions: Any}), + outputs: TypedDict("outputs", {OutputIdentifiers.performance: PathLike}), + task_properties: TypedDict("task_properties", {InputIdentifiers.rank: int}), +): + y_true = inputs.get(InputIdentifiers.datasamples)[1] + y_pred_path = inputs.get(InputIdentifiers.predictions) + y_pred = utils.load_predictions(y_pred_path) + + score = sum(y_true) + sum(y_pred) + + save_performance(performance=score, path=outputs.get(OutputIdentifiers.performance)) + + +def test_score(workdir, write_pred_file): + inputs = TaskResources( + json.dumps( + [ + {"id": InputIdentifiers.predictions, "value": str(write_pred_file[0]), "multiple": False}, + ] + ) + ) + outputs = TaskResources( + json.dumps( + [{"id": OutputIdentifiers.performance, "value": str(workdir / str(uuid.uuid4())), "multiple": False}] + ) + ) + workspace = FunctionWorkspace(inputs=inputs, outputs=outputs) + wp = function.FunctionWrapper(workspace=workspace, opener_wrapper=opener.load_from_module()) + wp.execute(function=score) + s = load_performance(wp._workspace.task_outputs[OutputIdentifiers.performance]) + assert s == 15 + + +def test_execute(inputs, outputs): + perf_path = outputs[0]["value"] + function.execute( + sysargs=["--function-name", "score", "--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)], + ) + s = load_performance(perf_path) + assert s == 15 + + +@pytest.mark.parametrize( + "fake_data_mode,expected_score", + [ + ([], 15), + (["--fake-data", "--n-fake-samples", "3"], 12), + ], +) +def test_execute_fake_data_modes(fake_data_mode, expected_score, inputs, outputs): + perf_path = outputs[0]["value"] + function.execute( + sysargs=fake_data_mode + + ["--function-name", "score", "--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)], + ) + s = load_performance(perf_path) + assert s == expected_score + + +def test_execute_np(inputs, outputs): + @function.register + def float_np_score( + inputs, + outputs, + task_properties: dict, + ): + save_performance(np.float64(0.99), outputs.get(OutputIdentifiers.performance)) + + perf_path = outputs[0]["value"] + function.execute( + sysargs=["--function-name", "float_np_score", "--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)], + ) + s = load_performance(perf_path) + assert s == pytest.approx(0.99) + + +def test_execute_int(inputs, outputs): + @function.register + def int_score( + inputs, + outputs, + task_properties: dict, + ): + save_performance(int(1), outputs.get(OutputIdentifiers.performance)) + + perf_path = outputs[0]["value"] + function.execute( + sysargs=["--function-name", "int_score", "--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)], + ) + s = load_performance(perf_path) + assert s == 1 + + +def test_execute_dict(inputs, outputs): + @function.register + def dict_score( + inputs, + outputs, + task_properties: dict, + ): + save_performance({"a": 1}, outputs.get(OutputIdentifiers.performance)) + + perf_path = outputs[0]["value"] + function.execute( + sysargs=["--function-name", "dict_score", "--inputs", json.dumps(inputs), "--outputs", json.dumps(outputs)], + ) + s = load_performance(perf_path) + assert s["a"] == 1 diff --git a/tests/tools/test_opener.py b/tests/tools/test_opener.py new file mode 100644 index 00000000..58e88697 --- /dev/null +++ b/tests/tools/test_opener.py @@ -0,0 +1,98 @@ +import os + +import pytest + +from substratools import exceptions +from substratools.opener import Opener +from substratools.opener import OpenerWrapper +from substratools.opener import load_from_module +from substratools.utils import import_module +from substratools.utils import load_interface_from_module +from substratools.workspace import DEFAULT_INPUT_DATA_FOLDER_PATH + + +@pytest.fixture +def tmp_cwd(tmp_path): + # create a temporary current working directory + new_dir = tmp_path / "workspace" + new_dir.mkdir() + + old_dir = os.getcwd() + os.chdir(new_dir) + + yield new_dir + + os.chdir(old_dir) + + +def test_load_opener_not_found(tmp_cwd): + with pytest.raises(ImportError): + load_from_module() + + +def test_load_invalid_opener(tmp_cwd): + invalid_script = """ +def get_data(): + raise NotImplementedError +""" + + import_module("opener", invalid_script) + + with pytest.raises(exceptions.InvalidInterfaceError): + load_from_module() + + +def test_load_opener_as_class(tmp_cwd): + script = """ +from substratools import Opener +class MyOpener(Opener): + def get_data(self, folders): + return 'data_class' + def fake_data(self, n_samples): + return 'fake_data' +""" + + import_module("opener", script) + + o = load_from_module() + assert o.get_data() == "data_class" + + +def test_load_opener_from_path(tmp_cwd, valid_opener_code): + dirpath = tmp_cwd / "myopener" + dirpath.mkdir() + path = dirpath / "my_opener.py" + path.write_text(valid_opener_code) + + interface = load_interface_from_module( + "opener", + interface_class=Opener, + interface_signature=None, # XXX does not support interface for debugging + path=path, + ) + o = OpenerWrapper(interface, workspace=None) + assert o.get_data()[0] == "X" + + +def test_opener_check_folders(tmp_cwd): + script = """ +from substratools import Opener +class MyOpener(Opener): + def get_data(self, folders): + assert len(folders) == 5 + return 'data_class' + def fake_data(self, n_samples): + return 'fake_data_class' +""" + + import_module("opener", script) + + o = load_from_module() + + # create some data folders + data_root_path = os.path.join(o._workspace._workdir, DEFAULT_INPUT_DATA_FOLDER_PATH) + data_paths = [os.path.join(data_root_path, str(i)) for i in range(5)] + [os.makedirs(p) for p in data_paths] + + o._workspace.input_data_folder_paths = data_paths + assert o.get_data() == "data_class" diff --git a/tests/tools/test_task_resources.py b/tests/tools/test_task_resources.py new file mode 100644 index 00000000..4fe1b712 --- /dev/null +++ b/tests/tools/test_task_resources.py @@ -0,0 +1,86 @@ +import json + +import pytest + +from substratools.exceptions import InvalidCLIError +from substratools.exceptions import InvalidInputOutputsError +from substratools.task_resources import StaticInputIdentifiers +from substratools.task_resources import TaskResources + +_VALID_RESOURCES = [ + {"id": "foo", "value": "bar", "multiple": True}, + {"id": "foo", "value": "babar", "multiple": True}, + {"id": "fofo", "value": "bar", "multiple": False}, +] +_VALID_VALUES = {"foo": {"value": ["bar", "babar"], "multiple": True}, "fofo": {"value": ["bar"], "multiple": False}} + + +@pytest.mark.parametrize( + "invalid_arg", + ( + {"foo": "barr"}, + "foo and bar", + ["foo", "barr"], + [{"foo": "bar"}], + [{"foo": "bar"}, {"id": "foo", "value": "bar", "multiple": True}], + # [{_RESOURCE_ID: "foo", _RESOURCE_VALUE: "some path", _RESOURCE_MULTIPLE: "str"}], + ), +) +def test_task_resources_invalid_argsrt(invalid_arg): + with pytest.raises(InvalidCLIError): + TaskResources(json.dumps(invalid_arg)) + + +@pytest.mark.parametrize( + "valid_arg,expected", + [ + ([], {}), + ([{"id": "foo", "value": "bar", "multiple": True}], {"foo": {"value": ["bar"], "multiple": True}}), + ( + [{"id": "foo", "value": "bar", "multiple": True}, {"id": "foo", "value": "babar", "multiple": True}], + {"foo": {"value": ["bar", "babar"], "multiple": True}}, + ), + (_VALID_RESOURCES, _VALID_VALUES), + ], +) +def test_task_resources_values(valid_arg, expected): + TaskResources(json.dumps(valid_arg))._values == expected + + +@pytest.mark.parametrize( + "static_resource_id", + ( + StaticInputIdentifiers.chainkeys.value, + StaticInputIdentifiers.datasamples.value, + StaticInputIdentifiers.opener.value, + ), +) +def test_task_static_resources(static_resource_id): + "checks that static keys opener, datasamples and chainkeys are excluded" + + TaskResources( + json.dumps(_VALID_RESOURCES + [{"id": static_resource_id, "value": "foo", "multiple": False}]) + )._values == _VALID_VALUES + + +@pytest.mark.parametrize("key", tuple(_VALID_VALUES.keys())) +def test_get_value(key): + "get_value method returns a list of path of multiple resource and a path for non multiple ones" + expected = _VALID_VALUES[key]["value"] + + if _VALID_VALUES[key]["multiple"]: + expected = expected[0] + + +def test_multiple_resource_error(): + "non multiple resource can't have multiple values" + + with pytest.raises(InvalidInputOutputsError): + TaskResources( + json.dumps( + [ + {"id": "foo", "value": "bar", "multiple": False}, + {"id": "foo", "value": "babar", "multiple": False}, + ] + ) + ) diff --git a/tests/tools/test_utils.py b/tests/tools/test_utils.py new file mode 100644 index 00000000..c8698556 --- /dev/null +++ b/tests/tools/test_utils.py @@ -0,0 +1,44 @@ +import sys + +import pytest + +from substratools import exceptions +from substratools.opener import Opener +from substratools.utils import get_logger +from substratools.utils import import_module +from substratools.utils import load_interface_from_module + + +def test_invalid_interface(): + code = """ +def score(): + pass +""" + import_module("score", code) + with pytest.raises(exceptions.InvalidInterfaceError): + load_interface_from_module("score", interface_class=Opener) + + +@pytest.fixture +def syspaths(): + copy = sys.path[:] + yield sys.path + sys.path = copy + + +def test_empty_module(tmpdir, syspaths): + with tmpdir.as_cwd(): + # python allows to import an empty directoy + # check that the error message would be helpful for debugging purposes + tmpdir.mkdir("foomod") + syspaths.append(str(tmpdir)) + + with pytest.raises(exceptions.EmptyInterfaceError): + load_interface_from_module("foomod", interface_class=Opener) + + +def test_get_logger(capfd): + logger = get_logger("test") + logger.info("message") + captured = capfd.readouterr() + assert "INFO substratools.test - message" in captured.err diff --git a/tests/tools/test_workflow.py b/tests/tools/test_workflow.py new file mode 100644 index 00000000..726439c5 --- /dev/null +++ b/tests/tools/test_workflow.py @@ -0,0 +1,140 @@ +import json +import os + +import pytest + +from substratools import load_performance +from substratools import opener +from substratools import save_performance +from substratools.function import FunctionWrapper +from substratools.task_resources import TaskResources +from substratools.utils import import_module +from substratools.workspace import FunctionWorkspace +from tests.tools import utils +from tests.utils import InputIdentifiers +from tests.utils import OutputIdentifiers + + +@pytest.fixture +def dummy_opener(): + script = """ +import json +from substratools import Opener + +class DummyOpener(Opener): + def get_data(self, folder): + return None + + def fake_data(self, n_samples): + raise NotImplementedError +""" + import_module("opener", script) + + +def train(inputs, outputs, task_properties): + models = utils.load_models(inputs.get(InputIdentifiers.shared, [])) + total = sum([m["i"] for m in models]) + new_model = {"i": len(models) + 1, "total": total} + + utils.save_model(new_model, outputs.get(OutputIdentifiers.shared)) + + +def predict(inputs, outputs, task_properties): + model = utils.load_model(inputs.get(InputIdentifiers.shared)) + pred = {"sum": model["i"]} + utils.save_predictions(pred, outputs.get(OutputIdentifiers.predictions)) + + +def score(inputs, outputs, task_properties): + y_pred_path = inputs.get(InputIdentifiers.predictions) + y_pred = utils.load_predictions(y_pred_path) + + score = y_pred["sum"] + + save_performance(performance=score, path=outputs.get(OutputIdentifiers.performance)) + + +def test_workflow(workdir, dummy_opener): + loop1_model_path = workdir / "loop1model" + loop1_workspace_outputs = TaskResources( + json.dumps([{"id": OutputIdentifiers.shared, "value": str(loop1_model_path), "multiple": False}]) + ) + loop1_workspace = FunctionWorkspace(outputs=loop1_workspace_outputs) + loop1_wp = FunctionWrapper(workspace=loop1_workspace, opener_wrapper=None) + + # loop 1 (no input) + loop1_wp.execute(function=train) + model = utils.load_model(path=loop1_wp._workspace.task_outputs[OutputIdentifiers.shared]) + + assert model == {"i": 1, "total": 0} + assert os.path.exists(loop1_model_path) + + loop2_model_path = workdir / "loop2model" + + loop2_workspace_inputs = TaskResources( + json.dumps([{"id": InputIdentifiers.shared, "value": str(loop1_model_path), "multiple": True}]) + ) + loop2_workspace_outputs = TaskResources( + json.dumps([{"id": OutputIdentifiers.shared, "value": str(loop2_model_path), "multiple": False}]) + ) + loop2_workspace = FunctionWorkspace(inputs=loop2_workspace_inputs, outputs=loop2_workspace_outputs) + loop2_wp = FunctionWrapper(workspace=loop2_workspace, opener_wrapper=None) + + # loop 2 (one model as input) + loop2_wp.execute(function=train) + model = utils.load_model(path=loop2_wp._workspace.task_outputs[OutputIdentifiers.shared]) + assert model == {"i": 2, "total": 1} + assert os.path.exists(loop2_model_path) + + loop3_model_path = workdir / "loop2model" + loop3_workspace_inputs = TaskResources( + json.dumps( + [ + {"id": InputIdentifiers.shared, "value": str(loop1_model_path), "multiple": True}, + {"id": InputIdentifiers.shared, "value": str(loop2_model_path), "multiple": True}, + ] + ) + ) + loop3_workspace_outputs = TaskResources( + json.dumps([{"id": OutputIdentifiers.shared, "value": str(loop3_model_path), "multiple": False}]) + ) + loop3_workspace = FunctionWorkspace(inputs=loop3_workspace_inputs, outputs=loop3_workspace_outputs) + loop3_wp = FunctionWrapper(workspace=loop3_workspace, opener_wrapper=None) + + # loop 3 (two models as input) + loop3_wp.execute(function=train) + model = utils.load_model(path=loop3_wp._workspace.task_outputs[OutputIdentifiers.shared]) + assert model == {"i": 3, "total": 3} + assert os.path.exists(loop3_model_path) + + predictions_path = workdir / "predictions" + predict_workspace_inputs = TaskResources( + json.dumps([{"id": InputIdentifiers.shared, "value": str(loop3_model_path), "multiple": False}]) + ) + predict_workspace_outputs = TaskResources( + json.dumps([{"id": OutputIdentifiers.predictions, "value": str(predictions_path), "multiple": False}]) + ) + predict_workspace = FunctionWorkspace(inputs=predict_workspace_inputs, outputs=predict_workspace_outputs) + predict_wp = FunctionWrapper(workspace=predict_workspace, opener_wrapper=None) + + # predict + predict_wp.execute(function=predict) + pred = utils.load_predictions(path=predict_wp._workspace.task_outputs[OutputIdentifiers.predictions]) + assert pred == {"sum": 3} + + # metrics + performance_path = workdir / "performance" + metric_workspace_inputs = TaskResources( + json.dumps([{"id": InputIdentifiers.predictions, "value": str(predictions_path), "multiple": False}]) + ) + metric_workspace_outputs = TaskResources( + json.dumps([{"id": OutputIdentifiers.performance, "value": str(performance_path), "multiple": False}]) + ) + metric_workspace = FunctionWorkspace( + inputs=metric_workspace_inputs, + outputs=metric_workspace_outputs, + ) + metrics_wp = FunctionWrapper(workspace=metric_workspace, opener_wrapper=opener.load_from_module()) + metrics_wp.execute(function=score) + res = load_performance(path=metrics_wp._workspace.task_outputs[OutputIdentifiers.performance]) + assert res == 3.0 diff --git a/tests/tools/tools_conftest.py b/tests/tools/tools_conftest.py new file mode 100644 index 00000000..a51e6fe9 --- /dev/null +++ b/tests/tools/tools_conftest.py @@ -0,0 +1,85 @@ +import json +import os +import sys +from pathlib import Path +from uuid import uuid4 + +import pytest + +from substratools.task_resources import TaskResources +from substratools.utils import import_module +from substratools.workspace import FunctionWorkspace +from tests.utils import OutputIdentifiers + + +@pytest.fixture +def workdir(tmp_path): + d = tmp_path / "substra-workspace" + d.mkdir() + return d + + +@pytest.fixture(autouse=True) +def patch_cwd(monkeypatch, workdir): + # this is needed to ensure the workspace is located in a tmpdir + def getcwd(): + return str(workdir) + + monkeypatch.setattr(os, "getcwd", getcwd) + + +@pytest.fixture() +def valid_opener_code(): + return """ +import json +from substratools import Opener + +class FakeOpener(Opener): + def get_data(self, folder): + return 'X', list(range(0, 3)) + + def fake_data(self, n_samples): + return ['Xfake'] * n_samples, [0] * n_samples +""" + + +@pytest.fixture() +def valid_opener(valid_opener_code): + import_module("opener", valid_opener_code) + yield + del sys.modules["opener"] + + +@pytest.fixture() +def valid_opener_script(workdir, valid_opener_code): + opener_path = workdir / "my_opener.py" + opener_path.write_text(valid_opener_code) + + return str(opener_path) + + +@pytest.fixture(autouse=True) +def output_model_path(workdir: Path) -> str: + path = workdir / str(uuid4()) + yield path + if path.exists(): + os.remove(path) + + +@pytest.fixture(autouse=True) +def output_model_path_2(workdir: Path) -> str: + path = workdir / str(uuid4()) + yield path + if path.exists(): + os.remove(path) + + +@pytest.fixture() +def valid_function_workspace(output_model_path: str) -> FunctionWorkspace: + workspace_outputs = TaskResources( + json.dumps([{"id": OutputIdentifiers.shared, "value": str(output_model_path), "multiple": False}]) + ) + + workspace = FunctionWorkspace(outputs=workspace_outputs) + + return workspace diff --git a/tests/tools/utils.py b/tests/tools/utils.py new file mode 100644 index 00000000..3b39a20b --- /dev/null +++ b/tests/tools/utils.py @@ -0,0 +1,65 @@ +import json +from enum import Enum +from os import PathLike +from typing import Any +from typing import List + +from substratools.task_resources import StaticInputIdentifiers + + +class InputIdentifiers(str, Enum): + local = "local" + shared = "shared" + predictions = "predictions" + opener = StaticInputIdentifiers.opener.value + datasamples = StaticInputIdentifiers.datasamples.value + rank = StaticInputIdentifiers.rank.value + + +class OutputIdentifiers(str, Enum): + local = "local" + shared = "shared" + predictions = "predictions" + performance = "performance" + + +def load_models(paths: List[PathLike]) -> dict: + models = [] + for model_path in paths: + with open(model_path, "r") as f: + models.append(json.load(f)) + + return models + + +def load_model(path: PathLike): + if path: + with open(path, "r") as f: + return json.load(f) + + +def save_model(model: dict, path: PathLike): + with open(path, "w") as f: + json.dump(model, f) + + +def save_predictions(predictions: Any, path: PathLike): + with open(path, "w") as f: + json.dump(predictions, f) + + +def load_predictions(path: PathLike) -> Any: + with open(path, "r") as f: + predictions = json.load(f) + return predictions + + +def no_save_model(path, model): + # do not save model at all + pass + + +def wrong_save_model(model, path): + # simulate numpy.save behavior + with open(path + ".npy", "w") as f: + json.dump(model, f)