Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi-column specification for positive and negative constraint #533

Merged
merged 18 commits into from
Aug 6, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions docs/user_guides/single_table/constraints.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,25 +174,46 @@ this functionality, we can pass:
handling_strategy='reject_sampling'
)

Optionally, when constructing ``GreaterThan`` constraint for scalar
comparisons, we can specify more than a single column in either
the ``high`` or ``low`` arguments. For example, we can create a
``GreaterThan`` constraint that ensures that the years of experience
is more than one year.

.. ipython:: python
:okwarning:

experience_years_gt_one_constraint = GreaterThan(
low=1,
high=['years_in_the_company', 'prior_years_experience'],
handling_strategy='reject_sampling'
)

.. note::

To specify more than one column, either ``high`` or ``low`` must
be a scalar value, otherwise the constraint cannot be correctly
evaluated.

Positive and Negative Constraints
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Similar to the ``GreaterThan`` constraint, we can use the ``Positive``
or ``Negative`` constraints. These constraints enforce that a specified
column is always positive or negative. We can create an instance passing:
or ``Negative`` constraints. These constraints enforce that specified
column(s) are always positive or negative. We can create an instance passing:

- the name of the ``low`` column for ``Negative`` or the name of the ``high`` column for ``Positive``
- the name of the column(s) for ``Negative`` or ``Positive`` constraints
- a boolean specifying whether to make the data strictly above or below 0,
or include 0 as a possible value
- the handling strategy that we want to use
- a boolean specifying whether to make the data strictly above or below 0, or include 0 as a possible value

.. ipython:: python
:okwarning:

from sdv.constraints import Positive

