Skip to content

Commit

Permalink
Merge pull request #254 from bessagroup/mpvanderschelling/issue252
Browse files Browse the repository at this point in the history
Improving usability by implementing helper function to convert arbitrary functions to DataGenerators
  • Loading branch information
mpvanderschelling authored Dec 13, 2023
2 parents 60b4d24 + 89fb54c commit ebe8d20
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 2 deletions.
68 changes: 67 additions & 1 deletion src/f3dasm/_src/datageneration/datagenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import sys
from abc import abstractmethod
from functools import partial
from typing import Any, Callable, Optional
from typing import Any, Callable, Dict, List, Optional

if sys.version_info < (3, 8): # NOQA
from typing_extensions import Protocol # NOQA
Expand Down Expand Up @@ -187,3 +187,69 @@ def add_post_process(self, func: Callable, **kwargs):
The keyword arguments to pass to the post-processing function
"""
self.post_process = partial(func, **kwargs)


def convert_function(f: Callable,
input: List[str],
output: Optional[List[str]] = None,
kwargs: Optional[Dict[str, Any]] = None,
to_disk: Optional[List[str]] = None) -> DataGenerator:
"""
Converts a given function `f` into a `DataGenerator` object.
Parameters
----------
f : Callable
The function to be converted.
input : List[str]
A list of argument names required by the function.
output : Optional[List[str]], optional
A list of names for the return values of the function.
Defaults to None.
kwargs : Optional[Dict[str, Any]], optional
Additional keyword arguments passed to the function. Defaults to None.
to_disk : Optional[List[str]], optional
The list of output names where the value needs to be stored on disk.
Defaults to None.
Returns
-------
DataGenerator
A converted `DataGenerator` object.
Notes
-----
The function `f` can have any number of arguments and any number of returns
as long as they are consistent with the `input` and `output` arguments that
are given to this function.
"""

kwargs = kwargs if kwargs is not None else {}
to_disk = to_disk if to_disk is not None else []
output = output if output is not None else []

class TempDataGenerator(DataGenerator):
def execute(self, **_kwargs) -> None:
_input = {input_name: self.experiment_sample.get(input_name)
for input_name in input}
_output = f(**_input, **kwargs)

# check if output is empty
if output is None:
return

if len(output) == 1:
_output = (_output,)

for name, value in zip(output, _output):
if name in to_disk:
self.experiment_sample.store(name=name,
object=value,
to_disk=True)
else:
self.experiment_sample.store(name=name,
object=value,
to_disk=False)

return TempDataGenerator()
2 changes: 1 addition & 1 deletion src/f3dasm/datageneration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Modules
# =============================================================================

from .._src.datageneration.datagenerator import DataGenerator
from .._src.datageneration.datagenerator import DataGenerator, convert_function

# Authorship & Credits
# =============================================================================
Expand Down
Empty file.
35 changes: 35 additions & 0 deletions tests/datageneration/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Callable

import pytest

from f3dasm import ExperimentData
from f3dasm.design import Domain


@pytest.fixture(scope="package")
def experiment_data() -> ExperimentData:
domain = Domain()
domain.add_float('x', low=0.0, high=1.0)

experiment_data = ExperimentData(domain=domain)

experiment_data.sample(sampler='random', n_samples=10, seed=2023)
return experiment_data


def example_function(x: int, s: int):
return x + s, x - s


def example_function2(x: int):
return x, -x


@pytest.fixture(scope="package")
def function_1() -> Callable:
return example_function


@pytest.fixture(scope="package")
def function_2() -> Callable:
return example_function2
28 changes: 28 additions & 0 deletions tests/datageneration/test_datagenerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Callable

import pytest

from f3dasm import ExperimentData
from f3dasm.datageneration import DataGenerator, convert_function

pytestmark = pytest.mark.smoke


def test_convert_function(
experiment_data: ExperimentData, function_1: Callable):
data_generator = convert_function(f=function_1, input=['x'], output=[
'y0', 'y1'], kwargs={'s': 103})

assert isinstance(data_generator, DataGenerator)

experiment_data.evaluate(data_generator)


def test_convert_function2(
experiment_data: ExperimentData, function_2: Callable):
data_generator = convert_function(f=function_2, input=['x'], output=[
'y0', 'y1'])

assert isinstance(data_generator, DataGenerator)

experiment_data.evaluate(data_generator)

0 comments on commit ebe8d20

Please sign in to comment.