Skip to content

Commit

Permalink
code review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
katxiao committed Feb 16, 2022
1 parent c38c0b4 commit a50c40f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 38 deletions.
10 changes: 1 addition & 9 deletions sdv/tabular/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,11 +411,6 @@ def _conditionally_sample_rows(self, dataframe, max_retries, max_rows_multiplier

return sampled_rows

def _set_fixed_seed(self, randomize_samples):
if randomize_samples:
# TODO: set random state on copulas.
return

@validate_sample_args
def sample(self, num_rows, randomize_samples=True):
"""Sample rows from this table.
Expand All @@ -432,10 +427,7 @@ def sample(self, num_rows, randomize_samples=True):
Sampled data.
"""
if num_rows is None:
raise ValueError(
'Error: You must specify the number of rows to sample (eg. num_rows=100).')

self._set_fixed_seed(randomize_samples)
raise ValueError('You must specify the number of rows to sample (e.g. num_rows=100).')

return self._sample_batch(num_rows)

Expand Down
38 changes: 9 additions & 29 deletions tests/unit/tabular/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def test_sample_valid_num_rows(self):
# Setup
gaussian_copula = Mock(spec_set=GaussianCopula)
valid_sampled_data = pd.DataFrame({
"column1": [28, 28, 21, 1, 2],
"column2": [37, 37, 1, 4, 5],
"column3": [93, 93, 6, 4, 12],
'column1': [28, 28, 21, 1, 2],
'column2': [37, 37, 1, 4, 5],
'column3': [93, 93, 6, 4, 12],
})
gaussian_copula._sample_batch.return_value = valid_sampled_data

Expand All @@ -93,7 +93,9 @@ def test_sample_no_num_rows(self):
model = BaseTabularModel()

# Run and assert
with pytest.raises(TypeError):
with pytest.raises(
TypeError,
match=r'sample\(\) missing 1 required positional argument: \'num_rows\''):
model.sample()

def test_sample_num_rows_none(self):
Expand All @@ -111,33 +113,11 @@ def test_sample_num_rows_none(self):
num_rows = None

# Run and assert
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match=r'You must specify the number of rows to sample \(e.g. num_rows=100\)'):
model.sample(num_rows)

def test_sample_randomize_samples_true(self):
"""Test the `BaseTabularModel.sample` method with `randomize_samples` set to True.
Expect that sequential calls return different sampled rows.
Input:
- num_rows = None
Output:
- randomized rows
"""
pass

def test_sample_randomize_samples_false(self):
"""Test the `BaseTabularModel.sample` method with `randomize_samples` set to False.
Expect that sequential calls return the same sampled rows.
Input:
- num_rows = 5
Output:
- deterministic rows
"""
pass


@patch('sdv.tabular.base.Table', spec_set=Table)
def test__init__passes_correct_parameters(metadata_mock):
Expand Down

0 comments on commit a50c40f

Please sign in to comment.