Skip to content

Commit

Permalink
Cannot sample after loading a model with custom constraint (#991)
Browse files Browse the repository at this point in the history
* Undo changes to CustomConstraint

* Change pickle to dill

* Update docs

* fix lint

* Lower dill version to suit 3.6 and improve tests

* Add missing comma

* Switch to cloudpickle

* Update tutorials links

* Fix cloudpickle version
  • Loading branch information
pvk-developer authored Sep 2, 2022
1 parent 2f62aa6 commit 7a70db6
Show file tree
Hide file tree
Showing 23 changed files with 131 additions and 131 deletions.
4 changes: 2 additions & 2 deletions docs/developer_guides/sdv/tabular.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ A part from the previous steps, the ``BaseTabularModel`` also offers a couple of
functionalities:

* ``get_metadata``: Returns the Table metadata object that has been fitted to the data.
* ``save``: Saves the complete Tabular Model in a file using ``pickle``.
* ``load``: Loads a previously saved model from a ``pickle`` file.
* ``save``: Saves the complete Tabular Model in a file using ``cloudpickle``.
* ``load``: Loads a previously saved model from a ``cloudpickle`` file.

Implementing a Tabular Model
----------------------------
Expand Down
2 changes: 1 addition & 1 deletion docs/user_guides/relational/hma1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ method passing the name of the file in which you want to save the model.
Note that the extension of the filename is not relevant, but we will be
using the ``.pkl`` extension to highlight that the serialization
protocol used is
`pickle <https://docs.python.org/3/library/pickle.html>`__.
`cloudpickle <https://github.com/cloudpipe/cloudpickle>`__.

.. ipython:: python
:okwarning:
Expand Down
2 changes: 1 addition & 1 deletion docs/user_guides/single_table/copulagan.rst
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ method passing the name of the file in which you want to save the model.
Note that the extension of the filename is not relevant, but we will be
using the ``.pkl`` extension to highlight that the serialization
protocol used is
`pickle <https://docs.python.org/3/library/pickle.html>`__.
`cloudpickle <https://github.com/cloudpipe/cloudpickle>`__.

.. ipython:: python
:okwarning:
Expand Down
2 changes: 1 addition & 1 deletion docs/user_guides/single_table/ctgan.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ method passing the name of the file in which you want to save the model.
Note that the extension of the filename is not relevant, but we will be
using the ``.pkl`` extension to highlight that the serialization
protocol used is
`pickle <https://docs.python.org/3/library/pickle.html>`__.
`cloudpickle <https://github.com/cloudpipe/cloudpickle>`__.

.. ipython:: python
:okwarning:
Expand Down
2 changes: 1 addition & 1 deletion docs/user_guides/single_table/gaussian_copula.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ method passing the name of the file in which you want to save the model.
Note that the extension of the filename is not relevant, but we will be
using the ``.pkl`` extension to highlight that the serialization
protocol used is
`pickle <https://docs.python.org/3/library/pickle.html>`__.
`cloudpickle <https://github.com/cloudpipe/cloudpickle>`__.

.. ipython:: python
:okwarning:
Expand Down
2 changes: 1 addition & 1 deletion docs/user_guides/single_table/tvae.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ method passing the name of the file in which you want to save the model.
Note that the extension of the filename is not relevant, but we will be
using the ``.pkl`` extension to highlight that the serialization
protocol used is
`pickle <https://docs.python.org/3/library/pickle.html>`__.
`cloudpickle <https://github.com/cloudpipe/cloudpickle>`__.

.. ipython:: python
:okwarning:
Expand Down
2 changes: 1 addition & 1 deletion docs/user_guides/timeseries/par.rst
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ method passing the name of the file in which you want to save the model.
Note that the extension of the filename is not relevant, but we will be
using the ``.pkl`` extension to highlight that the serialization
protocol used is
`pickle <https://docs.python.org/3/library/pickle.html>`__.
`cloudpickle <https://github.com/cloudpipe/cloudpickle>`__.

.. ipython:: python
:okwarning:
Expand Down
155 changes: 75 additions & 80 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,104 +85,99 @@ def create_custom_constraint(is_valid_fn, transform_fn=None, reverse_transform_f
"""
_validate_inputs_custom_constraint(is_valid_fn, transform_fn, reverse_transform_fn)

def constructor(self, column_names, **kwargs):
self.column_names = column_names
self.kwargs = kwargs
self.constraint_columns = tuple(column_names)

def is_valid(self, data):
"""Check whether the column values are valid.
class CustomConstraint(Constraint):
"""CustomConstraint class.
Args:
table_data (pandas.DataFrame):
Table data.
Returns:
pandas.Series:
Whether each row is valid.
transform (callable):
Function to replace the ``transform`` method.
reverse_transform (callable):
Function to replace the ``reverse_transform`` method.
is_valid (callable):
Function to replace the ``is_valid`` method.
"""
valid = is_valid_fn(self.column_names, data, **self.kwargs)
if len(valid) != data.shape[0]:
raise InvalidFunctionError(
'`is_valid_fn` did not produce exactly 1 True/False value for each row.')

if not isinstance(valid, pd.Series):
raise ValueError(
"The custom 'is_valid' function returned an unsupported type. "
'The returned object must be a pandas.Series'
)
def __init__(self, column_names, **kwargs):
self.column_names = column_names
self.kwargs = kwargs
self.constraint_columns = tuple(column_names)

return valid
def is_valid(self, data):
"""Check whether the column values are valid.
def transform(self, data):
"""Transform the table data.
Args:
table_data (pandas.DataFrame):
Table data.
Args:
table_data (pandas.DataFrame):
Table data.
Returns:
pandas.Series:
Whether each row is valid.
"""
valid = is_valid_fn(self.column_names, data, **self.kwargs)
if len(valid) != data.shape[0]:
raise InvalidFunctionError(
'`is_valid_fn` did not produce exactly 1 True/False value for each row.')

Returns:
pandas.DataFrame:
Transformed data.
"""
data = data.copy()
if transform_fn is None:
return data
if not isinstance(valid, pd.Series):
raise ValueError(
"The custom 'is_valid' function returned an unsupported type. "
'The returned object must be a pandas.Series'
)

try:
transformed_data = transform_fn(self.column_names, data, **self.kwargs)
if data.shape[0] != transformed_data.shape[0]:
raise InvalidFunctionError(
'Transformation did not produce the same number of rows as the original')
return valid

self.reverse_transform(transformed_data.copy())
return transformed_data
def transform(self, data):
"""Transform the table data.
except InvalidFunctionError as e:
raise e
Args:
table_data (pandas.DataFrame):
Table data.
except Exception:
raise FunctionError
Returns:
pandas.DataFrame:
Transformed data.
"""
data = data.copy()
if transform_fn is None:
return data

def reverse_transform(self, data):
"""Reverse transform the table data.
try:
transformed_data = transform_fn(self.column_names, data, **self.kwargs)
if data.shape[0] != transformed_data.shape[0]:
raise InvalidFunctionError(
'Transformation did not produce the same number of rows as the original')

Args:
table_data (pandas.DataFrame):
Table data.
self.reverse_transform(transformed_data.copy())
return transformed_data

Returns:
pandas.DataFrame:
Transformed data.
"""
data = data.copy()
if reverse_transform_fn is None:
return data

transformed_data = reverse_transform_fn(self.column_names, data, **self.kwargs)
if data.shape[0] != transformed_data.shape[0]:
raise InvalidFunctionError(
'Reverse transform did not produce the same number of rows as the original.'
)
except InvalidFunctionError as e:
raise e

return transformed_data
except Exception:
raise FunctionError

def _reduce(self):
"""Overwrite ``__reduce__`` function.
def reverse_transform(self, data):
"""Reverse transform the table data.
The ``__reduce__`` function returns a tuple with the method that creates this class,
and a tuple with the arguments that this function takes.
"""
return (create_custom_constraint, (is_valid_fn, transform_fn, reverse_transform_fn))

# Dynamic Class Creation
CustomConstraint = type('CustomConstraint', (Constraint, ), {
'__init__': constructor,
'__reduce__': _reduce,
'is_valid': is_valid,
'transform': transform,
'reverse_transform': reverse_transform,
})
Args:
table_data (pandas.DataFrame):
Table data.
Returns:
pandas.DataFrame:
Transformed data.
"""
data = data.copy()
if reverse_transform_fn is None:
return data

transformed_data = reverse_transform_fn(self.column_names, data, **self.kwargs)
if data.shape[0] != transformed_data.shape[0]:
raise InvalidFunctionError(
'Reverse transform did not produce the same number of rows as the original.'
)

return transformed_data

return CustomConstraint

Expand Down
8 changes: 4 additions & 4 deletions sdv/lite/tabular.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Base class for tabular model presets."""

import logging
import pickle
import sys
import warnings

import cloudpickle
import numpy as np
import rdt

Expand Down Expand Up @@ -233,7 +233,7 @@ def sample_remaining_columns(self, known_columns, max_tries_per_batch=100, batch
return self._postprocess_sampled(sampled)

def save(self, path):
"""Save this model instance to the given path using pickle.
"""Save this model instance to the given path using cloudpickle.
Args:
path (str):
Expand All @@ -242,7 +242,7 @@ def save(self, path):
self._package_versions = get_package_versions(getattr(self, '_model', None))

with open(path, 'wb') as output:
pickle.dump(self, output)
cloudpickle.dump(self, output)

@classmethod
def load(cls, path):
Expand All @@ -257,7 +257,7 @@ def load(cls, path):
The loaded tabular model.
"""
with open(path, 'rb') as f:
model = pickle.load(f)
model = cloudpickle.load(f)
throw_version_mismatch_warning(getattr(model, '_package_versions', None))

return model
Expand Down
8 changes: 4 additions & 4 deletions sdv/relational/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import itertools
import logging
import pickle

import cloudpickle
import numpy as np
import pandas as pd

Expand Down Expand Up @@ -185,7 +185,7 @@ def sample(self, table_name=None, num_rows=None,
return self._sample(table_name, num_rows, sample_children)

def save(self, path):
"""Save this instance to the given path using pickle.
"""Save this instance to the given path using cloudpickle.
Args:
path (str):
Expand All @@ -194,7 +194,7 @@ def save(self, path):
self._package_versions = get_package_versions(getattr(self, '_model', None))

with open(path, 'wb') as output:
pickle.dump(self, output)
cloudpickle.dump(self, output)

@classmethod
def load(cls, path):
Expand All @@ -205,7 +205,7 @@ def load(cls, path):
Path from which to load the instance.
"""
with open(path, 'rb') as f:
model = pickle.load(f)
model = cloudpickle.load(f)
throw_version_mismatch_warning(getattr(model, '_package_versions', None))

return model
9 changes: 5 additions & 4 deletions sdv/sdv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

"""Main SDV module."""

import pickle
import warnings

import cloudpickle

from sdv.errors import NotFittedError
from sdv.relational.hma import HMA1
from sdv.tabular.copulas import GaussianCopula
Expand Down Expand Up @@ -142,7 +143,7 @@ def sample_all(self, num_rows=None, reset_primary_keys=False):
return self.sample(num_rows=num_rows, reset_primary_keys=reset_primary_keys)

def save(self, path):
"""Save this SDV instance to the given path using pickle.
"""Save this SDV instance to the given path using cloudpickle.
Args:
path (str):
Expand All @@ -151,7 +152,7 @@ def save(self, path):
self._package_versions = get_package_versions(getattr(self, '_model', None))

with open(path, 'wb') as output:
pickle.dump(self, output)
cloudpickle.dump(self, output)

@classmethod
def load(cls, path):
Expand All @@ -162,7 +163,7 @@ def load(cls, path):
Path from which to load the SDV instance.
"""
with open(path, 'rb') as f:
model = pickle.load(f)
model = cloudpickle.load(f)
throw_version_mismatch_warning(getattr(model, '_package_versions', None))

return model
8 changes: 4 additions & 4 deletions sdv/tabular/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import logging
import math
import os
import pickle
import uuid
from collections import defaultdict
from copy import deepcopy

import cloudpickle
import copulas
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -864,7 +864,7 @@ def set_parameters(self, parameters):
self._set_parameters(parameters)

def save(self, path):
"""Save this model instance to the given path using pickle.
"""Save this model instance to the given path using cloudpickle.
Args:
path (str):
Expand All @@ -873,7 +873,7 @@ def save(self, path):
self._package_versions = get_package_versions(getattr(self, '_model', None))

with open(path, 'wb') as output:
pickle.dump(self, output)
cloudpickle.dump(self, output)

@classmethod
def load(cls, path):
Expand All @@ -888,7 +888,7 @@ def load(cls, path):
The loaded tabular model.
"""
with open(path, 'rb') as f:
model = pickle.load(f)
model = cloudpickle.load(f)
throw_version_mismatch_warning(getattr(model, '_package_versions', None))

return model
Loading

0 comments on commit 7a70db6

Please sign in to comment.