Skip to content

Commit

Permalink
adding integration test and tracking if reverse transform should use …
Browse files Browse the repository at this point in the history
…reject sampling
  • Loading branch information
amontanez24 committed Jun 23, 2022
1 parent 90b6003 commit 1ff251b
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 30 deletions.
28 changes: 11 additions & 17 deletions sdv/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class Constraint(metaclass=ConstraintMeta):

constraint_columns = ()
_hyper_transformer = None
_use_reject_sampling = False

def _validate_data_meets_constraint(self, table_data):
"""Make sure the given data is valid for the constraint.
Expand Down Expand Up @@ -143,22 +144,6 @@ def fit(self, table_data):
def _transform(self, table_data):
return table_data

def _validate_all_columns_present(self, table_data):
"""Validate that all required columns are in ``table_data``.
Args:
table_data (pandas.DataFrame):
Table data.
Raises:
MissingConstraintColumnError:
If the data is missing any columns needed for the constraint transformation,
a ``MissingConstraintColumnError`` is raised.
"""
missing_columns = [col for col in self.constraint_columns if col not in table_data.columns]
if missing_columns:
raise MissingConstraintColumnError(missing_columns=missing_columns)

def transform(self, table_data):
"""Perform necessary transformations needed by constraint.
Expand All @@ -178,7 +163,13 @@ def transform(self, table_data):
pandas.DataFrame:
Input data unmodified.
"""
self._validate_all_columns_present(table_data)
self._use_reject_sampling = False

missing_columns = [col for col in self.constraint_columns if col not in table_data.columns]
if missing_columns:
self._use_reject_sampling = True
raise MissingConstraintColumnError(missing_columns=missing_columns)

return self._transform(table_data)

def fit_transform(self, table_data):
Expand Down Expand Up @@ -213,6 +204,9 @@ def reverse_transform(self, table_data):
pandas.DataFrame:
Input data unmodified.
"""
if self._use_reject_sampling:
return table_data

return self._reverse_transform(table_data)

def is_valid(self, table_data):
Expand Down
2 changes: 1 addition & 1 deletion sdv/metadata/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def _get_transformers(self, dtypes):
def _warn_of_missing_columns(constraint, error):
warnings.warn(
f'{constraint.__class__.__name__} cannot be transformed because columns: '
f'{error.missing_columns} are not found. Using the reject sampling approach '
f'{error.missing_columns} were not found. Using the reject sampling approach '
'instead.'
)

Expand Down
22 changes: 12 additions & 10 deletions tests/integration/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,32 @@
from sdv.tabular import GaussianCopula


def years_in_the_company(data):
return data['age'] - data['age_when_joined']


def test_constraints(tmpdir):

# Setup
employees = load_tabular_demo()

fixed_company_department_constraint = FixedCombinations(column_names=['company', 'department'])

age_gt_age_when_joined_constraint = Inequality(
low_column_name='age_when_joined',
high_column_name='age'
)

age_range_constraint = ScalarRange('age', 29, 50)
constraints = [
fixed_company_department_constraint,
age_gt_age_when_joined_constraint,
age_range_constraint
]
gc = GaussianCopula(constraints=constraints)

# Run
gc = GaussianCopula(constraints=constraints, min_value=None, max_value=None)
gc.fit(employees)
gc.save(tmpdir / 'test.pkl')
gc = gc.load(tmpdir / 'test.pkl')
gc.sample(10)
sampled = gc.sample(10)

# Assert
assert all(age_gt_age_when_joined_constraint.is_valid(sampled))
assert all(age_range_constraint.is_valid(sampled))
assert all(fixed_company_department_constraint.is_valid(sampled))


def test_failing_constraints():
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/constraints/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,12 @@ def test_transform(self):
"""
# Run
instance = Constraint()
instance._use_reject_sampling = True
output = instance.transform(pd.DataFrame({'col': ['input']}))

# Assert
pd.testing.assert_frame_equal(output, pd.DataFrame({'col': ['input']}))
assert instance._use_reject_sampling is False

def test_transform_calls__transform(self):
"""Test that the ``Constraint.transform`` method calls ``_transform``.
Expand All @@ -305,14 +307,14 @@ def test_transform_calls__transform(self):
"""
# Setup
constraint_mock = Mock()
constraint_mock.constraint_columns = []
constraint_mock._transform.return_value = 'the_transformed_data'

# Run
output = Constraint.transform(constraint_mock, 'input')

# Assert
assert output == 'the_transformed_data'
constraint_mock._validate_all_columns_present.assert_called_once()

def test_transform__transform_errors(self):
"""Test that the ``transform`` method handles any errors.
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/metadata/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def test_fit_constraint_transform_missing_columns_error(self, warnings_mock):
constraint1.fit.assert_called_once_with(data)
constraint2.fit.assert_called_once_with(data)
warning_message = (
"Mock cannot be transformed because columns: ['column'] are not found. Using the "
"Mock cannot be transformed because columns: ['column'] were not found. Using the "
'reject sampling approach instead.'
)
warnings_mock.warn.assert_called_once_with(warning_message)
Expand Down

0 comments on commit 1ff251b

Please sign in to comment.