Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove handling_strategy parameter #843

Merged
merged 13 commits into from
Jun 28, 2022
23 changes: 1 addition & 22 deletions sdv/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import importlib
import inspect
import logging
import warnings

import pandas as pd
from copulas.multivariate.gaussian import GaussianMultivariate
Expand Down Expand Up @@ -101,10 +100,7 @@ class Constraint(metaclass=ConstraintMeta):
"""

constraint_columns = ()
rebuild_columns = ()
_hyper_transformer = None
_use_reject_sampling = False
IS_CUSTOM = False

def _validate_data_meets_constraint(self, table_data):
"""Make sure the given data is valid for the constraint.
Expand Down Expand Up @@ -182,22 +178,8 @@ def transform(self, table_data):
pandas.DataFrame:
Input data unmodified.
"""
self._use_reject_sampling = False
self._validate_all_columns_present(table_data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just copy paste the function logic here, instead of creating it separetely. The docstrings already describe that method either way, and it doesn't really have any test cases.


try:
transformed = self._transform(table_data)
if self.IS_CUSTOM:
self.reverse_transform(transformed)
return transformed

except Exception:
warnings.warn(
f'Error transforming {self.__class__.__name__}. Using the reject sampling '
'approach instead.'
)
self._use_reject_sampling = True
return table_data
return self._transform(table_data)

def fit_transform(self, table_data):
"""Fit this Constraint to the data and then transform it.
Expand Down Expand Up @@ -231,9 +213,6 @@ def reverse_transform(self, table_data):
pandas.DataFrame:
Input data unmodified.
"""
if self._use_reject_sampling:
return table_data

return self._reverse_transform(table_data)

def is_valid(self, table_data):
Expand Down
2 changes: 0 additions & 2 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def __init__(self, column_names):

self._columns = column_names
self.constraint_columns = tuple(column_names)
self.rebuild_columns = tuple(column_names)

def _fit(self, table_data):
"""Fit this Constraint to the data.
Expand Down Expand Up @@ -280,7 +279,6 @@ def __init__(self, low_column_name, high_column_name, strict_boundaries=False):
self._high_column_name = high_column_name
self._diff_column_name = f'{self._low_column_name}#{self._high_column_name}'
self._operator = np.greater if strict_boundaries else np.greater_equal
self.rebuild_columns = tuple(high_column_name)
self.constraint_columns = tuple([low_column_name, high_column_name])
self._dtype = None
self._is_datetime = None
Expand Down
16 changes: 11 additions & 5 deletions sdv/metadata/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,14 @@ def _get_transformers(self, dtypes):

return transformers

