Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set HyperTransformer configuration manually #995

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ test-tutorials: ## run the tutorial notebooks
invoke tutorials

.PHONY: test
test: test-unit test-readme test-tutorials ## test everything that needs test dependencies
test: test-unit test-integration test-readme test-tutorials ## test everything that needs test dependencies

.PHONY: test-all
test-all: ## run tests on every Python version with tox
Expand Down
38 changes: 35 additions & 3 deletions sdv/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from copulas.multivariate.gaussian import GaussianMultivariate
from copulas.univariate import GaussianUnivariate
from rdt import HyperTransformer
from rdt.transformers import OneHotEncoder
from rdt.transformers import BinaryEncoder, FloatFormatter, OneHotEncoder, UnixTimestampEncoder

from sdv.constraints.errors import MissingConstraintColumnError
from sdv.errors import ConstraintsNotMetError
Expand Down Expand Up @@ -313,6 +313,36 @@ def __init__(self, constraint, constraint_columns):

self.constraint = constraint

@staticmethod
def _get_hyper_transformer_config(data_to_model):
sdtypes = {}
transformers = {}
for column_name, data in data_to_model.items():
dtype = data.dropna().infer_objects().dtype.kind
if dtype in ('i', 'f'):
sdtypes[column_name] = 'numerical'
transformers = FloatFormatter(
missing_value_replacement='mean',
model_missing_values=True,
)
elif dtype == 'O':
sdtypes[column_name] = 'categorical'
transformers[column_name] = OneHotEncoder
elif dtype == 'M':
sdtypes[column_name] = 'datetime'
transformers[column_name] = UnixTimestampEncoder(
missing_value_replacement='mean',
model_missing_values=True,
)
elif dtype == 'b':
sdtypes[column_name] = 'boolean'
transformers[column_name] = BinaryEncoder(
missing_value_replacement=-1,
model_missing_values=True
)

return {'sdtypes': sdtypes, 'transformers': transformers}

def fit(self, table_data):
"""Fit the ``ColumnsModel``.

Expand All @@ -324,10 +354,12 @@ def fit(self, table_data):
Table data.
"""
data_to_model = table_data[self.constraint_columns]
ht_config = self._get_hyper_transformer_config(data_to_model)

self._hyper_transformer = HyperTransformer()
self._hyper_transformer.detect_initial_config(data_to_model)
self._hyper_transformer.update_transformers_by_sdtype({'categorical': OneHotEncoder})
self._hyper_transformer.set_config(ht_config)
transformed_data = self._hyper_transformer.fit_transform(data_to_model)

self._model = GaussianMultivariate(distribution=GaussianUnivariate)
self._model.fit(transformed_data)

Expand Down
47 changes: 0 additions & 47 deletions sdv/metadata/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import numpy as np
import pandas as pd
from rdt import HyperTransformer

from sdv.constraints import Constraint
from sdv.metadata import visualization
Expand Down Expand Up @@ -72,7 +71,6 @@ class Metadata:
"""

_child_map = None
_hyper_transformers = None
_metadata = None
_parent_map = None

Expand Down Expand Up @@ -181,7 +179,6 @@ def __init__(self, metadata=None, root_path=None):
else:
self._metadata = {'tables': {}}

self._hyper_transformers = dict()
self._analyze_relationships()

def get_children(self, table_name):
Expand Down Expand Up @@ -401,50 +398,6 @@ def get_dtypes(self, table_name, ids=False, errors=None):

return dtypes

def transform(self, table_name, data):
"""Transform data for a given table.

If the ``HyperTransformer`` for a table is ``None`` it is created.

Args:
table_name (str):
Name of the table that is being transformer.
data (pandas.DataFrame):
Table data.

Returns:
pandas.DataFrame:
Transformed data.
"""
hyper_transformer = self._hyper_transformers.get(table_name)
if hyper_transformer is None:
hyper_transformer = HyperTransformer()
hyper_transformer.detect_initial_config(data)
hyper_transformer.fit(data)
self._hyper_transformers[table_name] = hyper_transformer

return hyper_transformer.transform(data)

def reverse_transform(self, table_name, data):
"""Reverse the transformed data for a given table.

Args:
table_name (str):
Name of the table to reverse transform.
data (pandas.DataFrame):
Data to be reversed.

