Skip to content

Commit

Permalink
Raise constraint errors together + misc. (#807)
Browse files Browse the repository at this point in the history
* Add documentation that was forgotten

* Fix typo

* IN PROGRESS

* Working version without tests

* .

* Remove unnecessary files

* Add test case + print correctly

* Move validate in other validate

* Add validaton when handling_strategy is reject_sampling

* Move validation out of other validation method + rename _identity_transformer

* Add old validate_data_on_constraints back
  • Loading branch information
fealho authored May 26, 2022
1 parent 67614bd commit 04cfd65
Show file tree
Hide file tree
Showing 9 changed files with 225 additions and 129 deletions.
33 changes: 32 additions & 1 deletion sdv/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from rdt import HyperTransformer

from sdv.constraints.errors import MissingConstraintColumnError
from sdv.errors import ConstraintsNotMetError

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -120,13 +121,17 @@ class Constraint(metaclass=ConstraintMeta):
def _identity(self, table_data):
return table_data

def _identity_with_validation(self, table_data):
self._validate_data_on_constraint(table_data)
return table_data

def __init__(self, handling_strategy, fit_columns_model=False):
self.fit_columns_model = fit_columns_model
if handling_strategy == 'transform':
self.filter_valid = self._identity
elif handling_strategy == 'reject_sampling':
self.rebuild_columns = ()
self.transform = self._identity
self.transform = self._identity_with_validation
self.reverse_transform = self._identity
elif handling_strategy != 'all':
raise ValueError('Unknown handling strategy: {}'.format(handling_strategy))
Expand Down Expand Up @@ -220,6 +225,31 @@ def _sample_constraint_columns(self, table_data):
sampled_data = pd.concat(all_sampled_rows, ignore_index=True)
return sampled_data

def _validate_data_on_constraint(self, table_data):
"""Make sure the given data is valid for the given constraints.
Args:
data (pandas.DataFrame):
Table data.
Raises:
ConstraintsNotMetError:
If the table data is not valid for the provided constraints.
"""
if set(self.constraint_columns).issubset(table_data.columns.values):
is_valid_data = self.is_valid(table_data)
if not is_valid_data.all():
constraint_data = table_data[list(self.constraint_columns)]
invalid_rows = constraint_data[~is_valid_data]
err_msg = (
f"Data is not valid for the '{self.__class__.__name__}' constraint:\n"
f'{invalid_rows[:5]}'
)
if len(invalid_rows) > 5:
err_msg += f'\n+{len(invalid_rows) - 5} more'

raise ConstraintsNotMetError(err_msg)

def _validate_constraint_columns(self, table_data):
"""Validate the columns in ``table_data``.
Expand Down Expand Up @@ -277,6 +307,7 @@ def transform(self, table_data):
pandas.DataFrame:
Input data unmodified.
"""
self._validate_data_on_constraint(table_data)
table_data = self._validate_constraint_columns(table_data)
return self._transform(table_data)

Expand Down
4 changes: 4 additions & 0 deletions sdv/constraints/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@

class MissingConstraintColumnError(Exception):
"""Error to use when constraint is provided a table with missing columns."""


class MultipleConstraintsErrors(Exception):
"""Error used to represent a list of constraint errors."""
3 changes: 3 additions & 0 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
on the other columns of the table.
* Between: Ensure that the value in one column is always between the values
of two other columns/scalars.
* Rounding: Round a column based on the specified number of digits.
* OneHotEncoding: Ensure the rows of the specified columns are one hot encoded.
* Unique: Ensure that each value for a specified column/group of columns is unique.
"""

import operator
Expand Down Expand Up @@ -1135,6 +1137,7 @@ class Unique(Constraint):

def __init__(self, columns):
self.columns = columns if isinstance(columns, list) else [columns]
self.constraint_columns = tuple(self.columns)
super().__init__(handling_strategy='reject_sampling', fit_columns_model=False)

def is_valid(self, table_data):
Expand Down
36 changes: 8 additions & 28 deletions sdv/metadata/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from faker import Faker

from sdv.constraints.base import Constraint
from sdv.constraints.errors import MissingConstraintColumnError
from sdv.errors import ConstraintsNotMetError
from sdv.constraints.errors import MissingConstraintColumnError, MultipleConstraintsErrors
from sdv.metadata.errors import MetadataError, MetadataNotFittedError
from sdv.metadata.utils import strings_from_regex

Expand Down Expand Up @@ -443,9 +442,15 @@ def _get_transformers(self, dtypes):
return transformers

def _fit_transform_constraints(self, data):
errors = []
for constraint in self._constraints:
data = constraint.fit_transform(data)
try:
data = constraint.fit_transform(data)
except Exception as e:
errors.append(e)

if errors:
raise MultipleConstraintsErrors('\n' + '\n\n'.join(map(str, errors)))
return data

def _fit_hyper_transformer(self, data, extra_columns):
Expand Down Expand Up @@ -610,25 +615,6 @@ def _transform_constraints(self, data, on_missing_column='error'):

return data

def _validate_data_on_constraints(self, data):
"""Make sure the given data is valid for the given constraints.
Args:
data (pandas.DataFrame):
Table data.
Returns:
None
Raises:
ConstraintsNotMetError:
If the table data is not valid for the provided constraints.
"""
for constraint in self._constraints:
if set(constraint.constraint_columns).issubset(data.columns.values):
if not constraint.is_valid(data).all():
raise ConstraintsNotMetError('Data is not valid for the given constraints')

def transform(self, data, on_missing_column='error'):
"""Transform the given data.
Expand All @@ -643,10 +629,6 @@ def transform(self, data, on_missing_column='error'):
Returns:
pandas.DataFrame:
Transformed data.
Raises:
ConstraintsNotMetError:
If the table data is not valid for the provided constraints.
"""
if not self.fitted:
raise MetadataNotFittedError()
Expand All @@ -655,8 +637,6 @@ def transform(self, data, on_missing_column='error'):
LOGGER.debug('Anonymizing table %s', self.name)
data = self._anonymize(data[fields])

self._validate_data_on_constraints(data)

LOGGER.debug('Transforming constraints for table %s', self.name)
data = self._transform_constraints(data, on_missing_column)

Expand Down
86 changes: 85 additions & 1 deletion tests/integration/test_constraints.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
from sdv.constraints import ColumnFormula, FixedCombinations, GreaterThan
import re

import pandas as pd
import pytest

from sdv.constraints import (
Between, ColumnFormula, FixedCombinations, GreaterThan, Negative, OneHotEncoding, Positive,
Rounding, Unique)
from sdv.constraints.errors import MultipleConstraintsErrors
from sdv.demo import load_tabular_demo
from sdv.tabular import GaussianCopula

Expand Down Expand Up @@ -40,3 +48,79 @@ def test_constraints(tmpdir):
gc.save(tmpdir / 'test.pkl')
gc = gc.load(tmpdir / 'test.pkl')
gc.sample(10)


def test_failing_constraints():
data = pd.DataFrame({
'a': [0, 0, 0, 0, 0, 0, 0],
'b': [1, -1, 2, -2, 3, -3, 5],
'c': [-1, -1, -1, -1, -1, -1, -1],
'd': [1, -1, 2, -2, 3, -3, 5],
'e': [1, 2, 3, 4, 5, 6, 'a'],
'f': [1, 1, 2, 2, 3, 3, -1],
'g': [1, 0, 1, 0, 0, 1, 0],
'h': [1, 1, 1, 0, 0, 10, 0],
'i': [1, 1, 1, 1, 1, 1, 1]
})

constraints = [
GreaterThan('a', 'b'),
Positive('c'),
Negative('d'),
Rounding('e', 2),
Between('f', 0, 3),
OneHotEncoding(['g', 'h']),
Unique('i')
]
gc = GaussianCopula(constraints=constraints)

err_msg = re.escape(
"\nunsupported operand type(s) for -: 'str' and 'str'"
'\n'
"\nData is not valid for the 'OneHotEncoding' constraint:"
'\n g h'
'\n0 1 1'
'\n2 1 1'
'\n3 0 0'
'\n4 0 0'
'\n5 1 10'
'\n+1 more'
'\n'
"\nData is not valid for the 'Unique' constraint:"
'\n i'
'\n1 1'
'\n2 1'
'\n3 1'
'\n4 1'
'\n5 1'
'\n+1 more'
'\n'
"\nData is not valid for the 'GreaterThan' constraint:"
'\n a b'
'\n1 0 -1'
'\n3 0 -2'
'\n5 0 -3'
'\n'
"\nData is not valid for the 'Positive' constraint:"
'\n c'
'\n0 -1'
'\n1 -1'
'\n2 -1'
'\n3 -1'
'\n4 -1'
'\n+2 more'
'\n'
"\nData is not valid for the 'Negative' constraint:"
'\n d'
'\n0 1'
'\n2 2'
'\n4 3'
'\n6 5'
'\n'
"\nData is not valid for the 'Between' constraint:"
'\n f'
'\n6 -1'
)

with pytest.raises(MultipleConstraintsErrors, match=err_msg):
gc.fit(data)
Loading

0 comments on commit 04cfd65

Please sign in to comment.