Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot fit twice if I modify transformers: ValueError: There are non-numerical values in your data. #1259

Closed
npatki opened this issue Feb 12, 2023 · 1 comment
Assignees
Labels
bug Something isn't working data:multi-table Related to multi-table, relational datasets
Milestone

Comments

@npatki
Copy link
Contributor

npatki commented Feb 12, 2023

Environment Details

  • SDV version: 1.0.0 (in progress)
  • Python version: 3.8
  • Operating System: Linux (Colab Notebook)

Error Description

If I modify the transformations in the HMASynthesizer, then I am unable to call fit more than once. I get a ValueError every time I try to fit a second time.

This may be related to #1258, as the Stack Trace seems to be similar.

Steps to reproduce

Observe that the first fit call succeeds but the second one fails.

from rdt.transformers.pii import PseudoAnonymizedFaker
from sdv.multi_table import HMASynthesizer
from sdv.datasets.demo import download_demo

real_data, metadata = download_demo(
    modality='multi_table',
    dataset_name='fake_hotels'
)

synthesizer = HMASynthesizer(metadata)
synthesizer.auto_assign_transformers(real_data)

address_pseudo_anonymizer = PseudoAnonymizedFaker(provider_name='address', function_name='address')

synthesizer.update_transformers(
    table_name='guests',
    column_name_to_transformer={
        'billing_address': address_pseudo_anonymizer
    }
)

synthesizer.fit(real_data)
print('FITTING 1: DONE')
synthesizer.fit(real_data)
print('FITTING 2: DONE')

Stack Trace

FITTING 1: DONE
/usr/local/lib/python3.8/dist-packages/rdt/hyper_transformer.py:400: UserWarning: For this change to take effect, please refit your data using 'fit' or 'fit_transform'.
  warnings.warn(self._REFIT_MESSAGE)
