Skip to content

Commit

Permalink
Add handling and column_model back
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed May 26, 2022
1 parent cd867c9 commit d0913a3
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
5 changes: 3 additions & 2 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ def _validate_inputs(low_column_name, high_column_name, strict_boundaries):
if not isinstance(strict_boundaries, bool):
raise ValueError('`strict_boundaries` must be a boolean.')

def __init__(self, low_column_name, high_column_name, strict_boundaries=False):
def __init__(self, low_column_name, high_column_name, strict_boundaries=False,
handling_strategy='transform', fit_columns_model=False):
self._validate_inputs(low_column_name, high_column_name, strict_boundaries)
self._low_column_name = low_column_name
self._high_column_name = high_column_name
Expand All @@ -286,7 +287,7 @@ def __init__(self, low_column_name, high_column_name, strict_boundaries=False):
self.constraint_columns = tuple([low_column_name, high_column_name])
self._dtype = None
self._is_datetime = None
super().__init__(handling_strategy='transform', fit_columns_model=False)
super().__init__(handling_strategy=handling_strategy, fit_columns_model=fit_columns_model)

def _get_data(self, table_data):
low = table_data[self._low_column_name].to_numpy()
Expand Down
7 changes: 5 additions & 2 deletions tests/integration/tabular/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def test_conditional_sampling_constraint_uses_columns_model(gm_mock, isinstance_
expected_states = pd.Series(['CA', 'CA', 'CA', 'CA', 'CA'], name='state')
expected_ages = pd.Series([30, 30, 30, 30, 30], name='age')
sample_calls = model._model.sample.mock_calls
assert len(sample_calls) >= 2 and len(sample_calls) <= 3
assert 2 <= len(sample_calls) <= 3
assert all(c[2]['conditions']['age.value'] == 30 for c in sample_calls)
assert all('city#state.value' in c[2]['conditions'] for c in sample_calls)
pd.testing.assert_series_equal(sampled_data['age'], expected_ages)
Expand Down Expand Up @@ -401,7 +401,10 @@ def test_conditional_sampling_constraint_uses_columns_model_reject_sampling(colu
- Correct columns to condition on are passed to underlying sample method
"""
# Setup
constraint = Inequality(low_column_name='age_joined', high_column_name='age')
constraint = Inequality(
low_column_name='age_joined',
high_column_name='age',
fit_columns_model=True)
data = pd.DataFrame({
'age_joined': [22.0, 21.0, 15.0, 18.0, 29.0],
'age': [27.0, 28.0, 26.0, 21.0, 30.0],
Expand Down
5 changes: 3 additions & 2 deletions tests/integration/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from sdv.constraints import (
Between, ColumnFormula, FixedCombinations, Negative, OneHotEncoding, Positive, Inequality,
Between, ColumnFormula, FixedCombinations, Inequality, Negative, OneHotEncoding, Positive,
Rounding, Unique)
from sdv.constraints.errors import MultipleConstraintsErrors
from sdv.constraints.tabular import ScalarInequality
Expand All @@ -29,6 +29,7 @@ def test_constraints(tmpdir):
age_gt_age_when_joined_constraint = Inequality(
low_column_name='age_when_joined',
high_column_name='age',
handling_strategy='reject_sampling'
)

years_in_the_company_constraint = ColumnFormula(
Expand All @@ -39,8 +40,8 @@ def test_constraints(tmpdir):

constraints = [
fixed_company_department_constraint,
years_in_the_company_constraint,
age_gt_age_when_joined_constraint,
years_in_the_company_constraint,
]
gc = GaussianCopula(constraints=constraints)
gc.fit(employees)
Expand Down

0 comments on commit d0913a3

Please sign in to comment.