diff --git a/sdv/sequential/par.py b/sdv/sequential/par.py index 0dbd28656..789401303 100644 --- a/sdv/sequential/par.py +++ b/sdv/sequential/par.py @@ -76,6 +76,9 @@ def _get_context_metadata(self): for column in context_columns: context_columns_dict[column] = self.metadata.columns[column] + # Context datetime SDTypes for PAR have already been converted to float timestamp + if context_columns_dict[column]['sdtype'] == 'datetime': + context_columns_dict[column] = {'sdtype': 'numerical'} for column, column_metadata in self._extra_context_columns.items(): context_columns_dict[column] = column_metadata @@ -83,6 +86,14 @@ def _get_context_metadata(self): context_metadata_dict = {'columns': context_columns_dict} return SingleTableMetadata.load_from_dict(context_metadata_dict) + def _get_context_datetime_columns(self): + datetime_columns = [] + for column in self.context_columns: + if self.metadata.columns[column]['sdtype'] == 'datetime': + datetime_columns.append(column) + + return datetime_columns + def __init__( self, metadata, @@ -352,12 +363,6 @@ def _fit_context_model(self, transformed): context[constant_column] = 0 context_metadata.add_column(constant_column, sdtype='numerical') - for column in self.context_columns: - # Context datetime SDTypes for PAR have already been converted to float timestamp - if context_metadata.columns[column]['sdtype'] == 'datetime': - if pd.api.types.is_numeric_dtype(context[column]): - context_metadata.update_column(column, sdtype='numerical') - with warnings.catch_warnings(): warnings.filterwarnings('ignore', message=".*The 'SingleTableMetadata' is deprecated.*") self._context_synthesizer = GaussianCopulaSynthesizer( @@ -540,9 +545,15 @@ def sample_sequential_columns(self, context_columns, sequence_length=None): set(context_columns.columns), set(self._context_synthesizer._model.columns) ) ) + + datetime_columns = self._get_context_datetime_columns() + if datetime_columns: + context_columns[datetime_columns] = self._data_processor.transform( + context_columns[datetime_columns] + ) + condition_columns = context_columns[condition_columns].to_dict('records') - context = self._context_synthesizer.sample_from_conditions([ - Condition(conditions) for conditions in condition_columns - ]) + synthesizer_conditions = [Condition(conditions) for conditions in condition_columns] + context = self._context_synthesizer.sample_from_conditions(synthesizer_conditions) context.update(context_columns) return self._sample(context, sequence_length) diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 2539ff98e..cc45da3df 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1537,12 +1537,12 @@ def test_large_integer_ids(self): for col in table.columns: if metadata.tables[table_name].columns[col].get('sdtype') == 'id': values = table[col].astype(str) - assert all(len(str(v)) == 17 for v in values), ( - f'ID length mismatch in {table_name}.{col}' - ) - assert all(v.isdigit() for v in values), ( - f'Non-digit characters in {table_name}.{col}' - ) + assert all( + len(str(v)) == 17 for v in values + ), f'ID length mismatch in {table_name}.{col}' + assert all( + v.isdigit() for v in values + ), f'Non-digit characters in {table_name}.{col}' # Check relationships are preserved child_fks = set(synthetic_data['table_2']['col_A']) @@ -1616,12 +1616,12 @@ def test_large_integer_ids_overflow(self): for col in table.columns: if metadata.tables[table_name].columns[col].get('sdtype') == 'id': values = table[col].astype(str) - assert all(len(str(v)) == 21 for v in values), ( - f'ID length mismatch in {table_name}.{col}' - ) - assert all(v.isdigit() for v in values), ( - f'Non-digit characters in {table_name}.{col}' - ) + assert all( + len(str(v)) == 21 for v in values + ), f'ID length mismatch in {table_name}.{col}' + assert all( + v.isdigit() for v in values + ), f'Non-digit characters in {table_name}.{col}' # Check relationships are preserved child_fks = set(synthetic_data['table_2']['col_A']) @@ -1718,12 +1718,12 @@ def test_large_integer_ids_overflow_three_tables(self): for col in table.columns: if metadata.tables[table_name].columns[col].get('sdtype') == 'id': values = table[col].astype(str) - assert all(len(str(v)) == 20 for v in values), ( - f'ID length mismatch in {table_name}.{col}' - ) - assert all(v.isdigit() for v in values), ( - f'Non-digit characters in {table_name}.{col}' - ) + assert all( + len(str(v)) == 20 for v in values + ), f'ID length mismatch in {table_name}.{col}' + assert all( + v.isdigit() for v in values + ), f'Non-digit characters in {table_name}.{col}' # Check relationships are preserved child_fks = set(synthetic_data['table_1']['col_0']) @@ -1802,12 +1802,12 @@ def test_ids_that_dont_fit_in_int64(self): for col in table.columns: if metadata.tables[table_name].columns[col].get('sdtype') == 'id': values = table[col].astype(str) - assert all(len(str(v)) == 20 for v in values), ( - f'ID length mismatch in {table_name}.{col}' - ) - assert all(v.isdigit() for v in values), ( - f'Non-digit characters in {table_name}.{col}' - ) + assert all( + len(str(v)) == 20 for v in values + ), f'ID length mismatch in {table_name}.{col}' + assert all( + v.isdigit() for v in values + ), f'Non-digit characters in {table_name}.{col}' # Check relationships are preserved child_fks = set(synthetic_data['table_2']['col_A']) @@ -1877,12 +1877,12 @@ def test_large_real_ids_small_synthetic_ids(self): for col in table.columns: if metadata.tables[table_name].columns[col].get('sdtype') == 'id': values = table[col].astype(str) - assert all(len(str(v)) == 1 for v in values), ( - f'ID length mismatch in {table_name}.{col}' - ) - assert all(v.isdigit() for v in values), ( - f'Non-digit characters in {table_name}.{col}' - ) + assert all( + len(str(v)) == 1 for v in values + ), f'ID length mismatch in {table_name}.{col}' + assert all( + v.isdigit() for v in values + ), f'Non-digit characters in {table_name}.{col}' # Check relationships are preserved child_fks = set(synthetic_data['table_2']['col_A']) diff --git a/tests/integration/sequential/test_par.py b/tests/integration/sequential/test_par.py index 0193e7a34..69121f4c9 100644 --- a/tests/integration/sequential/test_par.py +++ b/tests/integration/sequential/test_par.py @@ -20,6 +20,7 @@ def _get_par_data_and_metadata(): 'column2': ['b', 'a', 'a', 'c'], 'entity': [1, 1, 2, 2], 'context': ['a', 'a', 'b', 'b'], + 'context_date': [date, date, date, date], }) metadata = Metadata.detect_from_dataframes({'table': data}) metadata.update_column('entity', 'table', sdtype='id') @@ -94,15 +95,21 @@ def test_column_after_date_complex(): data, metadata = _get_par_data_and_metadata() # Run - model = PARSynthesizer(metadata=metadata, context_columns=['context'], epochs=1) + model = PARSynthesizer(metadata=metadata, context_columns=['context', 'context_date'], epochs=1) model.fit(data) sampled = model.sample(2) + context_columns = data[['context', 'context_date']] + sample_with_conditions = model.sample_sequential_columns(context_columns=context_columns) # Assert assert sampled.shape == data.shape assert (sampled.dtypes == data.dtypes).all() assert (sampled.notna().sum(axis=1) != 0).all() + expected_date = datetime.datetime.strptime('2020-01-01', '%Y-%m-%d') + assert all(sample_with_conditions['context_date'] == expected_date) + assert all(sample_with_conditions['context'].isin(['a', 'b'])) + def test_save_and_load(tmp_path): """Test that synthesizers can be saved and loaded properly.""" diff --git a/tests/unit/sequential/test_par.py b/tests/unit/sequential/test_par.py index e1062fcfd..33459d024 100644 --- a/tests/unit/sequential/test_par.py +++ b/tests/unit/sequential/test_par.py @@ -487,7 +487,7 @@ def test__fit_context_model_with_datetime_context_column(self, gaussian_copula_m par = PARSynthesizer(metadata, context_columns=['time']) initial_synthesizer = Mock() context_metadata = SingleTableMetadata.load_from_dict({ - 'columns': {'time': {'sdtype': 'datetime'}, 'name': {'sdtype': 'id'}} + 'columns': {'time': {'sdtype': 'numerical'}, 'name': {'sdtype': 'id'}} }) par._context_synthesizer = initial_synthesizer par._get_context_metadata = Mock() @@ -934,6 +934,7 @@ def test_sample_sequential_columns(self): """Test that the method uses the provided context columns to sample.""" # Setup par = PARSynthesizer(metadata=self.get_metadata(), context_columns=['gender']) + par._get_context_datetime_columns = Mock(return_value=None) par._context_synthesizer = Mock() par._context_synthesizer._model.columns = ['gender', 'extra_col'] par._context_synthesizer.sample_from_conditions.return_value = pd.DataFrame({ @@ -970,6 +971,7 @@ def test_sample_sequential_columns(self): call_args, _ = par._sample.call_args pd.testing.assert_frame_equal(call_args[0], expected_call_arg) assert call_args[1] == 5 + par._get_context_datetime_columns.assert_called_once_with() def test_sample_sequential_columns_no_context_columns(self): """Test that the method raises an error if the synthesizer has no context columns. @@ -1083,3 +1085,44 @@ def test___init__with_unified_metadata(self): with pytest.raises(InvalidMetadataError, match=error_msg): PARSynthesizer(multi_metadata) + + def test_sample_sequential_columns_with_datetime_values(self): + """Test that the method uses converts datetime values to numerical space before sampling.""" + # Setup + par = PARSynthesizer(metadata=self.get_metadata(), context_columns=['time']) + data = self.get_data() + par.fit(data) + + par._context_synthesizer = Mock() + par._context_synthesizer._model.columns = ['time', 'extra_col'] + par._context_synthesizer.sample_from_conditions.return_value = pd.DataFrame({ + 'id_col': ['A', 'A', 'A'], + 'time': ['2020-01-01', '2020-01-02', '2020-01-03'], + 'extra_col': [0, 1, 1], + }) + par._sample = Mock() + context_columns = pd.DataFrame({ + 'id_col': ['ID-1', 'ID-2', 'ID-3'], + 'time': ['2020-01-01', '2020-01-02', '2020-01-03'], + }) + + # Run + par.sample_sequential_columns(context_columns, 5) + + # Assert + time_values = par._data_processor.transform( + pd.DataFrame({'time': ['2020-01-01', '2020-01-02', '2020-01-03']}) + ) + + time_values = time_values['time'].tolist() + expected_conditions = [ + Condition({'time': time_values[0]}), + Condition({'time': time_values[1]}), + Condition({'time': time_values[2]}), + ] + call_args, _ = par._context_synthesizer.sample_from_conditions.call_args + + assert len(call_args[0]) == len(expected_conditions) + for arg, expected in zip(call_args[0], expected_conditions): + assert arg.column_values == expected.column_values + assert arg.num_rows == expected.num_rows