Skip to content

Commit

Permalink
added GridSearch sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
mpvanderschelling committed Dec 18, 2023
1 parent 0991fae commit 430669b
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 6 deletions.
5 changes: 1 addition & 4 deletions docs/source/rst_doc_files/classes/sampling/sampling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ The function should return the samples (``input_data``) in one of the following
* A :class:`~pandas.DataFrame` object
* A :class:`~numpy.ndarray` object

.. note::

...


.. _implemented samplers:
Expand All @@ -36,7 +33,6 @@ To use the sampler in the data-driven process, you should pass the function to t
.. code-block:: python
from f3dasm.design import ExperimentData, Domain
from
domain = Domain(...)
# Create the ExperimentData object
Expand All @@ -63,4 +59,5 @@ Name Method
``"random"`` Random Uniform sampling `numpy.random.uniform <https://numpy.org/doc/stable/reference/random/generated/numpy.random.uniform.html>`_
``"latin"`` Latin Hypercube sampling `SALib.latin <https://salib.readthedocs.io/en/latest/api/SALib.sample.html?highlight=latin%20hypercube#SALib.sample.latin.sample>`_
``"sobol"`` Sobol Sequence sampling `SALib.sobol_sequence <https://salib.readthedocs.io/en/latest/api/SALib.sample.html?highlight=sobol%20sequence#SALib.sample.sobol_sequence.sample>`_
``"grid"`` Grid Search sampling `itertools.product <https://docs.python.org/3/library/itertools.html#itertools.product>`_
======================== ====================================================================== ===========================================================================================================
57 changes: 57 additions & 0 deletions src/f3dasm/_src/design/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import annotations

# Standard
from itertools import product
from typing import Optional

# Third-party
Expand Down Expand Up @@ -39,6 +40,9 @@ def _sampler_factory(sampler: str, domain: Domain) -> Sampler:
elif sampler.lower() == 'sobol':
return SobolSequence(domain)

elif sampler.lower() == 'grid':
return GridSampler(domain)

else:
raise KeyError(f"Sampler {sampler} not found!"
f"Available built-in samplers are: 'random',"
Expand Down Expand Up @@ -283,3 +287,56 @@ def sample_continuous(self, numsamples: int) -> np.ndarray:
# stretch samples
samples = self._stretch_samples(samples)
return samples


class GridSampler(Sampler):
"""Sampling via Grid Sampling
All the combination of the discrete and categorical parameters are
sampled. The argument number_of_samples is ignored.
Notes
-----
This sampler is at the moment only applicable for
discrete and categorical parameters.
"""

def get_samples(self, numsamples: Optional[int] = None) -> pd.DataFrame:
"""Receive samples of the search space
Parameters
----------
numsamples
number of samples
Returns
-------
Data objects with the samples
"""

self.set_seed(self.seed)

# If numsamples is None, take the object attribute number_of_samples
if numsamples is None:
numsamples = self.number_of_samples

continuous = self.domain.get_continuous_parameters()

if continuous:
raise ValueError("Grid sampling is only possible for domains \
strictly with only discrete and \
categorical parameters")

discrete = self.domain.get_discrete_parameters()
categorical = self.domain.get_categorical_parameters()

_iterdict = {}

for k, v in categorical.items():
_iterdict[k] = v.categories

for k, v, in discrete.items():
_iterdict[k] = range(v.lower_bound, v.upper_bound+1)

return pd.DataFrame(list(product(*_iterdict.values())),
columns=_iterdict)
17 changes: 15 additions & 2 deletions src/f3dasm/_src/experimentdata/experimentdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,9 @@ def from_sampling(cls, sampler: Sampler | str, domain: Domain | DictConfig,
Parameters
----------
sampler : Sampler
Sampler object containing the sampling strategy.
sampler : Sampler | str
Sampler object containing the sampling strategy or one of the
built-in sampler names.
domain : Domain | DictConfig
Domain object containing the domain of the experiment or hydra
DictConfig object containing the configuration.
Expand All @@ -291,6 +292,17 @@ def from_sampling(cls, sampler: Sampler | str, domain: Domain | DictConfig,
-------
ExperimentData
ExperimentData object containing the sampled data.
Note
----
If a string is passed for the sampler argument, it should be one
of the built-in samplers:
* 'random' : Random sampling
* 'latin' : Latin Hypercube Sampling
* 'sobol' : Sobol Sequence Sampling
* 'grid' : Grid Search Sampling
"""
experimentdata = cls(domain=domain)
experimentdata.sample(sampler=sampler, n_samples=n_samples, seed=seed)
Expand Down Expand Up @@ -1359,6 +1371,7 @@ def sample(self, sampler: Sampler | str, n_samples: int = 1,
* 'random' : Random sampling
* 'latin' : Latin Hypercube Sampling
* 'sobol' : Sobol Sequence Sampling
* 'grid' : Grid Search Sampling
Raises
------
Expand Down

0 comments on commit 430669b

Please sign in to comment.