Returns:
pandas.DataFrame
"""
hyper_transformer = self._hyper_transformers[table_name]
reversed_data = hyper_transformer.reverse_transform(data)

for name, dtype in self.get_dtypes(table_name, ids=True).items():
reversed_data[name] = reversed_data[name].dropna().astype(dtype)

return reversed_data

# ################### #
# Metadata Validation #
# ################### #
Expand Down
29 changes: 17 additions & 12 deletions sdv/metadata/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def _build_fields_metadata(self, data):

return fields_metadata

def _get_transformers(self, dtypes):
def _get_hypertransformer_config(self, dtypes):
"""Create the transformer instances needed to process the given dtypes.

Args:
Expand All @@ -398,15 +398,19 @@ def _get_transformers(self, dtypes):

Returns:
dict:
mapping of field names and transformer instances.
A dict containing the ``sdtypes`` and ``transformers`` config for the
``rdt.HyperTransformer``.
"""
transformers = dict()
sdtypes = dict()
for name, dtype in dtypes.items():
dtype = np.dtype(dtype).kind
field_metadata = self._fields_metadata.get(name, {})
transformer_template = field_metadata.get(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this take the metadata into account at all? I'm trying to figure out why it wasn't before and what has changed now. For example in the student_placements demo it wasn't making the duration column categorical even though the metadata said it was

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Old code was reading properly the dtype but was printing the wrong config which leaded to confusions as the transformers were updated after the detect. Now both sdtypes and transformers are being set corresponding to the dtype from the metadata with the set_config method.

'transformer', self._dtype_transformers[np.dtype(dtype).kind])
'transformer', self._dtype_transformers[dtype])

if transformer_template is None:
sdtypes[name] = self._DTYPES_TO_TYPES.get(dtype, {}).get('type', 'categorical')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a unit test that checks this case (when transformer_template is None)?

transformers[name] = None
continue

Expand All @@ -422,8 +426,9 @@ def _get_transformers(self, dtypes):
LOGGER.debug('Loading transformer %s for field %s',
transformer.__class__.__name__, name)
transformers[name] = transformer
sdtypes[name] = self._DTYPES_TO_TYPES.get(dtype, {}).get('type', 'categorical')

return transformers
return {'sdtypes': sdtypes, 'transformers': transformers}

def _fit_constraints(self, data):
errors = []
Expand Down Expand Up @@ -510,20 +515,20 @@ def _fit_hyper_transformer(self, data, extra_columns):
else:
dtypes[column] = dtype_kind

transformers_dict = self._get_transformers(dtypes)
ht_config = self._get_hypertransformer_config(dtypes)
for column in numerical_extras:
transformers_dict[column] = rdt.transformers.FloatFormatter(
dtypes[column] = 'numerical'
ht_config['sdtypes'][column] = 'numerical'
ht_config['transformers'][column] = rdt.transformers.FloatFormatter(
missing_value_replacement='mean',
model_missing_values=True,
)

self._hyper_transformer = rdt.HyperTransformer()
self._hyper_transformer.detect_initial_config(data)
if transformers_dict:
self._hyper_transformer.update_transformers(transformers_dict)

if not data.empty:
self._hyper_transformer.fit(data)
self._hyper_transformer.set_config(ht_config)
fit_columns = list(dtypes)
if not data[fit_columns].empty:
self._hyper_transformer.fit(data[fit_columns])

@staticmethod
def _get_key_subtype(field_meta):
Expand Down
25 changes: 13 additions & 12 deletions sdv/tabular/copulagan.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class CopulaGAN(CTGAN):
DEFAULT_DISTRIBUTION = 'truncated_gaussian'
_field_distributions = None
_default_distribution = None
_ht = None
_hyper_transformer = None

def __init__(self, field_names=None, field_types=None, field_transformers=None,
anonymize_fields=None, primary_key=None, constraints=None, table_metadata=None,
Expand Down Expand Up @@ -171,7 +171,7 @@ def get_distributions(self):
"""
return {
transformer.column_prefix: transformer._univariate.to_dict()['type']
for transformer in self._ht._transformers_sequence
for transformer in self._hyper_transformer._transformers_sequence
if isinstance(transformer, GaussianNormalizer)
}

Expand All @@ -185,23 +185,24 @@ def _fit(self, table_data):
distributions = self._field_distributions
fields = self._metadata.get_fields()

sdtypes = {}
transformers = {}
for field in table_data:
field_name = field.replace('.value', '')

