Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Multi-Table modeling #1403

Merged
merged 3 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,15 +273,6 @@ def update_transformers(self, table_name, column_name_to_transformer):
self._validate_table_name(table_name)
self._table_synthesizers[table_name].update_transformers(column_name_to_transformer)

def _fit(self, processed_data):
"""Fit the model to the tables.

Args:
processed_data (dict):
Dictionary mapping each table name to a preprocessed ``pandas.DataFrame``.
"""
raise NotImplementedError()

def preprocess(self, data):
"""Transform the raw data to numerical space.

Expand All @@ -308,14 +299,33 @@ def preprocess(self, data):

return processed_data

def _model_tables(self, augmented_data):
"""Model the augmented tables.

Args:
augmented_data (dict):
Dictionary mapping each table name to an augmented ``pandas.DataFrame``.
"""
raise NotImplementedError()

def _augment_tables(self, processed_data):
"""Augment the processed data.

Args:
processed_data (dict):
Dictionary mapping each table name to a preprocessed ``pandas.DataFrame``.
"""
raise NotImplementedError()

def fit_processed_data(self, processed_data):
"""Fit this model to the transformed data.

Args:
processed_data (dict):
Dictionary mapping each table name to a preprocessed ``pandas.DataFrame``.
"""
self._fit(processed_data.copy())
augmented_data = self._augment_tables(processed_data)
self._model_tables(augmented_data)
self._fitted = True
self._fitted_date = datetime.datetime.today().strftime('%Y-%m-%d')
self._fitted_sdv_version = pkg_resources.get_distribution('sdv').version
Expand Down
87 changes: 39 additions & 48 deletions sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, metadata, locales=None, synthesizer_kwargs=None):
self._synthesizer_kwargs = synthesizer_kwargs or self.DEFAULT_SYNTHESIZER_KWARGS
self._table_sizes = {}
self._max_child_rows = {}
self._modeled_tables = []
self._augmented_tables = []
for table_name in self.metadata.tables:
self.set_table_parameters(table_name, self._synthesizer_kwargs)

Expand Down Expand Up @@ -96,6 +96,17 @@ def _get_extension(self, child_name, child_table, foreign_key):

return pd.DataFrame(extension_rows, index=index)

@staticmethod
def _clear_nans(table_data):
for column in table_data.columns:
column_data = table_data[column]
if column_data.dtype in (int, float):
fill_value = 0 if column_data.isna().all() else column_data.mean()
else:
fill_value = column_data.mode()[0]

table_data[column] = table_data[column].fillna(fill_value)

def _get_foreign_keys(self, table_name, child_name):
foreign_keys = []
for relation in self.metadata.relationships:
Expand All @@ -105,7 +116,7 @@ def _get_foreign_keys(self, table_name, child_name):

return foreign_keys

def _extend_table(self, table, tables, table_name):
def _augment_table(self, table, tables, table_name):
"""Generate the extension columns for this table.

For each of the table's foreign keys, generate the related extension columns,
Expand All @@ -123,10 +134,12 @@ def _extend_table(self, table, tables, table_name):
pandas.DataFrame:
The extended table.
"""
self._table_sizes[table_name] = len(table)
LOGGER.info('Computing extensions for table %s', table_name)
for child_name in self.metadata._get_child_map()[table_name]:
if child_name not in self._modeled_tables:
child_table = self._model_table(child_name, tables)
if child_name not in self._augmented_tables:
child_table = self._augment_table(tables[child_name], tables, child_name)

else:
child_table = tables[child_name]

Expand All @@ -137,7 +150,10 @@ def _extend_table(self, table, tables, table_name):
num_rows_key = f'__{child_name}__{foreign_key}__num_rows'
table[num_rows_key] = table[num_rows_key].fillna(0)
self._max_child_rows[num_rows_key] = table[num_rows_key].max()
tables[table_name] = table

self._augmented_tables.append(table_name)
self._clear_nans(table)
return table

def _pop_foreign_keys(self, table_data, table_name):
Expand All @@ -160,66 +176,41 @@ def _pop_foreign_keys(self, table_data, table_name):

return keys

@staticmethod
def _clear_nans(table_data):
for column in table_data.columns:
column_data = table_data[column]
if column_data.dtype in (int, float):
fill_value = 0 if column_data.isna().all() else column_data.mean()
else:
fill_value = column_data.mode()[0]

table_data[column] = table_data[column].fillna(fill_value)

def _model_table(self, table_name, tables):
"""Model the indicated table and its children.
def _model_tables(self, augmented_data):
"""Model the augmented tables.

Args:
table_name (str):
Name of the table to model.
tables (dict):
Dict of original tables.

Returns:
pandas.DataFrame:
table data with the extensions created while modeling its children.
augmented_data (dict):
Dictionary mapping each table name to an augmented ``pandas.DataFrame``.
"""
LOGGER.info('Modeling %s', table_name)
for table_name, table in augmented_data.items():
keys = self._pop_foreign_keys(table, table_name)
self._clear_nans(table)
LOGGER.info('Fitting %s for table %s; shape: %s', self._synthesizer.__name__,
table_name, table.shape)

