Skip to content

Commit

Permalink
Create Unique Constraint (#540)
Browse files Browse the repository at this point in the history
* Create Unique Constraint

* adding unit tests

* making change in base

* fixing tests and adding docs

* pr comments

* adding constraint to init
  • Loading branch information
amontanez24 authored Aug 5, 2021
1 parent 0d4163c commit 8e3ae3b
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 4 deletions.
16 changes: 16 additions & 0 deletions docs/api_reference/constraints/tabular.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,19 @@ OneHotEncoding
OneHotEncoding.filter_valid
OneHotEncoding.from_dict
OneHotEncoding.to_dict

Unique
~~~~~~~~~~~~~~~~

.. autosummary::
:toctree: api/

Unique
Unique.fit
Unique.transform
Unique.fit_transform
Unique.reverse_transform
Unique.is_valid
Unique.filter_valid
Unique.from_dict
Unique.to_dict
5 changes: 3 additions & 2 deletions sdv/constraints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sdv.constraints.base import Constraint
from sdv.constraints.tabular import (
Between, ColumnFormula, CustomConstraint, GreaterThan, Negative, OneHotEncoding, Positive,
Rounding, UniqueCombinations)
Rounding, Unique, UniqueCombinations)

__all__ = [
'Constraint',
Expand All @@ -15,5 +15,6 @@
'Negative',
'Positive',
'Rounding',
'OneHotEncoding'
'OneHotEncoding',
'Unique'
]
37 changes: 37 additions & 0 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,3 +962,40 @@ def reverse_transform(self, table_data):
table_data[self._columns] = transformed_data

return table_data


class Unique(Constraint):
"""Ensure that each value for a specified column/group of columns is unique.
This constraint is provided a list of columns, and guarantees that every
unique combination of those columns appears at most once in the sampled
data.
Args:
columns (str or list[str]):
Name of the column(s) to keep unique.
"""

def __init__(self, columns):
self.columns = columns if isinstance(columns, list) else [columns]
super().__init__(handling_strategy='reject_sampling', fit_columns_model=False)

def is_valid(self, table_data):
"""Get indices of first instance of unique rows.
If a row is the first instance of that combination of column
values, it is valid. Otherwise it is false.
Args:
table_data (pandas.DataFrame):
Table data.
Returns:
pandas.Series:
Whether each row is valid.
"""
valid = pd.Series([False] * table_data.shape[0])
data = table_data.reset_index()
groups = data.groupby(self.columns)
valid.iloc[groups.first()['index'].values] = True
return valid
2 changes: 1 addition & 1 deletion sdv/tabular/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def _sample_rows(self, num_rows, conditions=None, transformed_conditions=None,
sampled = self._metadata.reverse_transform(sampled)

if previous_rows is not None:
sampled = previous_rows.append(sampled)
sampled = previous_rows.append(sampled, ignore_index=True)

sampled = self._metadata.filter_valid(sampled)

Expand Down
98 changes: 97 additions & 1 deletion tests/unit/constraints/test_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sdv.constraints.errors import MissingConstraintColumnError
from sdv.constraints.tabular import (
Between, ColumnFormula, CustomConstraint, GreaterThan, Negative, OneHotEncoding, Positive,
Rounding, UniqueCombinations)
Rounding, Unique, UniqueCombinations)


def dummy_transform():
Expand Down Expand Up @@ -3436,3 +3436,99 @@ def test_sample_constraint_columns_all_zeros_but_one(self):
'b': [1.0] * 10
})
pd.testing.assert_frame_equal(out, expected_output)


class TestUnique():

def test___init__(self):
"""Test the ``Unique.__init__`` method.
The ``columns`` should be set to those provided and the
``handling_strategy`` should be set to ``'reject_sampling'``.
Input:
- column names to keep unique.
Output:
- Instance with ``columns`` set and ``transform``
and ``reverse_transform`` methods set to ``instance._identity``.
"""
# Run
instance = Unique(columns=['a', 'b'])