if field_name in fields and fields.get(
field_name,
dict(),
).get('type') != 'categorical':
field_sdtype = fields.get(field_name, {}).get('type')
if field_name in fields and field_sdtype != 'categorical':
sdtypes[field] = 'numerical'
transformers[field] = GaussianNormalizer(
model_missing_values=True,
distribution=distributions.get(field_name, self._default_distribution)
)
else:
sdtypes[field] = field_sdtype or 'categorical'
transformers[field] = None

self._ht = HyperTransformer()
self._ht.detect_initial_config(table_data)
self._ht.update_transformers(transformers)
table_data = self._ht.fit_transform(table_data)
self._hyper_transformer = HyperTransformer()
self._hyper_transformer.set_config({'transformers': transformers, 'sdtypes': sdtypes})
table_data = self._hyper_transformer.fit_transform(table_data[list(transformers)])

super()._fit(table_data)

Expand All @@ -221,4 +222,4 @@ def _sample(self, num_rows, conditions=None):
Sampled data.
"""
sampled = super()._sample(num_rows, conditions)
return self._ht.reverse_transform(sampled)
return self._hyper_transformer.reverse_transform(sampled)
64 changes: 59 additions & 5 deletions tests/unit/constraints/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pandas as pd
import pytest
from copulas.univariate import GaussianUnivariate
from rdt.transformers import OneHotEncoder

from sdv.constraints.base import (
ColumnsModel, Constraint, _get_qualified_name, _module_contains_callable_name, get_subclasses,
Expand Down Expand Up @@ -638,6 +637,64 @@ def test___init__list(self):
assert instance.constraint_columns == ['age', 'age_when_joined']
assert instance.constraint == constraint

@patch('sdv.constraints.base.FloatFormatter')
@patch('sdv.constraints.base.OneHotEncoder')
@patch('sdv.constraints.base.UnixTimestampEncoder')
@patch('sdv.constraints.base.BinaryEncoder')
def test__get_hyper_transformer_config(self, mock_binaryencoder, mock_unixtimestampencoder,
mock_onehotencoder, mock_floatformatter):
"""Test the ``_get_hyper_transformer_config``.

Test that the method ``_get_hyper_transformer_config`` returns the expected
``sdtypes`` and ``transformers`` for a given ``data_to_model``.

Setup:
- Create a ``pandas.DataFrame`` named ``data_to_model`` whit multiple ``dtypes``.

Mock:
- ``rdt.transformers`` from ``constraints.base``.

Input:
- ``data_to_model`` with boolean, datetime, numerical and categorical data.

Output:
- A dictionary containing ``sdtypes`` and ``transformers`` that match the expected
to the ``sdtypes``.
"""
# Setup
data_to_model = pd.DataFrame({
'age': [1, 2, 3],
'amount': [1, 2, None],
'name': [None, 'Doe', 'John Doe'],
'joindate': pd.to_datetime(['2021-02-05', None, '2021-12-21']),
'is_valid': [True, False, None],
})
age_float_formatter = Mock()
amount_float_formatter = Mock()
mock_floatformatter.side_effects = [age_float_formatter, amount_float_formatter]

# Run
ht_config = ColumnsModel._get_hyper_transformer_config(data_to_model)

# Assert
ht_config == {
'sdtypes': {
'age': 'numerical',
'amount': 'numerical',
'name': 'categorical',
'joindate': 'datetime',
'is_valid': 'boolean'
},
'transformers': {
'age': age_float_formatter,
'amount': amount_float_formatter,
'name': mock_onehotencoder,
'joindate': mock_unixtimestampencoder.return_value,
'is_valid': mock_binaryencoder.return_value
}

}

@patch('sdv.constraints.base.GaussianMultivariate')
@patch('sdv.constraints.base.HyperTransformer')
def test_fit(self, mock_hyper_transformer, mock_gaussian_multivariate):
Expand Down Expand Up @@ -677,10 +734,7 @@ def test_fit(self, mock_hyper_transformer, mock_gaussian_multivariate):

# Assert
mock_hyper_transformer.assert_called_once_with()
mock_hyper_transformer.return_value.detect_initial_config.assert_called_once()
mock_hyper_transformer.return_value.update_transformers_by_sdtype.assert_called_once_with(
{'categorical': OneHotEncoder}
)
mock_hyper_transformer.return_value.set_config.assert_called_once()
call_data = mock_hyper_transformer.return_value.fit_transform.call_args[0][0]
pd.testing.assert_frame_equal(table_data[['age', 'age_when_joined']], call_data)

Expand Down
Loading