Skip to content

Commit

Permalink
Update preset param name (#752)
Browse files Browse the repository at this point in the history
  • Loading branch information
katxiao authored Mar 31, 2022
1 parent 2e7bc13 commit ccd323b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
14 changes: 7 additions & 7 deletions sdv/lite/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class TabularPreset(BaseTabularModel):
"""Class for all tabular model presets.
Args:
optimize_for (str):
name (str):
The preset to use.
metadata (dict or metadata.Table):
Table metadata instance or dict representation.
Expand All @@ -31,19 +31,19 @@ class TabularPreset(BaseTabularModel):
_model = None
_null_percentages = None

def __init__(self, optimize_for=None, metadata=None):
if optimize_for is None:
raise ValueError('You must provide the name of a preset using the `optimize_for` '
def __init__(self, name=None, metadata=None):
if name is None:
raise ValueError('You must provide the name of a preset using the `name` '
'parameter. Use `TabularPreset.list_available_presets()` to browse '
'through the options.')
if optimize_for not in PRESETS:
raise ValueError(f'`optimize_for` must be one of {PRESETS}.')
if name not in PRESETS:
raise ValueError(f'`name` must be one of {PRESETS}.')
if metadata is None:
warnings.warn('No metadata provided. Metadata will be automatically '
'detected from your data. This process may not be accurate. '
'We recommend writing metadata to ensure correct data handling.')

if optimize_for == SPEED_PRESET:
if name == SPEED_PRESET:
self._model = GaussianCopula(
table_metadata=metadata,
categorical_transformer='label_encoding',
Expand Down
18 changes: 9 additions & 9 deletions tests/unit/lite/test_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class TestTabularPreset:

def test___init__missing_optimize_for(self):
def test___init__missing_name(self):
"""Test the ``TabularPreset.__init__`` method with no parameters.
Side Effects:
Expand All @@ -21,23 +21,23 @@ def test___init__missing_optimize_for(self):
# Run and Assert
with pytest.raises(
ValueError,
match=('You must provide the name of a preset using the `optimize_for` parameter. '
match=('You must provide the name of a preset using the `name` parameter. '
r'Use `TabularPreset.list_available_presets\(\)` to browse through '
'the options.')):
TabularPreset()

def test___init__invalid_optimize_for(self):
def test___init__invalid_name(self):
"""Test the ``TabularPreset.__init__`` method with an invalid arg value.
Input:
- optimize_for = invalid parameter
- name = invalid parameter
Side Effects:
- ValueError should be thrown
"""
# Run and Assert
with pytest.raises(ValueError, match=r'`optimize_for` must be one of *'):
TabularPreset(optimize_for='invalid')
with pytest.raises(ValueError, match=r'`name` must be one of *'):
TabularPreset(name='invalid')

@patch('sdv.lite.tabular.GaussianCopula', spec_set=GaussianCopula)
def test__init__speed_passes_correct_parameters(self, gaussian_copula_mock):
Expand All @@ -46,12 +46,12 @@ def test__init__speed_passes_correct_parameters(self, gaussian_copula_mock):
The method should pass the parameters to the ``GaussianCopula`` class.
Input:
- optimize_for = speed
- name of the speed preset
Side Effects:
- GaussianCopula should receive the correct parameters
"""
# Run
TabularPreset(optimize_for='SPEED')
TabularPreset(name='SPEED')

# Assert
gaussian_copula_mock.assert_called_once_with(
Expand Down Expand Up @@ -186,7 +186,7 @@ def test_list_available_presets(self):
'custom presets? Contact the SDV team to learn more an SDV Premium license.')

# Run
TabularPreset(optimize_for='SPEED').list_available_presets(out)
TabularPreset(name='SPEED').list_available_presets(out)

# Assert
assert out.getvalue().strip() == expected

0 comments on commit ccd323b

Please sign in to comment.