# Assert
assert instance.columns == ['a', 'b']
assert instance.fit_columns_model is False
assert instance.transform == instance._identity
assert instance.reverse_transform == instance._identity

def test___init__one_column(self):
"""Test the ``Unique.__init__`` method.
The ``columns`` should be set to a list even if a string is
provided.
Input:
- string that is the name of a column.
Output:
- Instance with ``columns`` set to list of one element.
"""
# Run
instance = Unique(columns='a')

# Assert
assert instance.columns == ['a']

def test_is_valid(self):
"""Test the ``Unique.is_valid`` method.
This method should return a pd.Series where the index
of the first occurence of a unique combination of ``instance.columns``
is set to ``True``, and every other occurence is set to ``False``.
Input:
- DataFrame with multiple of the same combinations of columns.
Output:
- Series with the index of the first occurences set to ``True``.
"""
# Setup
instance = Unique(columns=['a', 'b', 'c'])

# Run
data = pd.DataFrame({
'a': [1, 1, 2, 2, 3, 4],
'b': [5, 5, 6, 6, 7, 8],
'c': [9, 9, 10, 10, 12, 13]
})
valid = instance.is_valid(data)

# Assert
expected = pd.Series([True, False, True, False, True, True])
pd.testing.assert_series_equal(valid, expected)

def test_is_valid_one_column(self):
"""Test the ``Unique.is_valid`` method.
This method should return a pd.Series where the index
of the first occurence of a unique value of ``self.columns``
is set to ``True``, and every other occurence is set to ``False``.
Input:
- DataFrame with multiple occurences of the same value of the
one column in ``instance.columns``.
Output:
- Series with the index of the first occurences set to ``True``.
"""
# Setup
instance = Unique(columns='a')

# Run
data = pd.DataFrame({
'a': [1, 1, 1, 2, 3, 2],
'b': [1, 2, 3, 4, 5, 6],
'c': [False, False, True, False, False, True]
})
valid = instance.is_valid(data)

# Assert
expected = pd.Series([True, False, False, True, True, False])
pd.testing.assert_series_equal(valid, expected)
45 changes: 45 additions & 0 deletions tests/unit/tabular/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,51 @@ def test_conditional_sampling_graceful_reject_sampling(model):
model.sample(5, conditions=conditions, graceful_reject_sampling=False)


def test__sample_rows_previous_rows_appended_correctly():
"""Test the ``BaseTabularModel._sample_rows`` method.
If ``_sample_rows`` is passed ``previous_rows``, then it
should reset the index when appending them to the new
sampled rows.
Input:
- num_rows is 5
- previous_rows is a DataFrame of 3 existing rows.
Output:
- 5 sampled rows with index set to [0, 1, 2, 3, 4]
"""
# Setup
model = GaussianCopula()
previous_data = pd.DataFrame({
'column1': [1, 2, 3],
'column2': [4, 5, 6],
'column3': [7, 8, 9]
})
new_data = pd.DataFrame({
'column1': [4, 5],
'column2': [7, 8],
'column3': [10, 11]
})
model._metadata = Mock()
model._sample = Mock()
model._sample.return_value = new_data
model._metadata.reverse_transform.return_value = new_data
model._metadata.filter_valid = lambda x: x

# Run
sampled, num_valid = model._sample_rows(5, previous_rows=previous_data)

# Assert
expected = pd.DataFrame({
'column1': [1, 2, 3, 4, 5],
'column2': [4, 5, 6, 7, 8],
'column3': [7, 8, 9, 10, 11]
})
assert num_valid == 5
pd.testing.assert_frame_equal(sampled, expected)


def test_sample_empty_transformed_conditions():
"""Test that None is passed to ``_sample_batch`` if transformed conditions are empty.
Expand Down

0 comments on commit 8e3ae3b

Please sign in to comment.