Skip to content

Commit

Permalink
Fix par synthesizer not being able to conditionally sample with datet…
Browse files Browse the repository at this point in the history
…imes
  • Loading branch information
pvk-developer committed Jan 15, 2025
1 parent 6b47f9d commit 8b0cfb1
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 41 deletions.
29 changes: 20 additions & 9 deletions sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,24 @@ def _get_context_metadata(self):

for column in context_columns:
context_columns_dict[column] = self.metadata.columns[column]
# Context datetime SDTypes for PAR have already been converted to float timestamp
if context_columns_dict[column]['sdtype'] == 'datetime':
context_columns_dict[column] = {'sdtype': 'numerical'}

for column, column_metadata in self._extra_context_columns.items():
context_columns_dict[column] = column_metadata

context_metadata_dict = {'columns': context_columns_dict}
return SingleTableMetadata.load_from_dict(context_metadata_dict)

def _get_context_datetime_columns(self):
datetime_columns = []
for column in self.context_columns:
if self.metadata.columns[column]['sdtype'] == 'datetime':
datetime_columns.append(column)

return datetime_columns

def __init__(
self,
metadata,
Expand Down Expand Up @@ -352,12 +363,6 @@ def _fit_context_model(self, transformed):
context[constant_column] = 0
context_metadata.add_column(constant_column, sdtype='numerical')

for column in self.context_columns:
# Context datetime SDTypes for PAR have already been converted to float timestamp
if context_metadata.columns[column]['sdtype'] == 'datetime':
if pd.api.types.is_numeric_dtype(context[column]):
context_metadata.update_column(column, sdtype='numerical')

with warnings.catch_warnings():
warnings.filterwarnings('ignore', message=".*The 'SingleTableMetadata' is deprecated.*")
self._context_synthesizer = GaussianCopulaSynthesizer(
Expand Down Expand Up @@ -540,9 +545,15 @@ def sample_sequential_columns(self, context_columns, sequence_length=None):
set(context_columns.columns), set(self._context_synthesizer._model.columns)
)
)

datetime_columns = self._get_context_datetime_columns()
if datetime_columns:
context_columns[datetime_columns] = self._data_processor.transform(
context_columns[datetime_columns]
)

condition_columns = context_columns[condition_columns].to_dict('records')
context = self._context_synthesizer.sample_from_conditions([
Condition(conditions) for conditions in condition_columns
])
synthesizer_conditions = [Condition(conditions) for conditions in condition_columns]
context = self._context_synthesizer.sample_from_conditions(synthesizer_conditions)
context.update(context_columns)
return self._sample(context, sequence_length)
60 changes: 30 additions & 30 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1537,12 +1537,12 @@ def test_large_integer_ids(self):
for col in table.columns:
if metadata.tables[table_name].columns[col].get('sdtype') == 'id':
values = table[col].astype(str)
assert all(len(str(v)) == 17 for v in values), (
f'ID length mismatch in {table_name}.{col}'
)
assert all(v.isdigit() for v in values), (
f'Non-digit characters in {table_name}.{col}'
)
assert all(
len(str(v)) == 17 for v in values
), f'ID length mismatch in {table_name}.{col}'
assert all(
v.isdigit() for v in values
), f'Non-digit characters in {table_name}.{col}'

# Check relationships are preserved
child_fks = set(synthetic_data['table_2']['col_A'])
Expand Down Expand Up @@ -1616,12 +1616,12 @@ def test_large_integer_ids_overflow(self):
for col in table.columns:
if metadata.tables[table_name].columns[col].get('sdtype') == 'id':
values = table[col].astype(str)
assert all(len(str(v)) == 21 for v in values), (
f'ID length mismatch in {table_name}.{col}'
)
assert all(v.isdigit() for v in values), (
f'Non-digit characters in {table_name}.{col}'
)
assert all(
len(str(v)) == 21 for v in values
), f'ID length mismatch in {table_name}.{col}'
assert all(
v.isdigit() for v in values
), f'Non-digit characters in {table_name}.{col}'

# Check relationships are preserved
child_fks = set(synthetic_data['table_2']['col_A'])
Expand Down Expand Up @@ -1718,12 +1718,12 @@ def test_large_integer_ids_overflow_three_tables(self):
for col in table.columns:
if metadata.tables[table_name].columns[col].get('sdtype') == 'id':
values = table[col].astype(str)
assert all(len(str(v)) == 20 for v in values), (
f'ID length mismatch in {table_name}.{col}'
)
assert all(v.isdigit() for v in values), (
f'Non-digit characters in {table_name}.{col}'
)
assert all(
len(str(v)) == 20 for v in values
), f'ID length mismatch in {table_name}.{col}'
assert all(
v.isdigit() for v in values
), f'Non-digit characters in {table_name}.{col}'

# Check relationships are preserved
child_fks = set(synthetic_data['table_1']['col_0'])
Expand Down Expand Up @@ -1802,12 +1802,12 @@ def test_ids_that_dont_fit_in_int64(self):
for col in table.columns:
if metadata.tables[table_name].columns[col].get('sdtype') == 'id':
values = table[col].astype(str)
assert all(len(str(v)) == 20 for v in values), (
f'ID length mismatch in {table_name}.{col}'
)
assert all(v.isdigit() for v in values), (
f'Non-digit characters in {table_name}.{col}'
)
assert all(
len(str(v)) == 20 for v in values
), f'ID length mismatch in {table_name}.{col}'
assert all(
v.isdigit() for v in values
), f'Non-digit characters in {table_name}.{col}'

