Skip to content

Commit

Permalink
Refactor code and simplify Domain class
Browse files Browse the repository at this point in the history
  • Loading branch information
mpvanderschelling committed May 13, 2024
1 parent 0662704 commit e1c4491
Show file tree
Hide file tree
Showing 13 changed files with 613 additions and 839 deletions.
239 changes: 1 addition & 238 deletions src/f3dasm/_src/design/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,21 +262,6 @@ def _cast_types_dataframe(self) -> dict:
return {name: parameter._type for
name, parameter in self.space.items()}

def _create_empty_dataframe(self) -> pd.DataFrame:
"""Create an empty DataFrame with input columns.
Returns
-------
pd.DataFrame
DataFrame containing "input" columns.
"""
# input columns
input_columns = [name for name in self.space.keys()]

return pd.DataFrame(columns=input_columns).astype(
self._cast_types_dataframe()
)

# Append and remove parameters
# =============================================================================

Expand Down Expand Up @@ -390,23 +375,6 @@ def add_constant(self, name: str, value: Any):
"""
self._add(name, _ConstantParameter(value))

def add_parameter(self, name: str):
"""Add a new parameter to the domain.
Parameters
----------
name : str
Name of the input parameter.
Example
-------
>>> domain = Domain()
>>> domain.add_parameter('param1')
>>> domain.space
{'param1': Parameter()}
"""
self._add(name, _Parameter())

def add(self, name: str,
type: Literal['float', 'int', 'category', 'constant'],
**kwargs):
Expand Down Expand Up @@ -476,186 +444,6 @@ def add_output(self, name: str, to_disk: bool, exist_ok=False):
# Getters
# =============================================================================

def get_continuous_parameters(self) -> Dict[str, _ContinuousParameter]:
"""Get all continuous input parameters.
Returns
-------
Dict[str, _ContinuousParameter]
Space of continuous input parameters.
Example
-------
>>> domain = Domain()
>>> domain.space = {
... 'param1': _ContinuousParameter(lower_bound=0., upper_bound=1.),
... 'param2': CategoricalParameter(categories=['A', 'B', 'C']),
... 'param3': _ContinuousParameter(lower_bound=2., upper_bound=5.)
... }
>>> continuous_input_params = domain.get_continuous_input_parameters()
>>> continuous_input_params
{'param1': _ContinuousParameter(lower_bound=0., upper_bound=1.),
'param3': _ContinuousParameter(lower_bound=2., upper_bound=5.)}
"""
return self._filter(_ContinuousParameter).space

def get_continuous_names(self) -> List[str]:
"""Get the names of continuous input parameters in the input space.
Returns
-------
List[str]
List of names of continuous input parameters.
Example
-------
>>> domain = Domain()
>>> domain.space = {
... 'param1': _ContinuousParameter(lower_bound=0., upper_bound=1.),
... 'param2': _DiscreteParameter(lower_bound=1, upper_bound=3),
... 'param3': _ContinuousParameter(lower_bound=2., upper_bound=5.)
... }
>>> continuous_input_names = domain.get_continuous_input_names()
>>> continuous_input_names
['param1', 'param3']
"""
return self._filter(_ContinuousParameter).names

def get_discrete_parameters(self) -> Dict[str, _DiscreteParameter]:
"""Retrieve all discrete input parameters.
Returns
-------
Dict[str, _DiscreteParameter]
Space of discrete input parameters.
Example
-------
>>> domain = Domain()
>>> domain.space = {
... 'param1': _DiscreteParameter(lower_bound=1, upperBound=4),
... 'param2': CategoricalParameter(categories=['A', 'B', 'C']),
... 'param3': _DiscreteParameter(lower_bound=4, upperBound=6)
... }
>>> discrete_input_params = domain.get_discrete_input_parameters()
>>> discrete_input_params
{'param1': _DiscreteParameter(lower_bound=1, upperBound=4)),
'param3': _DiscreteParameter(lower_bound=4, upperBound=6)}
"""
return self._filter(_DiscreteParameter).space

def get_discrete_names(self) -> List[str]:
"""Retrieve the names of all discrete input parameters.
Returns
-------
List[str]
List of names of discrete input parameters.
Example
-------
>>> domain = Domain()
>>> domain.space = {
... 'param1': _DiscreteParameter(lower_bound=1, upperBound=4),
... 'param2': _ContinuousParameter(lower_bound=0, upper_bound=1),
... 'param3': _DiscreteParameter(lower_bound=4, upperBound=6)
... }
>>> discrete_input_names = domain.get_discrete_input_names()
>>> discrete_input_names
['param1', 'param3']
"""
return self._filter(_DiscreteParameter).names

def get_categorical_parameters(self) -> Dict[str, _CategoricalParameter]:
"""Retrieve all categorical input parameters.
Returns
-------
Dict[str, CategoricalParameter]
Space of categorical input parameters.
Example
-------
>>> domain = Domain()
>>> domain.space = {
... 'param1': CategoricalParameter(categories=['A', 'B', 'C']),
... 'param2': _ContinuousParameter(lower_bound=0, upper_bound=1),
... 'param3': CategoricalParameter(categories=['X', 'Y', 'Z'])
... }
>>> categorical_input_params =
domain.get_categorical_input_parameters()
>>> categorical_input_params
{'param1': CategoricalParameter(categories=['A', 'B', 'C']),
'param3': CategoricalParameter(categories=['X', 'Y', 'Z'])}
"""
return self._filter(_CategoricalParameter).space

def get_categorical_names(self) -> List[str]:
"""Retrieve the names of categorical input parameters.
Returns
-------
List[str]
List of names of categorical input parameters.
Example
-------
>>> domain = Domain()
>>> domain.space = {
... 'param1': CategoricalParameter(categories=['A', 'B', 'C']),
... 'param2': _ContinuousParameter(lower_bound=0, upper_bound=1),
... 'param3': CategoricalParameter(categories=['X', 'Y', 'Z'])
... }
>>> categorical_input_names = domain.get_categorical_input_names()
>>> categorical_input_names
['param1', 'param3']
"""
return self._filter(_CategoricalParameter).names

def get_constant_parameters(self) -> Dict[str, _ConstantParameter]:
"""Retrieve all constant input parameters.
Returns
-------
Dict[str, ConstantParameter]
Space of constant input parameters.
Example
-------
>>> domain = Domain()
>>> domain.space = {
... 'param1': ConstantParameter(value=0),
... 'param2': CategoricalParameter(categories=['A', 'B', 'C']),
... 'param3': ConstantParameter(value=1)
... }
>>> constant_input_params = domain.get_constant_input_parameters()
>>> constant_input_params
{'param1': ConstantParameter(value=0),
'param3': ConstantParameter(value=1)}
"""
return self._filter(_ConstantParameter).space

def get_constant_names(self) -> List[str]:
"""Receive the names of the constant input parameters
Returns
-------
list of names of constant input parameters
Example
-------
>>> domain = Domain()
>>> domain.space = {
... 'param1': ConstantParameter(value=0),
... 'param2': ConstantParameter(value=1),
... 'param3': _ContinuousParameter(lower_bound=0, upper_bound=1)
... }
>>> constant_input_names = domain.get_constant_input_names()
>>> constant_input_names
['param1', 'param2']
"""
return self._filter(_ConstantParameter).names

def get_bounds(self) -> np.ndarray:
"""Return the boundary constraints of the continuous input parameters
Expand All @@ -680,7 +468,7 @@ def get_bounds(self) -> np.ndarray:
"""
return np.array(
[[parameter.lower_bound, parameter.upper_bound]
for _, parameter in self.get_continuous_parameters().items()]
for _, parameter in self.continuous.space.items()]
)

def _filter(self, type: Type[_Parameter]) -> Domain:
Expand Down Expand Up @@ -788,30 +576,6 @@ def _all_input_continuous(self) -> bool:
"""Check if all input parameters are continuous"""
return len(self) == len(self._filter(_ContinuousParameter))

def _check_output(self, names: List[str]):
"""Check if output is in the domain and add it if not
Parameters
----------
names : list of str
Names of the outputs to be checked
Example
-------
>>> domain = Domain()
>>> domain.add_output('output1')
>>> domain.add_output('output2')
>>> domain._check_output(['output1', 'output2', 'output3'])
>>> domain.output_space
{'output1': _ContinuousParameter(lower_bound=-inf, upper_bound=inf),
'output2': _ContinuousParameter(lower_bound=-inf, upper_bound=inf),
'output3': _ContinuousParameter(lower_bound=-inf, upper_bound=inf)}
"""
for output_name in names:
if not self.is_in_output(output_name):
self.add_output(output_name, to_disk=False)

def is_in_output(self, output_name: str) -> bool:
"""Check if output is in the domain
Expand Down Expand Up @@ -886,7 +650,6 @@ def _domain_factory(domain: Domain | DictConfig | None,
input_data: pd.DataFrame,
output_data: pd.DataFrame) -> Domain:
if isinstance(domain, Domain):
# domain._check_output(output_data.columns)
return domain

elif isinstance(domain, (Path, str)):
Expand Down
Loading

0 comments on commit e1c4491

Please sign in to comment.