diff --git a/docs/source/rst_doc_files/classes/sampling/sampling.rst b/docs/source/rst_doc_files/classes/sampling/sampling.rst index 168777e5..96e46baa 100644 --- a/docs/source/rst_doc_files/classes/sampling/sampling.rst +++ b/docs/source/rst_doc_files/classes/sampling/sampling.rst @@ -21,9 +21,6 @@ The function should return the samples (``input_data``) in one of the following * A :class:`~pandas.DataFrame` object * A :class:`~numpy.ndarray` object -.. note:: - - ... .. _implemented samplers: @@ -36,7 +33,6 @@ To use the sampler in the data-driven process, you should pass the function to t .. code-block:: python from f3dasm.design import ExperimentData, Domain - from domain = Domain(...) # Create the ExperimentData object @@ -63,4 +59,5 @@ Name Method ``"random"`` Random Uniform sampling `numpy.random.uniform `_ ``"latin"`` Latin Hypercube sampling `SALib.latin `_ ``"sobol"`` Sobol Sequence sampling `SALib.sobol_sequence `_ +``"grid"`` Grid Search sampling `itertools.product `_ ======================== ====================================================================== =========================================================================================================== diff --git a/src/f3dasm/_src/design/samplers.py b/src/f3dasm/_src/design/samplers.py index 8d9468d0..05a78410 100644 --- a/src/f3dasm/_src/design/samplers.py +++ b/src/f3dasm/_src/design/samplers.py @@ -6,6 +6,7 @@ from __future__ import annotations # Standard +from itertools import product from typing import Optional # Third-party @@ -39,6 +40,9 @@ def _sampler_factory(sampler: str, domain: Domain) -> Sampler: elif sampler.lower() == 'sobol': return SobolSequence(domain) + elif sampler.lower() == 'grid': + return GridSampler(domain) + else: raise KeyError(f"Sampler {sampler} not found!" f"Available built-in samplers are: 'random'," @@ -283,3 +287,56 @@ def sample_continuous(self, numsamples: int) -> np.ndarray: # stretch samples samples = self._stretch_samples(samples) return samples + + +class GridSampler(Sampler): + """Sampling via Grid Sampling + + All the combination of the discrete and categorical parameters are + sampled. The argument number_of_samples is ignored. + Notes + ----- + This sampler is at the moment only applicable for + discrete and categorical parameters. + + """ + + def get_samples(self, numsamples: Optional[int] = None) -> pd.DataFrame: + """Receive samples of the search space + + Parameters + ---------- + numsamples + number of samples + + Returns + ------- + Data objects with the samples + """ + + self.set_seed(self.seed) + + # If numsamples is None, take the object attribute number_of_samples + if numsamples is None: + numsamples = self.number_of_samples + + continuous = self.domain.get_continuous_parameters() + + if continuous: + raise ValueError("Grid sampling is only possible for domains \ + strictly with only discrete and \ + categorical parameters") + + discrete = self.domain.get_discrete_parameters() + categorical = self.domain.get_categorical_parameters() + + _iterdict = {} + + for k, v in categorical.items(): + _iterdict[k] = v.categories + + for k, v, in discrete.items(): + _iterdict[k] = range(v.lower_bound, v.upper_bound+1) + + return pd.DataFrame(list(product(*_iterdict.values())), + columns=_iterdict) diff --git a/src/f3dasm/_src/experimentdata/experimentdata.py b/src/f3dasm/_src/experimentdata/experimentdata.py index 0963f97d..fb49bb80 100644 --- a/src/f3dasm/_src/experimentdata/experimentdata.py +++ b/src/f3dasm/_src/experimentdata/experimentdata.py @@ -277,8 +277,9 @@ def from_sampling(cls, sampler: Sampler | str, domain: Domain | DictConfig, Parameters ---------- - sampler : Sampler - Sampler object containing the sampling strategy. + sampler : Sampler | str + Sampler object containing the sampling strategy or one of the + built-in sampler names. domain : Domain | DictConfig Domain object containing the domain of the experiment or hydra DictConfig object containing the configuration. @@ -291,6 +292,17 @@ def from_sampling(cls, sampler: Sampler | str, domain: Domain | DictConfig, ------- ExperimentData ExperimentData object containing the sampled data. + + Note + ---- + + If a string is passed for the sampler argument, it should be one + of the built-in samplers: + + * 'random' : Random sampling + * 'latin' : Latin Hypercube Sampling + * 'sobol' : Sobol Sequence Sampling + * 'grid' : Grid Search Sampling """ experimentdata = cls(domain=domain) experimentdata.sample(sampler=sampler, n_samples=n_samples, seed=seed) @@ -1359,6 +1371,7 @@ def sample(self, sampler: Sampler | str, n_samples: int = 1, * 'random' : Random sampling * 'latin' : Latin Hypercube Sampling * 'sobol' : Sobol Sequence Sampling + * 'grid' : Grid Search Sampling Raises ------