# Check relationships are preserved
child_fks = set(synthetic_data['table_2']['col_A'])
Expand Down Expand Up @@ -1877,12 +1877,12 @@ def test_large_real_ids_small_synthetic_ids(self):
for col in table.columns:
if metadata.tables[table_name].columns[col].get('sdtype') == 'id':
values = table[col].astype(str)
assert all(len(str(v)) == 1 for v in values), (
f'ID length mismatch in {table_name}.{col}'
)
assert all(v.isdigit() for v in values), (
f'Non-digit characters in {table_name}.{col}'
)
assert all(
len(str(v)) == 1 for v in values
), f'ID length mismatch in {table_name}.{col}'
assert all(
v.isdigit() for v in values
), f'Non-digit characters in {table_name}.{col}'

# Check relationships are preserved
child_fks = set(synthetic_data['table_2']['col_A'])
Expand Down
9 changes: 8 additions & 1 deletion tests/integration/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def _get_par_data_and_metadata():
'column2': ['b', 'a', 'a', 'c'],
'entity': [1, 1, 2, 2],
'context': ['a', 'a', 'b', 'b'],
'context_date': [date, date, date, date],
})
metadata = Metadata.detect_from_dataframes({'table': data})
metadata.update_column('entity', 'table', sdtype='id')
Expand Down Expand Up @@ -94,15 +95,21 @@ def test_column_after_date_complex():
data, metadata = _get_par_data_and_metadata()

# Run
model = PARSynthesizer(metadata=metadata, context_columns=['context'], epochs=1)
model = PARSynthesizer(metadata=metadata, context_columns=['context', 'context_date'], epochs=1)
model.fit(data)
sampled = model.sample(2)
context_columns = data[['context', 'context_date']]
sample_with_conditions = model.sample_sequential_columns(context_columns=context_columns)

# Assert
assert sampled.shape == data.shape
assert (sampled.dtypes == data.dtypes).all()
assert (sampled.notna().sum(axis=1) != 0).all()

expected_date = datetime.datetime.strptime('2020-01-01', '%Y-%m-%d')
assert all(sample_with_conditions['context_date'] == expected_date)
assert all(sample_with_conditions['context'].isin(['a', 'b']))


def test_save_and_load(tmp_path):
"""Test that synthesizers can be saved and loaded properly."""
Expand Down
45 changes: 44 additions & 1 deletion tests/unit/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def test__fit_context_model_with_datetime_context_column(self, gaussian_copula_m
par = PARSynthesizer(metadata, context_columns=['time'])
initial_synthesizer = Mock()
context_metadata = SingleTableMetadata.load_from_dict({
'columns': {'time': {'sdtype': 'datetime'}, 'name': {'sdtype': 'id'}}
'columns': {'time': {'sdtype': 'numerical'}, 'name': {'sdtype': 'id'}}
})
par._context_synthesizer = initial_synthesizer
par._get_context_metadata = Mock()
Expand Down Expand Up @@ -934,6 +934,7 @@ def test_sample_sequential_columns(self):
"""Test that the method uses the provided context columns to sample."""
# Setup
par = PARSynthesizer(metadata=self.get_metadata(), context_columns=['gender'])
par._get_context_datetime_columns = Mock(return_value=None)
par._context_synthesizer = Mock()
par._context_synthesizer._model.columns = ['gender', 'extra_col']
par._context_synthesizer.sample_from_conditions.return_value = pd.DataFrame({
Expand Down Expand Up @@ -970,6 +971,7 @@ def test_sample_sequential_columns(self):
call_args, _ = par._sample.call_args
pd.testing.assert_frame_equal(call_args[0], expected_call_arg)
assert call_args[1] == 5
par._get_context_datetime_columns.assert_called_once_with()

def test_sample_sequential_columns_no_context_columns(self):
"""Test that the method raises an error if the synthesizer has no context columns.
Expand Down Expand Up @@ -1083,3 +1085,44 @@ def test___init__with_unified_metadata(self):

with pytest.raises(InvalidMetadataError, match=error_msg):
PARSynthesizer(multi_metadata)

def test_sample_sequential_columns_with_datetime_values(self):
"""Test that the method uses converts datetime values to numerical space before sampling."""
# Setup
par = PARSynthesizer(metadata=self.get_metadata(), context_columns=['time'])
data = self.get_data()
par.fit(data)

par._context_synthesizer = Mock()
par._context_synthesizer._model.columns = ['time', 'extra_col']
par._context_synthesizer.sample_from_conditions.return_value = pd.DataFrame({
'id_col': ['A', 'A', 'A'],
'time': ['2020-01-01', '2020-01-02', '2020-01-03'],
'extra_col': [0, 1, 1],
})
par._sample = Mock()
context_columns = pd.DataFrame({
'id_col': ['ID-1', 'ID-2', 'ID-3'],
'time': ['2020-01-01', '2020-01-02', '2020-01-03'],
})

# Run
par.sample_sequential_columns(context_columns, 5)

# Assert
time_values = par._data_processor.transform(
pd.DataFrame({'time': ['2020-01-01', '2020-01-02', '2020-01-03']})
)

time_values = time_values['time'].tolist()
expected_conditions = [
Condition({'time': time_values[0]}),
Condition({'time': time_values[1]}),
Condition({'time': time_values[2]}),
]
call_args, _ = par._context_synthesizer.sample_from_conditions.call_args

assert len(call_args[0]) == len(expected_conditions)
for arg, expected in zip(call_args[0], expected_conditions):
assert arg.column_values == expected.column_values
assert arg.num_rows == expected.num_rows

0 comments on commit 8b0cfb1

Please sign in to comment.