table = tables.get(table_name)
self._table_sizes[table_name] = len(table)

table = self._extend_table(table, tables, table_name)
keys = self._pop_foreign_keys(table, table_name)
self._clear_nans(table)
LOGGER.info('Fitting %s for table %s; shape: %s', self._synthesizer.__name__,
table_name, table.shape)
if not table.empty:
self._table_synthesizers[table_name].fit_processed_data(table)

if not table.empty:
self._table_synthesizers[table_name].fit_processed_data(table)

for name, values in keys.items():
table[name] = values

tables[table_name] = table
self._modeled_tables.append(table_name)

return table
for name, values in keys.items():
table[name] = values

def _fit(self, processed_data):
def _augment_tables(self, processed_data):
"""Fit this ``HMASynthesizer`` instance to the dataset data.

Args:
processed_data (dict):
Dictionary mapping each table name to a preprocessed ``pandas.DataFrame``.
"""
self._modeled_tables = []
augmented_data = deepcopy(processed_data)
self._augmented_tables = []
parent_map = self.metadata._get_parent_map()
for table_name in processed_data:
if not parent_map.get(table_name):
self._model_table(table_name, processed_data)
self._augment_table(augmented_data[table_name], augmented_data, table_name)

LOGGER.info('Modeling Complete')
LOGGER.info('Augmentation Complete')
return augmented_data

def _finalize(self, sampled_data):
"""Do the final touches to the generated data.
Expand Down
15 changes: 10 additions & 5 deletions tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,8 @@ def test_update_transformers_missing_table(self):
with pytest.raises(InvalidDataError, match=err_msg):
instance.update_transformers('not_seen', {})

def test__fit(self):
"""Test that ``_fit`` raises a ``NotImplementedError``."""
def test__model_tables(self):
"""Test that ``_model_tables`` raises a ``NotImplementedError``."""
# Setup
metadata = get_multi_table_metadata()
instance = BaseMultiTableSynthesizer(metadata)
Expand All @@ -505,7 +505,7 @@ def test__fit(self):

# Run and Assert
with pytest.raises(NotImplementedError, match=''):
instance._fit(data)
instance._model_tables(data)

def test__assign_table_transformers(self):
"""Test the ``_assign_table_transformers`` method.
Expand Down Expand Up @@ -642,7 +642,11 @@ def test_preprocess_warning(self, mock_warnings):
)

def test_fit_processed_data(self):
"""Test that fit processed data calls ``_fit`` and sets ``_fitted`` to ``True``."""
"""Test that fit processed data calls ``_augment_tables`` and ``_model_tables``.

Ensure that the ``fit_processed_data`` augments the tables and then models those using
the ``_model_tables`` method. Then sets the state to fitted.
"""
# Setup
instance = Mock()
data = Mock()
Expand All @@ -652,7 +656,8 @@ def test_fit_processed_data(self):
BaseMultiTableSynthesizer.fit_processed_data(instance, data)

# Assert
instance._fit.assert_called_once_with(data)
instance._augment_tables.assert_called_once_with(data)
instance._model_tables.assert_called_once_with(instance._augment_tables.return_value)
assert instance._fitted

def test_fit(self):
Expand Down
78 changes: 39 additions & 39 deletions tests/unit/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def test__get_all_foreign_keys(self):
# Assert
assert result == ['upravna_enota']

def test__extend_table(self):
"""Test that ``extend_table`` extends the current table with extra columns.
def test__augment_table(self):
"""Test that ``augment_table`` extends the current table with extra columns.

