Skip to content

Commit

Permalink
Set HyperTransformer configuration manually (#995)
Browse files Browse the repository at this point in the history
* Change from detect to set config

* Add tests constraints/base

* Fix unit tests

* fix data empty only for fit columns

* Fix lint

* rename hp to ht

* Remove unused transform / reverse transform
  • Loading branch information
pvk-developer authored Sep 7, 2022
1 parent 7a70db6 commit 0af85b8
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 116 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ test-tutorials: ## run the tutorial notebooks
invoke tutorials

.PHONY: test
test: test-unit test-readme test-tutorials ## test everything that needs test dependencies
test: test-unit test-integration test-readme test-tutorials ## test everything that needs test dependencies

.PHONY: test-all
test-all: ## run tests on every Python version with tox
Expand Down
38 changes: 35 additions & 3 deletions sdv/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from copulas.multivariate.gaussian import GaussianMultivariate
from copulas.univariate import GaussianUnivariate
from rdt import HyperTransformer
from rdt.transformers import OneHotEncoder
from rdt.transformers import BinaryEncoder, FloatFormatter, OneHotEncoder, UnixTimestampEncoder

from sdv.constraints.errors import MissingConstraintColumnError
from sdv.errors import ConstraintsNotMetError
Expand Down Expand Up @@ -313,6 +313,36 @@ def __init__(self, constraint, constraint_columns):

self.constraint = constraint

@staticmethod
def _get_hyper_transformer_config(data_to_model):
sdtypes = {}
transformers = {}
for column_name, data in data_to_model.items():
dtype = data.dropna().infer_objects().dtype.kind
if dtype in ('i', 'f'):
sdtypes[column_name] = 'numerical'
transformers = FloatFormatter(
missing_value_replacement='mean',
model_missing_values=True,
)
elif dtype == 'O':
sdtypes[column_name] = 'categorical'
transformers[column_name] = OneHotEncoder
elif dtype == 'M':
sdtypes[column_name] = 'datetime'
transformers[column_name] = UnixTimestampEncoder(
missing_value_replacement='mean',
model_missing_values=True,
)
elif dtype == 'b':
sdtypes[column_name] = 'boolean'
transformers[column_name] = BinaryEncoder(
missing_value_replacement=-1,
model_missing_values=True
)

return {'sdtypes': sdtypes, 'transformers': transformers}

def fit(self, table_data):
"""Fit the ``ColumnsModel``.
Expand All @@ -324,10 +354,12 @@ def fit(self, table_data):
Table data.
"""
data_to_model = table_data[self.constraint_columns]
ht_config = self._get_hyper_transformer_config(data_to_model)

self._hyper_transformer = HyperTransformer()
self._hyper_transformer.detect_initial_config(data_to_model)
self._hyper_transformer.update_transformers_by_sdtype({'categorical': OneHotEncoder})
self._hyper_transformer.set_config(ht_config)
transformed_data = self._hyper_transformer.fit_transform(data_to_model)

self._model = GaussianMultivariate(distribution=GaussianUnivariate)
self._model.fit(transformed_data)

Expand Down
47 changes: 0 additions & 47 deletions sdv/metadata/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import numpy as np
import pandas as pd
from rdt import HyperTransformer

from sdv.constraints import Constraint
from sdv.metadata import visualization
Expand Down Expand Up @@ -72,7 +71,6 @@ class Metadata:
"""

_child_map = None
_hyper_transformers = None
_metadata = None
_parent_map = None

Expand Down Expand Up @@ -181,7 +179,6 @@ def __init__(self, metadata=None, root_path=None):
else:
self._metadata = {'tables': {}}

self._hyper_transformers = dict()
self._analyze_relationships()

def get_children(self, table_name):
Expand Down Expand Up @@ -401,50 +398,6 @@ def get_dtypes(self, table_name, ids=False, errors=None):

return dtypes

def transform(self, table_name, data):
"""Transform data for a given table.
If the ``HyperTransformer`` for a table is ``None`` it is created.
Args:
table_name (str):
Name of the table that is being transformer.
data (pandas.DataFrame):
Table data.
Returns:
pandas.DataFrame:
Transformed data.
"""
hyper_transformer = self._hyper_transformers.get(table_name)
if hyper_transformer is None:
hyper_transformer = HyperTransformer()
hyper_transformer.detect_initial_config(data)
hyper_transformer.fit(data)
self._hyper_transformers[table_name] = hyper_transformer

return hyper_transformer.transform(data)

def reverse_transform(self, table_name, data):
"""Reverse the transformed data for a given table.
Args:
table_name (str):
Name of the table to reverse transform.
data (pandas.DataFrame):
Data to be reversed.
Returns:
pandas.DataFrame
"""
hyper_transformer = self._hyper_transformers[table_name]
reversed_data = hyper_transformer.reverse_transform(data)

for name, dtype in self.get_dtypes(table_name, ids=True).items():
reversed_data[name] = reversed_data[name].dropna().astype(dtype)

return reversed_data

# ################### #
# Metadata Validation #
# ################### #
Expand Down
29 changes: 17 additions & 12 deletions sdv/metadata/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def _build_fields_metadata(self, data):

return fields_metadata

def _get_transformers(self, dtypes):
def _get_hypertransformer_config(self, dtypes):
"""Create the transformer instances needed to process the given dtypes.
Args:
Expand All @@ -398,15 +398,19 @@ def _get_transformers(self, dtypes):
Returns:
dict:
mapping of field names and transformer instances.
A dict containing the ``sdtypes`` and ``transformers`` config for the
``rdt.HyperTransformer``.
"""
transformers = dict()
sdtypes = dict()
for name, dtype in dtypes.items():
dtype = np.dtype(dtype).kind
field_metadata = self._fields_metadata.get(name, {})
transformer_template = field_metadata.get(
'transformer', self._dtype_transformers[np.dtype(dtype).kind])
'transformer', self._dtype_transformers[dtype])

if transformer_template is None:
sdtypes[name] = self._DTYPES_TO_TYPES.get(dtype, {}).get('type', 'categorical')
transformers[name] = None
continue

Expand All @@ -422,8 +426,9 @@ def _get_transformers(self, dtypes):
LOGGER.debug('Loading transformer %s for field %s',
transformer.__class__.__name__, name)
transformers[name] = transformer
sdtypes[name] = self._DTYPES_TO_TYPES.get(dtype, {}).get('type', 'categorical')

return transformers
return {'sdtypes': sdtypes, 'transformers': transformers}

def _fit_constraints(self, data):
errors = []
Expand Down Expand Up @@ -510,20 +515,20 @@ def _fit_hyper_transformer(self, data, extra_columns):
else:
dtypes[column] = dtype_kind

transformers_dict = self._get_transformers(dtypes)
ht_config = self._get_hypertransformer_config(dtypes)
for column in numerical_extras:
transformers_dict[column] = rdt.transformers.FloatFormatter(
dtypes[column] = 'numerical'
ht_config['sdtypes'][column] = 'numerical'
ht_config['transformers'][column] = rdt.transformers.FloatFormatter(
missing_value_replacement='mean',
model_missing_values=True,
)

self._hyper_transformer = rdt.HyperTransformer()
self._hyper_transformer.detect_initial_config(data)
if transformers_dict:
self._hyper_transformer.update_transformers(transformers_dict)

if not data.empty:
self._hyper_transformer.fit(data)
self._hyper_transformer.set_config(ht_config)
fit_columns = list(dtypes)
if not data[fit_columns].empty:
self._hyper_transformer.fit(data[fit_columns])

@staticmethod
def _get_key_subtype(field_meta):
Expand Down
25 changes: 13 additions & 12 deletions sdv/tabular/copulagan.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class CopulaGAN(CTGAN):
DEFAULT_DISTRIBUTION = 'truncated_gaussian'
_field_distributions = None
_default_distribution = None
_ht = None
_hyper_transformer = None

def __init__(self, field_names=None, field_types=None, field_transformers=None,
anonymize_fields=None, primary_key=None, constraints=None, table_metadata=None,
Expand Down Expand Up @@ -171,7 +171,7 @@ def get_distributions(self):
"""
return {
transformer.column_prefix: transformer._univariate.to_dict()['type']
for transformer in self._ht._transformers_sequence
for transformer in self._hyper_transformer._transformers_sequence
if isinstance(transformer, GaussianNormalizer)
}

Expand All @@ -185,23 +185,24 @@ def _fit(self, table_data):
distributions = self._field_distributions
fields = self._metadata.get_fields()

sdtypes = {}
transformers = {}
for field in table_data:
field_name = field.replace('.value', '')

if field_name in fields and fields.get(
field_name,
dict(),
).get('type') != 'categorical':
field_sdtype = fields.get(field_name, {}).get('type')
if field_name in fields and field_sdtype != 'categorical':
sdtypes[field] = 'numerical'
transformers[field] = GaussianNormalizer(
model_missing_values=True,
distribution=distributions.get(field_name, self._default_distribution)
)
else:
sdtypes[field] = field_sdtype or 'categorical'
transformers[field] = None

self._ht = HyperTransformer()
self._ht.detect_initial_config(table_data)
self._ht.update_transformers(transformers)
table_data = self._ht.fit_transform(table_data)
self._hyper_transformer = HyperTransformer()
self._hyper_transformer.set_config({'transformers': transformers, 'sdtypes': sdtypes})
table_data = self._hyper_transformer.fit_transform(table_data[list(transformers)])

super()._fit(table_data)

Expand All @@ -221,4 +222,4 @@ def _sample(self, num_rows, conditions=None):
Sampled data.
"""
sampled = super()._sample(num_rows, conditions)
return self._ht.reverse_transform(sampled)
return self._hyper_transformer.reverse_transform(sampled)
64 changes: 59 additions & 5 deletions tests/unit/constraints/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pandas as pd
import pytest
from copulas.univariate import GaussianUnivariate
from rdt.transformers import OneHotEncoder

from sdv.constraints.base import (
ColumnsModel, Constraint, _get_qualified_name, _module_contains_callable_name, get_subclasses,
Expand Down Expand Up @@ -638,6 +637,64 @@ def test___init__list(self):
assert instance.constraint_columns == ['age', 'age_when_joined']
assert instance.constraint == constraint

@patch('sdv.constraints.base.FloatFormatter')
@patch('sdv.constraints.base.OneHotEncoder')
@patch('sdv.constraints.base.UnixTimestampEncoder')
@patch('sdv.constraints.base.BinaryEncoder')
def test__get_hyper_transformer_config(self, mock_binaryencoder, mock_unixtimestampencoder,
mock_onehotencoder, mock_floatformatter):
"""Test the ``_get_hyper_transformer_config``.
Test that the method ``_get_hyper_transformer_config`` returns the expected
``sdtypes`` and ``transformers`` for a given ``data_to_model``.
Setup:
- Create a ``pandas.DataFrame`` named ``data_to_model`` whit multiple ``dtypes``.
Mock:
- ``rdt.transformers`` from ``constraints.base``.
Input:
- ``data_to_model`` with boolean, datetime, numerical and categorical data.
Output:
- A dictionary containing ``sdtypes`` and ``transformers`` that match the expected
to the ``sdtypes``.
"""
# Setup
data_to_model = pd.DataFrame({
'age': [1, 2, 3],
'amount': [1, 2, None],
'name': [None, 'Doe', 'John Doe'],
'joindate': pd.to_datetime(['2021-02-05', None, '2021-12-21']),
'is_valid': [True, False, None],
})
age_float_formatter = Mock()
amount_float_formatter = Mock()
mock_floatformatter.side_effects = [age_float_formatter, amount_float_formatter]

# Run
ht_config = ColumnsModel._get_hyper_transformer_config(data_to_model)

# Assert
ht_config == {
'sdtypes': {
'age': 'numerical',
'amount': 'numerical',
'name': 'categorical',
'joindate': 'datetime',
'is_valid': 'boolean'
},
'transformers': {
'age': age_float_formatter,
'amount': amount_float_formatter,
'name': mock_onehotencoder,
'joindate': mock_unixtimestampencoder.return_value,
'is_valid': mock_binaryencoder.return_value
}

}

@patch('sdv.constraints.base.GaussianMultivariate')
@patch('sdv.constraints.base.HyperTransformer')
def test_fit(self, mock_hyper_transformer, mock_gaussian_multivariate):
Expand Down Expand Up @@ -677,10 +734,7 @@ def test_fit(self, mock_hyper_transformer, mock_gaussian_multivariate):

# Assert
mock_hyper_transformer.assert_called_once_with()
mock_hyper_transformer.return_value.detect_initial_config.assert_called_once()
mock_hyper_transformer.return_value.update_transformers_by_sdtype.assert_called_once_with(
{'categorical': OneHotEncoder}
)
mock_hyper_transformer.return_value.set_config.assert_called_once()
call_data = mock_hyper_transformer.return_value.fit_transform.call_args[0][0]
pd.testing.assert_frame_equal(table_data[['age', 'age_when_joined']], call_data)

Expand Down
Loading

0 comments on commit 0af85b8

Please sign in to comment.