Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed May 26, 2022
1 parent d0913a3 commit e7ce5a5
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 42 deletions.
20 changes: 5 additions & 15 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,6 @@ def __init__(self, low_column_name, high_column_name, strict_boundaries=False,
def _get_data(self, table_data):
low = table_data[self._low_column_name].to_numpy()
high = table_data[self._high_column_name].to_numpy()

return low, high

def _get_is_datetime(self, table_data):
Expand Down Expand Up @@ -323,7 +322,7 @@ def _fit(self, table_data):
self._dtype = table_data[self._high_column_name].dtypes

def is_valid(self, table_data):
"""Say whether ``high`` is greater than ``low`` in each row.
"""Check whether ``high`` is greater than ``low`` in each row.
Args:
table_data (pandas.DataFrame):
Expand All @@ -335,7 +334,6 @@ def is_valid(self, table_data):
"""
low, high = self._get_data(table_data)
valid = np.isnan(low) | np.isnan(high) | self._operator(high, low)

return valid

def _transform(self, table_data):
Expand All @@ -358,12 +356,10 @@ def _transform(self, table_data):
table_data = table_data.copy()
low, high = self._get_data(table_data)
diff_column = high - low

if self._is_datetime:
diff_column = diff_column.astype(np.float64)

table_data[self._diff_column_name] = np.log(diff_column + 1)

return table_data.drop(self._high_column_name, axis=1)

def reverse_transform(self, table_data):
Expand All @@ -389,17 +385,15 @@ def reverse_transform(self, table_data):

low = table_data[self._low_column_name].to_numpy()
table_data[self._high_column_name] = pd.Series(diff_column + low).astype(self._dtype)

return table_data.drop(self._diff_column_name, axis=1)


class ScalarInequality(Constraint):
"""Ensure an inequality between the ``column_name`` column and a scalar ``value``.
The transformation works by creating a column with the difference between the
``column_name`` and ``value`` and storing it in the ``column_name``'s place.
The reverse transform adds the difference column and the ``value``
to reconstruct the ``column_name``.
The transformation works by creating a column with the difference between the ``column_name``
and ``value`` and storing it in the ``column_name``'s place. The reverse transform adds the
difference column and the ``value`` to reconstruct the ``column_name``.
Args:
column_name (str):
Expand All @@ -408,7 +402,7 @@ class ScalarInequality(Constraint):
Scalar value to compare.
relation (str):
Describes the relation between ``column_name`` and ``value``.
Choose one among ``>``, ``>=``, ``<``, ``<=``.
Choose one among ``'>'``, ``'>='``, ``'<'``, ``'<='``.
"""

@staticmethod
Expand Down Expand Up @@ -473,7 +467,6 @@ def is_valid(self, table_data):
"""
column = table_data[self._column_name].to_numpy()
valid = np.isnan(column) | self._operator(column, self._value)

return valid

def _transform(self, table_data):
Expand All @@ -496,12 +489,10 @@ def _transform(self, table_data):
table_data = table_data.copy()
column = table_data[self._column_name].to_numpy()
diff_column = abs(column - self._value)

if self._is_datetime:
diff_column = diff_column.astype(np.float64)

table_data[self._diff_column_name] = np.log(diff_column + 1)

return table_data.drop(self._column_name, axis=1)

def reverse_transform(self, table_data):
Expand All @@ -526,7 +517,6 @@ def reverse_transform(self, table_data):
diff_column = diff_column.astype('timedelta64[ns]')

table_data[self._column_name] = pd.Series(diff_column + self._value).astype(self._dtype)

return table_data.drop(self._diff_column_name, axis=1)


Expand Down
61 changes: 34 additions & 27 deletions tests/unit/constraints/test_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ def test__fit(self):
def test__fit_floats(self):
"""Test the ``Inequality._fit`` method.
The method should learn the ``dtype`` to be float when ``high_column_name`` contains floats.
The attribute ``_dtype`` should be float when ``high_column_name`` contains floats.
Input:
- Table data with floats.
Expand All @@ -950,7 +950,7 @@ def test__fit_floats(self):
def test__fit_datetime(self):
"""Test the ``Inequality._fit`` method.
The method should learn the ``dtype`` to be datetime when ``high_column_name`` contains datetimes.
The attribute ``_dtype`` should be datetime when ``high_column_name`` contains datetimes.
Input:
- Table data with datetimes.
Expand All @@ -973,11 +973,13 @@ def test__fit_datetime(self):
def test_is_valid(self):
"""Test the ``Inequality.is_valid`` method.
The method should return True when ``high_column_name`` column is greater or equal to
``low_column_name`` or the row contains nan, otherwise return False.
Input:
- Table with a mixture of valid and invalid rows, as well as np.nans.
Output:
- False should be returned for the strictly invalid rows and True
for the rest.
- False should be returned for the strictly invalid rows and True for the rest.
"""
# Setup
instance = Inequality(low_column_name='a', high_column_name='b')
Expand All @@ -997,11 +999,13 @@ def test_is_valid(self):
def test_is_valid_strict_boundaries_True(self):
"""Test the ``Inequality.is_valid`` method with ``strict_boundaries = True``.
The method should return True when ``high_column_name`` column is greater than
``low_column_name`` or the row contains nan, otherwise return False.
Input:
- Table with a mixture of valid and invalid rows, as well as np.nans.
Output:
- False should be returned for the non-strictly invalid rows and True
for the rest.
- False should be returned for the non-strictly invalid rows and True for the rest.
"""
# Setup
instance = Inequality(low_column_name='a', high_column_name='b', strict_boundaries=True)
Expand All @@ -1021,11 +1025,13 @@ def test_is_valid_strict_boundaries_True(self):
def test_is_valid_datetimes(self):
"""Test the ``Inequality.is_valid`` method with datetimes.
The method should return True when ``high_column_name`` column is greater or equal to
``low_column_name`` or the row contains nan, otherwise return False.
Input:
- Table with datetimes and np.nans.
Output:
- False should be returned for the strictly invalid rows and True
for the rest.
- False should be returned for the strictly invalid rows and True for the rest.
"""
# Setup
instance = Inequality(low_column_name='a', high_column_name='b')
Expand Down Expand Up @@ -1053,8 +1059,7 @@ def test__transform(self):
Input:
- Table with two columns at a constant distance of 3 and one additional dummy column.
Output:
- Same table with a diff column of the logarithms of the distances + 1,
which is np.log(4).
- Same table with a diff column of the log of distances + 1, which is np.log(4).
"""
# Setup
instance = Inequality(low_column_name='a', high_column_name='b')
Expand Down Expand Up @@ -1087,8 +1092,7 @@ def test__transform_datetime(self):
Input:
- Table with two datetime columns at a distance of 3 and one additional dummy column.
Output:
- Same table with a diff column of the logarithms of the distances + 1,
which is np.log(4).
- Same table with a diff column of the log of distances + 1, which is np.log(4).
"""
# Setup
instance = Inequality(low_column_name='a', high_column_name='b')
Expand Down Expand Up @@ -1127,8 +1131,7 @@ def test_reverse_transform(self):
Input:
- Table with a diff column that contains the constant np.log(4).
Output:
- Same table with the high column replaced by the low one + 3, as int
and the diff column dropped.
- Same table with the high column replaced by the low one + 3 with diff column dropped.
"""
# Setup
instance = Inequality(low_column_name='a', high_column_name='b')
Expand Down Expand Up @@ -1167,8 +1170,7 @@ def test_reverse_transform_floats(self):
Input:
- Table with a diff column that contains the constant np.log(4).
Output:
- Same table with the high column replaced by the low one + 3, as int
and the diff column dropped.
- Same table with the high column replaced by the low one + 3 with diff column dropped.
"""
# Setup
instance = Inequality(low_column_name='a', high_column_name='b')
Expand Down Expand Up @@ -1207,8 +1209,7 @@ def test_reverse_transform_datetime(self):
Input:
- Table with a diff column that contains the constant np.log(4).
Output:
- Same table with the high column replaced by the low one + 3, as int
and the diff column dropped.
- Same table with the high column replaced by the low one + 1sec with diff column dropped.
"""
# Setup
instance = Inequality(low_column_name='a', high_column_name='b')
Expand Down Expand Up @@ -1368,7 +1369,7 @@ def test__fit(self):
def test__fit_floats(self):
"""Test the ``ScalarInequality._fit`` method.
The method should learn the ``dtype`` to be float when ``column_name`` contains floats.
The attribute ``_dtype`` should be float when ``column_name`` contains floats.
Input:
- Table data with floats.
Expand All @@ -1391,7 +1392,7 @@ def test__fit_floats(self):
def test__fit_datetime(self):
"""Test the ``ScalarInequality._fit`` method.
The method should learn the ``dtype`` to be datetime when ``column_name`` contains datetimes.
The attribute ``_dtype`` should be datetime when ``column_name`` contains datetimes.
Input:
- Table data with datetimes.
Expand All @@ -1413,14 +1414,16 @@ def test__fit_datetime(self):
# Assert
assert instance._dtype == np.dtype('<M8[ns]')

def test_is_valid_greater(self):
def test_is_valid(self):
"""Test the ``ScalarInequality.is_valid`` method with ``relation = '>'``.
The method should return True when ``column_name`` is greater than
``value`` or the row contains nan, otherwise return False.
Input:
- Table with a mixture of valid and invalid rows, as well as np.nans.
Output:
- False should be returned for the strictly invalid rows and True
for the rest.
- False should be returned for the strictly invalid rows and True for the rest.
"""
# Setup
instance = ScalarInequality(column_name='b', value=2, relation='>')
Expand All @@ -1439,11 +1442,13 @@ def test_is_valid_greater(self):
def test_is_valid_datetimes(self):
"""Test the ``ScalarInequality.is_valid`` method with datetimes and ``relation = '<='``.
The method should return True when ``column_name`` is greater or equal to
``value`` or the row contains nan, otherwise return False.
Input:
- Table with datetimes and np.nans.
Output:
- False should be returned for the strictly invalid rows and True
for the rest.
- False should be returned for the strictly invalid rows and True for the rest.
"""
# Setup
instance = ScalarInequality(
Expand Down Expand Up @@ -1473,7 +1478,8 @@ def test__transform(self):
Input:
- Table data.
Output:
- Same table with a diff column of the logarithms of the distances + 1 in the ``column_name``'s place.
- Same table with a diff column of the log of the distances + 1
in the ``column_name``'s place.
"""
# Setup
instance = ScalarInequality(column_name='a', value=1, relation='>=')
Expand Down Expand Up @@ -1504,7 +1510,8 @@ def test__transform_datetime(self):
Input:
- Table data with datetimes.
Output:
- Same table with a diff column of the logarithms of the distances + 1 in the ``column_name``'s place.
- Same table with a diff column of the logarithms of the distances + 1
in the ``column_name``'s place.
"""
# Setup
instance = ScalarInequality(
Expand Down

0 comments on commit e7ce5a5

Please sign in to comment.