-
Notifications
You must be signed in to change notification settings - Fork 325
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add condition class * add unit tests
- Loading branch information
Showing
3 changed files
with
120 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
"""SDV Sampling module.""" | ||
|
||
from sdv.sampling.condition import Condition | ||
|
||
__all__ = [ | ||
'Condition', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
"""SDV Condition class for sampling.""" | ||
|
||
|
||
class Condition(): | ||
"""Condition class. | ||
This class represents a condition that is used for sampling. | ||
Attributes: | ||
column_values (dict): | ||
A dictionary representing the desired conditions. A mapping of | ||
the column name to column value, which will be satisfied in this | ||
condition. | ||
num_rows (int): | ||
The number of rows to generate for this condition. Defaults to 1. | ||
""" | ||
|
||
column_values = {} | ||
num_rows = 1 | ||
|
||
def __init__(self, column_values, num_rows=1): | ||
self.column_values = column_values | ||
self.num_rows = num_rows | ||
|
||
def get_column_values(self): | ||
"""Get the column value mappings in this condition.""" | ||
return self.column_values.copy() | ||
|
||
def get_num_rows(self): | ||
"""Get the desired number of rows for this condition.""" | ||
return self.num_rows |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
"""Tests for the sdv.sampling.condition module.""" | ||
from sdv.sampling.condition import Condition | ||
|
||
|
||
class TestCondition(): | ||
|
||
def test___init__(self): | ||
"""Test ```Condition.__init__`` method. | ||
Expect that `column_values` and `num_rows` are defined correctly. | ||
Input: | ||
- column_values | ||
- num_rows | ||
""" | ||
# Setup | ||
column_values = {'a': 1, 'b': 2} | ||
num_rows = 5 | ||
|
||
# Run | ||
condition = Condition(column_values=column_values, num_rows=num_rows) | ||
|
||
# Assert | ||
assert condition.column_values == column_values | ||
assert condition.num_rows == num_rows | ||
|
||
def test_get_column_values(self): | ||
"""Test ```Condition.get_column_values`` method. | ||
Expect that the correct `column_values` value is returned. | ||
Input: | ||
- column_values | ||
""" | ||
# Setup | ||
column_values = {'a': 1, 'b': 2} | ||
condition = Condition(column_values=column_values) | ||
|
||
# Run | ||
condition_column_values = condition.get_column_values() | ||
|
||
# Assert | ||
assert condition_column_values == column_values | ||
|
||
def test_get_num_rows_default(self): | ||
"""Test ```Condition.get_num_rows`` method. | ||
Expect that the default `num_rows` value is returned, and | ||
that the default value is 1. | ||
Input: | ||
- column_values | ||
""" | ||
# Setup | ||
column_values = {'a': 1, 'b': 2} | ||
condition = Condition(column_values=column_values) | ||
|
||
# Run | ||
default_num_rows = condition.get_num_rows() | ||
|
||
# Assert | ||
assert default_num_rows == 1 | ||
|
||
def test_get_num_rows(self): | ||
"""Test ```Condition.get_num_rows`` method. | ||
Expect that the correct `num_rows` value is returned. | ||
Input: | ||
- column_values | ||
- num_rows | ||
""" | ||
# Setup | ||
column_values = {'a': 1, 'b': 2} | ||
num_rows = 100 | ||
condition = Condition(column_values=column_values, num_rows=num_rows) | ||
|
||
# Run | ||
condition_num_rows = condition.get_num_rows() | ||
|
||
# Assert | ||
assert condition_num_rows == num_rows |