@staticmethod
def _warn_of_missing_columns(constraint, error):
warnings.warn(
f'{constraint.__class__.__name__} cannot be transformed because columns: '
f'{error.missing_columns} are not found. Using the reject sampling approach '
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

were instead of `are, no?

'instead.'
)

def _fit_transform_constraints(self, data):
errors = []
# Fit and validate all constraints first because `transform` might change columns
Expand All @@ -424,6 +432,8 @@ def _fit_transform_constraints(self, data):
for constraint in self._constraints:
try:
data = constraint.transform(data)
except MissingConstraintColumnError as e:
Table._warn_of_missing_columns(constraint, e)
except Exception as e:
errors.append(e)

Expand Down Expand Up @@ -581,11 +591,7 @@ def _transform_constraints(self, data):
try:
data = constraint.transform(data)
except MissingConstraintColumnError as e:
warnings.warn(
f'{constraint.__class__.__name__} cannot be transformed because columns: '
f'{e.missing_columns} are not found. Using the reject sampling approach '
'instead.'
)
Table._warn_of_missing_columns(constraint, e)
indices_to_drop = data.columns.isin(constraint.constraint_columns)
columns_to_drop = data.columns.where(indices_to_drop).dropna()
data = data.drop(columns_to_drop, axis=1)
Expand Down
93 changes: 9 additions & 84 deletions tests/unit/constraints/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,12 +286,10 @@ def test_transform(self):
"""
# Run
instance = Constraint()
instance._use_reject_sampling = True
output = instance.transform(pd.DataFrame({'col': ['input']}))

# Assert
pd.testing.assert_frame_equal(output, pd.DataFrame({'col': ['input']}))
assert instance._use_reject_sampling is False

def test_transform_calls__transform(self):
"""Test that the ``Constraint.transform`` method calls ``_transform``.
Expand All @@ -316,45 +314,13 @@ def test_transform_calls__transform(self):
assert output == 'the_transformed_data'
constraint_mock._validate_all_columns_present.assert_called_once()

def test_transform_calls__transform_and_reverse_transform_if_custom(self):
"""Test that the ``Constraint.transform`` method calls ``_reverse_transform`` if it's custom.

The ``Constraint.transform`` method is expected to:
- Return value returned by ``_transform``.

Setup:
- Set ``IS_CUSTOM`` to True.

Input:
- Anything

Output:
- Result of ``_transform(input)``
"""
# Setup
instance = Constraint()
instance.IS_CUSTOM = True
instance._transform = Mock()
instance.reverse_transform = Mock()
instance._transform.return_value = 'the_transformed_data'

# Run
output = instance.transform('input')

# Assert
assert output == 'the_transformed_data'
instance.reverse_transform.assert_called_once()

@patch('sdv.constraints.base.warnings')
def test_transform__transform_errors(self, warnings_mock):
def test_transform__transform_errors(self):
"""Test that the ``transform`` method handles any errors.

If the ``_transform`` method raises an error, the data should be return unchanged
and a warning should be raised.
If the ``_transform`` method raises an error, the error should be raised.

Setup:
- Make ``_transform`` raise an error.
- Mock warnings.

Input:
- ``pandas.DataFrame``.
Expand All @@ -363,58 +329,17 @@ def test_transform__transform_errors(self, warnings_mock):
- Same ``pandas.DataFrame``.

Side effects:
- Warning should be raised.
- Exception should be raised
"""
# Setup
constraint_mock = Mock()
constraint_mock._transform.side_effect = Exception()
data = pd.DataFrame({'a': [1, 2, 3]})

# Run
output = Constraint.transform(constraint_mock, data)

# Assert
pd.testing.assert_frame_equal(data, output)
expected_message = 'Error transforming Mock. Using the reject sampling approach instead.'
warnings_mock.warn.assert_called_with(expected_message)

@patch('sdv.constraints.base.warnings')
def test_transform_reverse_transform_errors(self, warnings_mock):
"""Test that the ``transform`` method handles any errors.

If the ``reverse_transform`` method raises an error, the data should be return unchanged
and a warning should be raised.

Setup:
- Make ``reverse_transform`` raise an error.
- Mock warnings.
- Set ``IS_CUSTOM`` to True.

Input:
- ``pandas.DataFrame``.

Output:
- Same ``pandas.DataFrame``.

Side effects:
- Warning should be raised.
"""
# Setup
constraint = Constraint()
constraint.IS_CUSTOM = True
constraint.reverse_transform = Mock()
constraint.reverse_transform.side_effect = Exception()
instance = Constraint()
instance._transform = Mock()
instance._transform.side_effect = Exception()
data = pd.DataFrame({'a': [1, 2, 3]})

# Run
output = constraint.transform(data)

# Assert
pd.testing.assert_frame_equal(data, output)
expected_message = (
'Error transforming Constraint. Using the reject sampling approach instead.'
)
warnings_mock.warn.assert_called_with(expected_message)
# Run / Assert
with pytest.raises(Exception):
instance.transform(data)

def test_transform_columns_missing(self):
"""Test the ``Constraint.transform`` method with invalid data.
Expand Down
20 changes: 1 addition & 19 deletions tests/unit/constraints/test_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,23 +398,7 @@ def test___init__(self):

# Assert
assert instance._columns == columns

def test___init__sets_rebuild_columns(self):
"""Test the ``FixedCombinations.__init__`` method.

The rebuild columns should be set.

Side effects:
- instance.rebuild_columns are set
"""
# Setup
columns = ['b', 'c']

# Run
instance = FixedCombinations(column_names=columns)

# Assert
assert instance.rebuild_columns == tuple(columns)
assert instance.constraint_columns == tuple(columns)

def test___init__with_one_column(self):
"""Test the ``FixedCombinations.__init__`` method with only one constraint column.
Expand Down Expand Up @@ -820,7 +804,6 @@ def test___init___(self, mock_validate):
- _low_column_name and _high_column_name are set to the input column names
- _diff_column_name is set to '_low_column_name#_high_column_name'
- _operator is set to the default np.greater_equal
- rebuild_columns is a tuple of _igh_column_name
- _dtype and _is_datetime are None
- _validate_inputs is called once
"""
Expand All @@ -832,7 +815,6 @@ def test___init___(self, mock_validate):
assert instance._high_column_name == 'b'
assert instance._diff_column_name == 'a#b'
assert instance._operator == np.greater_equal
assert instance.rebuild_columns == tuple('b')
assert instance._dtype is None
assert instance._is_datetime is None
mock_validate.assert_called_once_with('a', 'b', False)
Expand Down
41 changes: 41 additions & 0 deletions tests/unit/metadata/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,47 @@ def test_fit_constraint_transform_errors(self):
constraint1.fit.assert_called_once_with(data)
constraint2.fit.assert_called_once_with(data)

@patch('sdv.metadata.table.warnings')
def test_fit_constraint_transform_missing_columns_error(self, warnings_mock):
"""Test the ``fit`` method when transform raises a ``MissingConstraintColumnError``.

The ``fit`` method should loop through all the constraints and try to fit them. Then it
should loop through again and try to transform. If a ``MissingConstraintColumnError`` is
raised, a warning should be raised and reject sampling should be used.

Setup:
- Set the ``_constraints`` to be a list of mocked constraints.
- Set constraint mocks to raise ``MissingConstraintColumnError`` when calling
transform.
- Mock warnings module.

Input:
- A ``pandas.DataFrame``.

Side effect:
- A ``MissingConstraintColumnError`` should be raised.
"""
# Setup
data = pd.DataFrame({'a': [1, 2, 3]})
instance = Table()
constraint1 = Mock()
constraint2 = Mock()
constraint1.transform.return_value = data
constraint2.transform.side_effect = MissingConstraintColumnError(['column'])
instance._constraints = [constraint1, constraint2]

# Run
instance.fit(data)

# Assert
constraint1.fit.assert_called_once_with(data)
constraint2.fit.assert_called_once_with(data)
warning_message = (
"Mock cannot be transformed because columns: ['column'] are not found. Using the "
'reject sampling approach instead.'
)
warnings_mock.warn.assert_called_once_with(warning_message)

def test_transform_calls__transform_constraints(self):
"""Test that the `transform` method calls `_transform_constraints` with right parameters

Expand Down