/usr/local/lib/python3.8/dist-packages/sdv/single_table/base.py:398: UserWarning: This model has already been fitted. To use the new preprocessed data, please refit the model using 'fit' or 'fit_processed_data'.
  warnings.warn(
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-11-8c09db8697b6>](https://localhost:8080/#) in <module>
     26 synthesizer.fit(real_data)
     27 print('FITTING 1: DONE')
---> 28 synthesizer.fit(real_data)
     29 print('FITTING 2: DONE')

8 frames
[/usr/local/lib/python3.8/dist-packages/sdv/multi_table/base.py](https://localhost:8080/#) in fit(self, data)
    324         self._fitted = False
    325         processed_data = self.preprocess(data)
--> 326         self.fit_processed_data(processed_data)
    327 
    328     def reset_sampling(self):

[/usr/local/lib/python3.8/dist-packages/sdv/multi_table/base.py](https://localhost:8080/#) in fit_processed_data(self, processed_data)
    309                 Dictionary mapping each table name to a preprocessed ``pandas.DataFrame``.
    310         """
--> 311         self._fit(processed_data)
    312         self._fitted = True
    313         self._fitted_date = datetime.datetime.today().strftime('%Y-%m-%d')

[/usr/local/lib/python3.8/dist-packages/sdv/multi_table/hma.py](https://localhost:8080/#) in _fit(self, processed_data)
    207         for table_name in processed_data:
    208             if not parent_map.get(table_name):
--> 209                 self._model_table(table_name, processed_data)
    210 
    211         LOGGER.info('Modeling Complete')

[/usr/local/lib/python3.8/dist-packages/sdv/multi_table/hma.py](https://localhost:8080/#) in _model_table(self, table_name, tables)
    180         self._table_sizes[table_name] = len(table)
    181 
--> 182         table = self._extend_table(table, tables, table_name)
    183         keys = self._pop_foreign_keys(table, table_name)
    184         self._clear_nans(table)

[/usr/local/lib/python3.8/dist-packages/sdv/multi_table/hma.py](https://localhost:8080/#) in _extend_table(self, table, tables, table_name)
    117         for child_name in self.metadata._get_child_map()[table_name]:
    118             if child_name not in self._modeled_tables:
--> 119                 child_table = self._model_table(child_name, tables)
    120             else:
    121                 child_table = tables[child_name]

[/usr/local/lib/python3.8/dist-packages/sdv/multi_table/hma.py](https://localhost:8080/#) in _model_table(self, table_name, tables)
    186                     table_name, table.shape)
    187 
--> 188         self._table_synthesizers[table_name].fit_processed_data(table)
    189 
    190         for name, values in keys.items():

[/usr/local/lib/python3.8/dist-packages/sdv/single_table/base.py](https://localhost:8080/#) in fit_processed_data(self, processed_data)
    419                 The transformed data used to fit the model to.
    420         """
--> 421         self._fit(processed_data)
    422         self._fitted = True
    423         self._fitted_date = datetime.datetime.today().strftime('%Y-%m-%d')

[/usr/local/lib/python3.8/dist-packages/sdv/single_table/copulas.py](https://localhost:8080/#) in _fit(self, processed_data)
    130         with warnings.catch_warnings():
    131             warnings.filterwarnings('ignore', module='scipy')
--> 132             self._model.fit(processed_data)
    133 
    134     def _warn_for_update_transformers(self, column_name_to_transformer):

[/usr/local/lib/python3.8/dist-packages/copulas/__init__.py](https://localhost:8080/#) in decorated(self, X, *args, **kwargs)
    251 
    252         if not (np.issubdtype(W.dtype, np.floating) or np.issubdtype(W.dtype, np.integer)):
--> 253             raise ValueError('There are non-numerical values in your data.')
    254 
    255         if np.isnan(W).any().any():

ValueError: There are non-numerical values in your data.
@npatki npatki added bug Something isn't working data:multi-table Related to multi-table, relational datasets labels Feb 12, 2023
@npatki npatki added this to the 1.0.0 milestone Feb 12, 2023
@frances-h
Copy link
Contributor

Hi @npatki, I looked into this more, and it seems like the problem is that running synthesizer.preprocess multiple times will only correctly process the data the first time its run. This is probably the same thing that's causing #1258, since any attempts to refit will get the same error if there's non-numerical data.

from rdt.transformers.pii import PseudoAnonymizedFaker
from sdv.multi_table import HMASynthesizer
from sdv.datasets.demo import download_demo

real_data, metadata = download_demo(
    modality='multi_table',
    dataset_name='fake_hotels'
)

synthesizer = HMASynthesizer(metadata)
synthesizer.auto_assign_transformers(real_data)

address_pseudo_anonymizer = PseudoAnonymizedFaker(provider_name='address', function_name='address')

synthesizer.update_transformers(
    table_name='guests',
    column_name_to_transformer={
        'billing_address': address_pseudo_anonymizer
    }
)

synthesizer.preprocess(real_data)['hotels']

city state rating classification
hotel_id
HID_000 0.996857 0.996857 4.8 0.996857
HID_001 0.080440 1.080440 4.1 1.080440
HID_002 1.284600 2.284600 3.8 2.284600
HID_003 1.535250 2.535250 4.0 1.535250
HID_004 2.845802 3.845802 3.7 2.845802
HID_005 3.200124 4.200124 4.6 0.200124
HID_006 4.812044 2.812044 4.9 0.812044
HID_007 4.372937 2.372937 4.3 2.372937
HID_008 3.809641 4.809641 4.2 1.809641
HID_009 0.011995 0.011995 4.6 0.011995

synthesizer.preprocess(real_data)['hotels']

hotel_id city state rating classification
hotel_id
HID_000 HID_000 Boston Massachusetts 4.8 RESORT
HID_001 HID_001 Boston Massachuesetts 4.1 CHAIN
HID_002 HID_002 San Francisco California 3.8 MOTEL
HID_003 HID_003 San Francisco California 4.0 CHAIN
HID_004 HID_004 New York City New York 3.7 MOTEL
HID_005 HID_005 Austin Texas 4.6 RESORT
HID_006 HID_006 Los Angeles California 4.9 RESORT
HID_007 HID_007 Los Angeles California NaN MOTEL
HID_008 HID_008 Austin Texas 4.2 CHAIN
HID_009 HID_009 Boston Massachusetts 4.6 RESORT

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working data:multi-table Related to multi-table, relational datasets
Projects
None yet
Development

No branches or pull requests

3 participants