Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove handling_strategy parameter #843

Merged
merged 13 commits into from
Jun 28, 2022
27 changes: 1 addition & 26 deletions docs/developer_guides/sdv/constraints.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,6 @@ The following public methods are implemented in this class:
* ``from_dict``: Build a ``Constraint`` from its dict representation.
* ``to_dict``: Return a dict representing the ``Constraint``.

handling_strategy
~~~~~~~~~~~~~~~~~

Additionally, the ``Constraint.__init__`` method sets up the class based on the value of the
argument ``handling_strategy`` as follows:

* If ``handling_strategy`` equals ``'transform'``, the ``filter_valid`` method is disabled by
replacing it with an identity function.
* If ``handling_strategy`` equals ``'reject_sampling'``, both the ``transform`` and
``reverse_transform`` methods are disabled by replacing them with an identity function.

Because of this, any subclass has the option to implement both the transformation and reject
sampling strategies and later on give the user the choice to choose between the two by just
calling the ``super().__init__`` method passing the corresponding ``handling_strategy`` value.

Implementing a Custom Constraint
--------------------------------

Expand Down Expand Up @@ -131,23 +116,13 @@ modeling and sampling the number of `pairs of legs` instead of the number of `le
table_data[self._column_name] = table_data[self._column_name] * 2
return table_data

With this new implementation, our Constraint would be ready to handle both strategies,
`reject sampling` and `transform`, but in some cases we might want to let the user
chose only one of them, so the other is skipped.

In a situation like this, we can simply add a ``handling_strategy`` parameter to our
``__init__`` method and call ``super().__init__`` passing it, so the base ``Constraint`` class
can handle it adequately:


.. code-block:: python

class PositiveEven(Constraint):
"""Ensure that values are positive and even."""

def __init__(self, column_name, handling_strategy='transform'):
def __init__(self, column_name):
self._column_name = column_name
super().__init__(handling_strategy=handling_strategy)

def is_valid(self, table_data):
"""Say if values are positive and even."""
Expand Down
5 changes: 1 addition & 4 deletions docs/user_guides/single_table/handling_constraints.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,7 @@ order to use this constraint we will need to import it from the

from sdv.constraints import FixedCombinations

fixed_company_department_constraint = FixedCombinations(
column_names=['company', 'department'],
handling_strategy='transform'
)
fixed_company_department_constraint = FixedCombinations(column_names=['company', 'department'])

Inequality and ScalarInequality Constraints
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
90 changes: 38 additions & 52 deletions sdv/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,61 +91,19 @@ class Constraint(metaclass=ConstraintMeta):
This class is not intended to be used directly and should rather be
subclassed to create different types of constraints.

If ``handling_strategy`` is passed with the value ``transform``
or ``reject_sampling``, the ``filter_valid`` or ``transform`` and
``reverse_transform`` methods will be replaced respectively by a simple
identity function.

Attributes:
constraint_columns (tuple[str]):
The names of the columns used by this constraint.
rebuild_columns (tuple[str]):
The names of the columns that this constraint will rebuild during
``reverse_transform``.
Args:
handling_strategy (str):
How this Constraint should be handled, which can be ``transform``,
``reject_sampling`` or ``all``.
"""

constraint_columns = ()
rebuild_columns = ()
_hyper_transformer = None

def _identity(self, table_data):
return table_data

def _identity_with_validation(self, table_data):
self._validate_data_on_constraint(table_data)
return table_data

def __init__(self, handling_strategy):
if handling_strategy == 'transform':
self.filter_valid = self._identity
elif handling_strategy == 'reject_sampling':
self.rebuild_columns = ()
self.transform = self._identity_with_validation
self.reverse_transform = self._identity
elif handling_strategy != 'all':
raise ValueError('Unknown handling strategy: {}'.format(handling_strategy))

def _fit(self, table_data):
del table_data

def fit(self, table_data):
"""Fit ``Constraint`` class to data.

Args:
table_data (pandas.DataFrame):
Table data.
"""
self._fit(table_data)

def _transform(self, table_data):
return table_data

def _validate_data_on_constraint(self, table_data):
"""Make sure the given data is valid for the given constraints.
def _validate_data_meets_constraint(self, table_data):
"""Make sure the given data is valid for the constraint.

Args:
data (pandas.DataFrame):
Expand All @@ -169,16 +127,37 @@ def _validate_data_on_constraint(self, table_data):

raise ConstraintsNotMetError(err_msg)

def check_missing_columns(self, table_data):
"""Check ``table_data`` for missing columns.
def _fit(self, table_data):
del table_data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we need this. Couldn't we just use pass or I'm missing something ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could, this is just from before


def fit(self, table_data):
"""Fit ``Constraint`` class to data.

Args:
table_data (pandas.DataFrame):
Table data.
"""
self._fit(table_data)
self._validate_data_meets_constraint(table_data)

def _transform(self, table_data):
return table_data

def _validate_all_columns_present(self, table_data):
"""Validate that all required columns are in ``table_data``.

Args:
table_data (pandas.DataFrame):
Table data.

Raises:
MissingConstraintColumnError:
If the data is missing any columns needed for the constraint transformation,
a ``MissingConstraintColumnError`` is raised.
"""
missing_columns = [col for col in self.constraint_columns if col not in table_data.columns]
if missing_columns:
raise MissingConstraintColumnError()
raise MissingConstraintColumnError(missing_columns=missing_columns)

def transform(self, table_data):
"""Perform necessary transformations needed by constraint.
Expand All @@ -188,7 +167,8 @@ def transform(self, table_data):
should overwrite the ``_transform`` method instead. This method raises a
``MissingConstraintColumnError`` if the ``table_data`` is missing any columns
needed to do the transformation. If columns are present, this method will call
the ``_transform`` method.
the ``_transform`` method. If ``_transform`` fails, the data will be returned
unchanged.

Args:
table_data (pandas.DataFrame):
Expand All @@ -198,8 +178,7 @@ def transform(self, table_data):
pandas.DataFrame:
Input data unmodified.
"""
self._validate_data_on_constraint(table_data)
self.check_missing_columns(table_data)
self._validate_all_columns_present(table_data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just copy paste the function logic here, instead of creating it separetely. The docstrings already describe that method either way, and it doesn't really have any test cases.

return self._transform(table_data)

def fit_transform(self, table_data):
Expand All @@ -216,8 +195,15 @@ def fit_transform(self, table_data):
self.fit(table_data)
return self.transform(table_data)

def _reverse_transform(self, table_data):
return table_data

def reverse_transform(self, table_data):
"""Identity method for completion. To be optionally overwritten by subclasses.
"""Handle logic around reverse transforming constraints.

If the ``transform`` method was skipped, then this method should be too.
Otherwise attempt to reverse transform and if that fails, return the data
unchanged to fall back on reject sampling.

Args:
table_data (pandas.DataFrame):
Expand All @@ -227,7 +213,7 @@ def reverse_transform(self, table_data):
pandas.DataFrame:
Input data unmodified.
"""
return table_data
return self._reverse_transform(table_data)

def is_valid(self, table_data):
"""Say whether the given table rows are valid.
Expand Down
3 changes: 3 additions & 0 deletions sdv/constraints/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
class MissingConstraintColumnError(Exception):
"""Error to use when constraint is provided a table with missing columns."""

def __init__(self, missing_columns):
self.missing_columns = missing_columns


class MultipleConstraintsErrors(Exception):
"""Error used to represent a list of constraint errors."""
Loading