diff --git a/sdv/constraints/tabular.py b/sdv/constraints/tabular.py index 5cd2034a6..15d843407 100644 --- a/sdv/constraints/tabular.py +++ b/sdv/constraints/tabular.py @@ -276,7 +276,8 @@ def _validate_inputs(low_column_name, high_column_name, strict_boundaries): if not isinstance(strict_boundaries, bool): raise ValueError('`strict_boundaries` must be a boolean.') - def __init__(self, low_column_name, high_column_name, strict_boundaries=False): + def __init__(self, low_column_name, high_column_name, strict_boundaries=False, + handling_strategy='transform', fit_columns_model=False): self._validate_inputs(low_column_name, high_column_name, strict_boundaries) self._low_column_name = low_column_name self._high_column_name = high_column_name @@ -286,7 +287,7 @@ def __init__(self, low_column_name, high_column_name, strict_boundaries=False): self.constraint_columns = tuple([low_column_name, high_column_name]) self._dtype = None self._is_datetime = None - super().__init__(handling_strategy='transform', fit_columns_model=False) + super().__init__(handling_strategy=handling_strategy, fit_columns_model=fit_columns_model) def _get_data(self, table_data): low = table_data[self._low_column_name].to_numpy() diff --git a/tests/integration/tabular/test_base.py b/tests/integration/tabular/test_base.py index c1f99b20f..eda553c75 100644 --- a/tests/integration/tabular/test_base.py +++ b/tests/integration/tabular/test_base.py @@ -370,7 +370,7 @@ def test_conditional_sampling_constraint_uses_columns_model(gm_mock, isinstance_ expected_states = pd.Series(['CA', 'CA', 'CA', 'CA', 'CA'], name='state') expected_ages = pd.Series([30, 30, 30, 30, 30], name='age') sample_calls = model._model.sample.mock_calls - assert len(sample_calls) >= 2 and len(sample_calls) <= 3 + assert 2 <= len(sample_calls) <= 3 assert all(c[2]['conditions']['age.value'] == 30 for c in sample_calls) assert all('city#state.value' in c[2]['conditions'] for c in sample_calls) pd.testing.assert_series_equal(sampled_data['age'], expected_ages) @@ -401,7 +401,10 @@ def test_conditional_sampling_constraint_uses_columns_model_reject_sampling(colu - Correct columns to condition on are passed to underlying sample method """ # Setup - constraint = Inequality(low_column_name='age_joined', high_column_name='age') + constraint = Inequality( + low_column_name='age_joined', + high_column_name='age', + fit_columns_model=True) data = pd.DataFrame({ 'age_joined': [22.0, 21.0, 15.0, 18.0, 29.0], 'age': [27.0, 28.0, 26.0, 21.0, 30.0], diff --git a/tests/integration/test_constraints.py b/tests/integration/test_constraints.py index 4137ff1da..8b05a2e50 100644 --- a/tests/integration/test_constraints.py +++ b/tests/integration/test_constraints.py @@ -4,7 +4,7 @@ import pytest from sdv.constraints import ( - Between, ColumnFormula, FixedCombinations, Negative, OneHotEncoding, Positive, Inequality, + Between, ColumnFormula, FixedCombinations, Inequality, Negative, OneHotEncoding, Positive, Rounding, Unique) from sdv.constraints.errors import MultipleConstraintsErrors from sdv.constraints.tabular import ScalarInequality @@ -29,6 +29,7 @@ def test_constraints(tmpdir): age_gt_age_when_joined_constraint = Inequality( low_column_name='age_when_joined', high_column_name='age', + handling_strategy='reject_sampling' ) years_in_the_company_constraint = ColumnFormula( @@ -39,8 +40,8 @@ def test_constraints(tmpdir): constraints = [ fixed_company_department_constraint, - years_in_the_company_constraint, age_gt_age_when_joined_constraint, + years_in_the_company_constraint, ] gc = GaussianCopula(constraints=constraints) gc.fit(employees)