-
Notifications
You must be signed in to change notification settings - Fork 325
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial commit * fixing skeleton for parsynthesizer * adding unit tests * pr comments
- Loading branch information
1 parent
04dc0bd
commit ab8f913
Showing
5 changed files
with
202 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Synthesizers for sequential data.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
"""PAR Synthesizer class.""" | ||
|
||
import inspect | ||
|
||
from sdv.data_processing import DataProcessor | ||
from sdv.metadata.single_table import SingleTableMetadata | ||
from sdv.single_table import GaussianCopulaSynthesizer | ||
|
||
|
||
class PARSynthesizer: | ||
"""Synthesizer for sequential data. | ||
This synthesizer uses the ``deepecho.models.par.PARModel`` class as the core model. | ||
Additionally, it uses a separate synthesizer to model and sample the context columns | ||
to be passed into PAR. | ||
Args: | ||
metadata (sdv.metadata.SingleTableMetadata): | ||
Single table metadata representing the data that this synthesizer will be used for. | ||
enforce_min_max_values (bool): | ||
Specify whether or not to clip the data returned by ``reverse_transform`` of | ||
the numerical transformer, ``FloatFormatter``, to the min and max values seen | ||
during ``fit``. Defaults to ``True``. | ||
enforce_rounding (bool): | ||
Define rounding scheme for ``numerical`` columns. If ``True``, the data returned | ||
by ``reverse_transform`` will be rounded as in the original data. Defaults to ``True``. | ||
context_columns (list[str]): | ||
A list of strings, representing the columns that do not vary in a sequence. | ||
segment_size (int): | ||
If specified, cut each training sequence in several segments of | ||
the indicated size. The size can be passed as an integer | ||
value, which will interpreted as the number of data points to | ||
put on each segment. | ||
epochs (int): | ||
The number of epochs to train for. Defaults to 128. | ||
sample_size (int): | ||
The number of times to sample (before choosing and | ||
returning the sample which maximizes the likelihood). | ||
Defaults to 1. | ||
cuda (bool): | ||
Whether to attempt to use cuda for GPU computation. | ||
If this is False or CUDA is not available, CPU will be used. | ||
Defaults to ``True``. | ||
verbose (bool): | ||
Whether to print progress to console or not. | ||
""" | ||
|
||
def _get_context_metadata(self): | ||
context_columns_dict = {} | ||
context_columns = self.context_columns or [] | ||
for column in context_columns: | ||
context_columns_dict[column] = self.metadata._columns[column] | ||
|
||
context_metadata_dict = {'columns': context_columns_dict} | ||
return SingleTableMetadata._load_from_dict(context_metadata_dict) | ||
|
||
def __init__(self, metadata, enforce_min_max_values, enforce_rounding, context_columns=None, | ||
segment_size=None, epochs=128, sample_size=1, cuda=True, verbose=False): | ||
self.metadata = metadata | ||
self.enforce_min_max_values = enforce_min_max_values | ||
self.enforce_rounding = enforce_rounding | ||
self._data_processor = DataProcessor(metadata) | ||
self.context_columns = context_columns | ||
self.segment_size = segment_size | ||
self._model_kwargs = { | ||
'epochs': epochs, | ||
'sample_size': sample_size, | ||
'cuda': cuda, | ||
'verbose': verbose, | ||
} | ||
context_metadata = self._get_context_metadata() | ||
self._context_synthesizer = GaussianCopulaSynthesizer( | ||
metadata=context_metadata, | ||
enforce_min_max_values=enforce_min_max_values, | ||
enforce_rounding=enforce_rounding | ||
) | ||
|
||
def get_parameters(self): | ||
"""Return the parameters used to instantiate the synthesizer.""" | ||
parameters = inspect.signature(self.__init__).parameters | ||
instantiated_parameters = {} | ||
for parameter_name in parameters: | ||
if parameter_name != 'metadata': | ||
instantiated_parameters[parameter_name] = self.__dict__.get(parameter_name) | ||
|
||
for parameter_name, value in self._model_kwargs.items(): | ||
instantiated_parameters[parameter_name] = value | ||
|
||
return instantiated_parameters | ||
|
||
def get_metadata(self): | ||
"""Return the ``SingleTableMetadata`` for this synthesizer.""" | ||
return self.metadata |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Tests for synthesizers for sequential data.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from sdv.data_processing.data_processor import DataProcessor | ||
from sdv.metadata.single_table import SingleTableMetadata | ||
from sdv.sequential.par import PARSynthesizer | ||
from sdv.single_table.copulas import GaussianCopulaSynthesizer | ||
|
||
|
||
class TestPARSynthesizer: | ||
|
||
def test___init__(self): | ||
"""Test that the parameters are set correctly. | ||
The parameters passed in the ``__init__`` should be set on the instance. Additionally, | ||
a context synthesizer should be created with the correct metadata and parameters. | ||
""" | ||
# Setup | ||
metadata = SingleTableMetadata() | ||
metadata.add_column('time', sdtype='datetime') | ||
metadata.add_column('gender', sdtype='categorical') | ||
metadata.add_column('name', sdtype='text') | ||
metadata.add_column('measurement', sdtype='numerical') | ||
|
||
# Run | ||
synthesizer = PARSynthesizer( | ||
metadata=metadata, | ||
enforce_min_max_values=True, | ||
enforce_rounding=True, | ||
context_columns=['gender', 'name'], | ||
segment_size=10, | ||
epochs=10, | ||
sample_size=5, | ||
cuda=False, | ||
verbose=False | ||
) | ||
|
||
# Assert | ||
assert synthesizer.context_columns == ['gender', 'name'] | ||
assert synthesizer.enforce_min_max_values is True | ||
assert synthesizer.enforce_rounding is True | ||
assert synthesizer.segment_size == 10 | ||
assert synthesizer._model_kwargs == { | ||
'epochs': 10, | ||
'sample_size': 5, | ||
'cuda': False, | ||
'verbose': False | ||
} | ||
assert isinstance(synthesizer._data_processor, DataProcessor) | ||
assert synthesizer._data_processor.metadata == metadata | ||
assert isinstance(synthesizer._context_synthesizer, GaussianCopulaSynthesizer) | ||
assert synthesizer._context_synthesizer.metadata._columns == { | ||
'gender': {'sdtype': 'categorical'}, | ||
'name': {'sdtype': 'text'} | ||
} | ||
|
||
def test_get_parameters(self): | ||
"""Test that it returns every ``init`` parameter without the ``metadata``.""" | ||
# Setup | ||
metadata = SingleTableMetadata() | ||
instance = PARSynthesizer( | ||
metadata=metadata, | ||
enforce_min_max_values=True, | ||
enforce_rounding=True, | ||
context_columns=None, | ||
segment_size=10, | ||
epochs=10, | ||
sample_size=5, | ||
cuda=False, | ||
verbose=False | ||
) | ||
|
||
# Run | ||
parameters = instance.get_parameters() | ||
|
||
# Assert | ||
assert 'metadata' not in parameters | ||
assert parameters == { | ||
'enforce_min_max_values': True, | ||
'enforce_rounding': True, | ||
'context_columns': None, | ||
'segment_size': 10, | ||
'epochs': 10, | ||
'sample_size': 5, | ||
'cuda': False, | ||
'verbose': False | ||
} | ||
|
||
def test_get_metadata(self): | ||
"""Test that it returns the ``metadata`` object.""" | ||
# Setup | ||
metadata = SingleTableMetadata() | ||
instance = PARSynthesizer( | ||
metadata=metadata, | ||
enforce_min_max_values=True, | ||
enforce_rounding=True, | ||
context_columns=None, | ||
segment_size=10, | ||
epochs=10, | ||
sample_size=5, | ||
cuda=False, | ||
verbose=False | ||
) | ||
|
||
# Run | ||
result = instance.get_metadata() | ||
|
||
# Assert | ||
assert result == metadata |