This also updates ``self._modeled_tables`` and ``self._max_child_rows``.
This also updates ``self._augmented_tables`` and ``self._max_child_rows``.
"""
# Setup
metadata = get_multi_table_metadata()
Expand All @@ -128,7 +128,7 @@ def test__extend_table(self):
data['oseba']['oseba_value'] = [0, 1, 2, 3]

# Run
result = instance._extend_table(data['nesreca'], data, 'nesreca')
result = instance._augment_table(data['nesreca'], data, 'nesreca')

# Assert
expected_result = pd.DataFrame({
Expand All @@ -140,16 +140,16 @@ def test__extend_table(self):
'__oseba__id_nesreca__univariates__oseba_val__a': [1.] * 4,
'__oseba__id_nesreca__univariates__oseba_val__b': [1.] * 4,
'__oseba__id_nesreca__univariates__oseba_val__loc': [0., 1., 2., 3.],
'__oseba__id_nesreca__univariates__oseba_val__scale': [np.nan] * 4,
'__oseba__id_nesreca__univariates__oseba_val__scale': [0.] * 4,
'__oseba__id_nesreca__univariates__oseba_value__a': [1.] * 4,
'__oseba__id_nesreca__univariates__oseba_value__b': [1.] * 4,
'__oseba__id_nesreca__univariates__oseba_value__loc': [0., 1., 2., 3.],
'__oseba__id_nesreca__univariates__oseba_value__scale': [np.nan] * 4,
'__oseba__id_nesreca__univariates__oseba_value__scale': [0.] * 4,
'__oseba__id_nesreca__num_rows': [1.] * 4,
})

pd.testing.assert_frame_equal(expected_result, result)
assert instance._modeled_tables == ['oseba']
assert instance._augmented_tables == ['oseba', 'nesreca']
assert instance._max_child_rows['__oseba__id_nesreca__num_rows'] == 1

def test__pop_foreign_keys(self):
Expand Down Expand Up @@ -189,40 +189,35 @@ def test__clear_nans(self):
})
pd.testing.assert_frame_equal(expected_data, data)

def test__model_table(self):
"""Test that ``_model_table`` performs the modeling.
def test__model_tables(self):
"""Test that ``_model_tables`` performs the modeling.

Modeling consists of getting the table for the given table name,
learning the size of this table, removing the foreign keys and clearing
any null values by using the ``_clear_nans`` method. Then, fitting the table model by
calling ``fit_processed_data``, adding back the foreign keys, updating the ``tables`` and
marking the table name as modeled within the ``instance._modeled_tables``.
marking the table name as modeled within the ``instance._augmented_tables``.
"""
# Setup
nesreca_model = Mock()
instance = Mock()
instance._synthesizer = GaussianCopulaSynthesizer

instance._modeled_tables = []
instance._table_sizes = {}
instance._augmented_tables = ['nesreca']
instance._table_sizes = {'nesreca': 3}
instance._table_synthesizers = {'nesreca': nesreca_model}
instance._pop_foreign_keys.return_value = {'fk': [1, 2, 3]}
data = pd.DataFrame({
'id_nesreca': [0, 1, 2],
'upravna_enota': [0, 1, 2]
})
extended_data = pd.DataFrame({
'id_nesreca': [0, 1, 2],
'upravna_enota': [0, 1, 2],
'extended': ['a', 'b', 'c']
})

instance._extend_table.return_value = extended_data

tables = {'nesreca': data}
input_data = {
'nesreca': pd.DataFrame({
'id_nesreca': [0, 1, 2],
'upravna_enota': [0, 1, 2],
'extended': ['a', 'b', 'c']
})
}
augmented_data = input_data.copy()

# Run
result = HMASynthesizer._model_table(instance, 'nesreca', tables)
HMASynthesizer._model_tables(instance, augmented_data)

# Assert
expected_result = pd.DataFrame({
Expand All @@ -231,31 +226,36 @@ def test__model_table(self):
'extended': ['a', 'b', 'c'],
'fk': [1, 2, 3]
})
pd.testing.assert_frame_equal(expected_result, result)
pd.testing.assert_frame_equal(expected_result, augmented_data['nesreca'])

instance._extend_table.assert_called_once_with(data, tables, 'nesreca')
instance._pop_foreign_keys.assert_called_once_with(extended_data, 'nesreca')
instance._clear_nans(extended_data)
nesreca_model.fit_processed_data.assert_called_once_with(extended_data)
instance._pop_foreign_keys.assert_called_once_with(input_data['nesreca'], 'nesreca')
instance._clear_nans.assert_called_once_with(input_data['nesreca'])
nesreca_model.fit_processed_data.assert_called_once_with(augmented_data['nesreca'])

assert instance._modeled_tables == ['nesreca']
assert instance._table_sizes == {'nesreca': 3}

def test__fit(self):
"""Test that ``_fit`` calls ``_model_table`` only if the table has no parents."""
def test__augment_tables(self):
"""Test that ``_fit`` calls ``_model_tables`` only if the table has no parents."""
# Setup
metadata = get_multi_table_metadata()
instance = HMASynthesizer(metadata)
instance._model_table = Mock()
instance._augment_table = Mock()
data = get_multi_table_data()
data['nesreca']['value'] = [0, 1, 2, 3]
data['oseba']['oseba_value'] = [0, 1, 2, 3]

# Run
instance._fit(data)
instance._augment_tables(data)

# Assert
instance._model_table.assert_called_once_with('upravna_enota', data)
call_table = instance._augment_table.call_args[0][0]
call_augmented_data = instance._augment_table.call_args[0][1]
call_table_name = instance._augment_table.call_args[0][2]

pd.testing.assert_frame_equal(call_table, data['upravna_enota'])
for input_table, orig_table in zip(call_augmented_data.values(), data.values()):
pd.testing.assert_frame_equal(input_table, orig_table)

assert list(call_augmented_data) == list(data)
assert call_table_name == 'upravna_enota'

def test__finalize(self):
"""Test that the finalize method applies the final touches to the generated data.
Expand Down