Skip to content

Commit

Permalink
Add condition class (#704)
Browse files Browse the repository at this point in the history
* Add condition class

* add unit tests
  • Loading branch information
katxiao authored Feb 11, 2022
1 parent c045868 commit 7f238bf
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 0 deletions.
7 changes: 7 additions & 0 deletions sdv/sampling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""SDV Sampling module."""

from sdv.sampling.condition import Condition

__all__ = [
'Condition',
]
31 changes: 31 additions & 0 deletions sdv/sampling/condition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""SDV Condition class for sampling."""


class Condition():
"""Condition class.
This class represents a condition that is used for sampling.
Attributes:
column_values (dict):
A dictionary representing the desired conditions. A mapping of
the column name to column value, which will be satisfied in this
condition.
num_rows (int):
The number of rows to generate for this condition. Defaults to 1.
"""

column_values = {}
num_rows = 1

def __init__(self, column_values, num_rows=1):
self.column_values = column_values
self.num_rows = num_rows

def get_column_values(self):
"""Get the column value mappings in this condition."""
return self.column_values.copy()

def get_num_rows(self):
"""Get the desired number of rows for this condition."""
return self.num_rows
82 changes: 82 additions & 0 deletions tests/unit/sampling/test_condition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Tests for the sdv.sampling.condition module."""
from sdv.sampling.condition import Condition


class TestCondition():

def test___init__(self):
"""Test ```Condition.__init__`` method.
Expect that `column_values` and `num_rows` are defined correctly.
Input:
- column_values
- num_rows
"""
# Setup
column_values = {'a': 1, 'b': 2}
num_rows = 5

# Run
condition = Condition(column_values=column_values, num_rows=num_rows)

# Assert
assert condition.column_values == column_values
assert condition.num_rows == num_rows

def test_get_column_values(self):
"""Test ```Condition.get_column_values`` method.
Expect that the correct `column_values` value is returned.
Input:
- column_values
"""
# Setup
column_values = {'a': 1, 'b': 2}
condition = Condition(column_values=column_values)

# Run
condition_column_values = condition.get_column_values()

# Assert
assert condition_column_values == column_values

def test_get_num_rows_default(self):
"""Test ```Condition.get_num_rows`` method.
Expect that the default `num_rows` value is returned, and
that the default value is 1.
Input:
- column_values
"""
# Setup
column_values = {'a': 1, 'b': 2}
condition = Condition(column_values=column_values)

# Run
default_num_rows = condition.get_num_rows()

# Assert
assert default_num_rows == 1

def test_get_num_rows(self):
"""Test ```Condition.get_num_rows`` method.
Expect that the correct `num_rows` value is returned.
Input:
- column_values
- num_rows
"""
# Setup
column_values = {'a': 1, 'b': 2}
num_rows = 100
condition = Condition(column_values=column_values, num_rows=num_rows)

# Run
condition_num_rows = condition.get_num_rows()

# Assert
assert condition_num_rows == num_rows

0 comments on commit 7f238bf

Please sign in to comment.