positive_prior_exp_constraint = Positive(
high='prior_years_experience',
positive_age_constraint = Positive(
columns='age',
strict=False,
handling_strategy='reject_sampling'
)
Expand Down Expand Up @@ -319,9 +340,10 @@ constraints that we just defined as a ``list``:
constraints = [
unique_company_department_constraint,
age_gt_age_when_joined_constraint,
years_in_the_company_constraint,
salary_gt_30000_constraint,
positive_prior_exp_constraint,
experience_years_gt_one_constraint,
positive_age_constraint,
years_in_the_company_constraint,
salary_rounding_constraint,
reasonable_age_constraint,
one_hot_constraint
Expand All @@ -345,3 +367,4 @@ we defined:
:okwarning:

sampled

152 changes: 83 additions & 69 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
across several columns are the same after sampling.
* GreaterThan: Ensure that the value in one column is always greater than
the value in another column.
* Positive: Ensure that the values in given columns are always positive.
* Negative: Ensure that the values in given columns are always negative.
* ColumnFormula: Compute the value of a column based on applying a formula
on the other columns of the table.
* Between: Ensure that the value in one column is always between the values
Expand Down Expand Up @@ -195,11 +197,11 @@ class GreaterThan(Constraint):
will be added to the diff to reconstruct the ``high`` column.

Args:
low (str or int):
Either the name of the column that contains the low value,
low (str or list[str]):
Either the name of the column(s) that contains the low value,
or a scalar that is the low value.
high (str or int):
Either the name of the column that contains the high value,
high (str or list[str]):
Either the name of the column(s) that contains the high value,
or a scalar that is the high value.
strict (bool):
Whether the comparison of the values should be strict ``>=`` or
Expand All @@ -221,40 +223,50 @@ class GreaterThan(Constraint):
by checking if the value provided is a column name.
"""

_diff_column = None
_column_list = False
_diff_columns = None
_is_datetime = None
_column_to_reconstruct = None

def __init__(self, low, high, strict=False, handling_strategy='transform',
fit_columns_model=True, drop=None, high_is_scalar=None,
low_is_scalar=None):
self.constraint_columns = (low, high)
low = [low] if isinstance(low, str) else low
high = [high] if isinstance(high, str) else high

self._low = low
self._high = high
self._strict = strict
self.constraint_columns = (low, high)
self._drop = drop
self._high_is_scalar = high_is_scalar
self._low_is_scalar = low_is_scalar

if strict:
self.operator = np.greater
else:
self.operator = np.greater_equal

super().__init__(handling_strategy=handling_strategy,
fit_columns_model=fit_columns_model)

def _get_low_value(self, table_data):
if self._low_is_scalar:
return self._low
elif self._low in table_data.columns:
return table_data[self._low]
elif any([low in table_data.columns for low in self._low]):
return table_data[self._low].values

return None

def _get_high_value(self, table_data):
if self._high_is_scalar:
return self._high
elif self._high in table_data.columns:
return table_data[self._high]
elif any([high in table_data.columns for high in self._high]):
return table_data[self._high].values

return None

def _get_column_to_reconstruct(self):
def _get_columns_to_reconstruct(self):
if self._drop == 'high':
column = self._high
elif self._drop == 'low':
Expand All @@ -266,19 +278,23 @@ def _get_column_to_reconstruct(self):

return column

def _get_diff_column_name(self, table_data):
def _get_diff_columns_name(self, table_data):
token = '#'
if len(self.constraint_columns) == 1:
name = self.constraint_columns[0] + token
names = []
for column in self.constraint_columns:
name = column + token
while name in table_data.columns:
name += '#'

return name
names.append(name)

while token.join(self.constraint_columns) in table_data.columns:
token += '#'
if not self._high_is_scalar and not self._low_is_scalar:
while token.join(self.constraint_columns) in table_data.columns:
token += '#'

names = [token.join(self.constraint_columns)]

return token.join(self.constraint_columns)
return names

def _get_is_datetime(self, table_data):
low = self._get_low_value(table_data)
Expand All @@ -301,23 +317,24 @@ def _fit(self, table_data):
The Table data.
"""
if self._high_is_scalar is None:
self._high_is_scalar = self._high not in table_data.columns
self._high_is_scalar = not isinstance(self._high, list)
if self._low_is_scalar is None:
self._low_is_scalar = self._low not in table_data.columns
self._low_is_scalar = not isinstance(self._low, list)

if self._high_is_scalar and self._low_is_scalar:
raise TypeError('`low` and `high` cannot be both scalars at the same time')
elif self._low_is_scalar:
self.constraint_columns = (self._high,)
self._dtype = table_data[self._high].dtype
self.constraint_columns = tuple(self._high)
elif self._high_is_scalar:
self.constraint_columns = (self._low,)
self._dtype = table_data[self._low].dtype
self.constraint_columns = tuple(self._low)
else:
self._dtype = table_data[self._high].dtype
self.constraint_columns = tuple(self._low + self._high)
if len(self.constraint_columns) > 2:
raise ValueError('`low` and `high` cannot be more than one column.')

self._column_to_reconstruct = self._get_column_to_reconstruct()
self._diff_column = self._get_diff_column_name(table_data)
self._columns_to_reconstruct = self._get_columns_to_reconstruct()
self._dtype = [table_data[column].dtype for column in self._columns_to_reconstruct]
self._diff_columns = self._get_diff_columns_name(table_data)
self._is_datetime = self._get_is_datetime(table_data)

def is_valid(self, table_data):
Expand All @@ -333,10 +350,8 @@ def is_valid(self, table_data):
"""
low = self._get_low_value(table_data)
high = self._get_high_value(table_data)
if self._strict:
return high > low

return high >= low
return self.operator(high, low).all(axis=1)

def _transform(self, table_data):
"""Transform the table data.
Expand All @@ -359,9 +374,9 @@ def _transform(self, table_data):
diff = self._get_high_value(table_data) - self._get_low_value(table_data)

if self._is_datetime:
diff = pd.to_numeric(diff)
diff = diff.astype(np.float64)

table_data[self._diff_column] = np.log(diff + 1)
table_data[self._diff_columns] = np.log(diff + 1)
if self._drop == 'high':
table_data = table_data.drop(self._high, axis=1)
elif self._drop == 'low':
Expand All @@ -388,92 +403,91 @@ def reverse_transform(self, table_data):
Transformed data.
"""
table_data = table_data.copy()
diff = (np.exp(table_data[self._diff_column]).round() - 1).clip(0)
diff = (np.exp(table_data[self._diff_columns].values).round() - 1).clip(0)
if self._is_datetime:
diff = pd.to_timedelta(diff)
diff = diff.astype('timedelta64[ns]')

high = self._get_high_value(table_data)
low = self._get_low_value(table_data)

if self._drop == 'high':
table_data[self._high] = (low + diff).astype(self._dtype)
new_values = pd.DataFrame(diff + low, columns=self._high)
table_data[self._high] = new_values.astype(dict(zip(self._high, self._dtype)))
elif self._drop == 'low':
table_data[self._low] = (high - diff).astype(self._dtype)
new_values = pd.DataFrame(high - diff, columns=self._low)
table_data[self._low] = new_values.astype(dict(zip(self._low, self._dtype)))
else:
invalid = ~self.is_valid(table_data)
if not self._high_is_scalar and not self._low_is_scalar:
new_values = low.loc[invalid] + diff.loc[invalid]
new_values = low[invalid] + diff[invalid]
elif self._high_is_scalar:
new_values = high - diff.loc[invalid]
new_values = high - diff[invalid]
else:
new_values = low + diff.loc[invalid]
new_values = low + diff[invalid]

table_data[self._column_to_reconstruct].loc[invalid] = new_values.astype(self._dtype)
for i, column in enumerate(self._columns_to_reconstruct):
table_data.at[invalid, column] = new_values[:, i].astype(self._dtype[i])

table_data = table_data.drop(self._diff_column, axis=1)
table_data = table_data.drop(self._diff_columns, axis=1)

return table_data


class Positive(GreaterThan):
"""Ensure that the ``high`` column is always positive.
"""Ensure that the given column is always positive.

The transformation strategy works by creating a column with the
difference between ``high`` and 0 value and then computing back the ``high``
value by adding the difference to 0 when reversing the transformation.
The transformation strategy works by creating columns with the
difference between given columns and zero then computing back the
necessary columns using the difference.

Args:
high (str or int):
Either the name of the column that contains the high value,
or a scalar that is the high value.
columns (str or list[str]):
The name of the column(s) that are constrained to be positive.
strict (bool):
Whether the comparison of the values should be strict ``>=`` or
not ``>`` when comparing them. Currently, this is only respected
Whether the comparison of the values should be strict; disclude
zero ``>`` or include it ``>=``. Currently, this is only respected
if ``reject_sampling`` or ``all`` handling strategies are used.
handling_strategy (str):
How this Constraint should be handled, which can be ``transform``
or ``reject_sampling``. Defaults to ``transform``.
drop (str):
Which column to drop during transformation. Can be ``'high'``
or ``None``.
drop (bool):
Whether to drop columns during transformation.
"""

def __init__(self, high, strict=False, handling_strategy='transform',
def __init__(self, columns, strict=False, handling_strategy='transform',
fit_columns_model=True, drop=None):
super().__init__(handling_strategy=handling_strategy,
fit_columns_model=fit_columns_model,
high=high, low=0, high_is_scalar=False,
high=columns, low=0, high_is_scalar=False,
low_is_scalar=True, drop=drop, strict=strict)


class Negative(GreaterThan):
"""Ensure that the ``low`` column is always negative.
"""Ensure that the given columns are always negative.

The transformation strategy works by creating a column with the
difference between ``low`` and 0 and then computing back the ``low``
value by subtracting the difference from 0 when reversing the transformation.
The transformation strategy works by creating columns with the
difference between zero and given columns then computing back the
necessary columns using the difference.

Args:
high (str or int):
Either the name of the column that contains the high value,
or a scalar that is the high value.
columns (str or list[str]):
The name of the column(s) that are constrained to be negative.
strict (bool):
Whether the comparison of the values should be strict ``>=`` or
not ``>`` when comparing them. Currently, this is only respected
Whether the comparison of the values should be strict, disclude
zero ``<`` or include it ``<=``. Currently, this is only respected
if ``reject_sampling`` or ``all`` handling strategies are used.
handling_strategy (str):
How this Constraint should be handled, which can be ``transform``
or ``reject_sampling``. Defaults to ``transform``.
drop (str):
Which column to drop during transformation. Can be ``'low'``
or ``None``.
drop (bool):
Whether to drop columns during transformation.
"""

def __init__(self, low, strict=False, handling_strategy='transform',
def __init__(self, columns, strict=False, handling_strategy='transform',
fit_columns_model=True, drop=None):
super().__init__(handling_strategy=handling_strategy,
fit_columns_model=fit_columns_model,
high=0, low=low, high_is_scalar=True,
high=0, low=columns, high_is_scalar=True,
low_is_scalar=False, drop=drop, strict=strict)


Expand Down
Loading