Skip to content

Commit

Permalink
Merge pull request #77 from ihmeuw-msca/test/sex_splitter
Browse files Browse the repository at this point in the history
Added sex splitting tests
  • Loading branch information
saalUW authored Aug 26, 2024
2 parents b5e3124 + d3a6edd commit d4dc800
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/pydisagg/ihme/splitter/sex_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,18 @@ def get_population_by_sex(self, population, sex_value):
def parse_data(self, data: DataFrame) -> DataFrame:
name = "While parsing data"

# Validate core columns first
try:
validate_columns(data, self.data.columns, name)
except KeyError as e:
raise KeyError(
f"{name}: Missing columns in the input data. Details:\n{e}"
)

data = data[self.data.columns].copy()
if self.population.sex not in data.columns:
raise KeyError(
f"{name}: Missing column '{self.population.sex}' in the input data."
)

try:
validate_index(data, self.data.index, name)
Expand Down
194 changes: 194 additions & 0 deletions tests/test_sex_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import pytest
import pandas as pd
from pydantic import ValidationError
from pydisagg.ihme.splitter import (
SexSplitter,
SexDataConfig,
SexPatternConfig,
SexPopulationConfig,
)

# Step 1: Setup Fixtures


@pytest.fixture
def sex_data_config():
return SexDataConfig(
index=["age_group_id", "year_id", "location_id"],
val="val",
val_sd="val_sd",
)


@pytest.fixture
def sex_pattern_config():
return SexPatternConfig(by=["age_group_id", "year_id"])


@pytest.fixture
def sex_population_config():
return SexPopulationConfig(
index=["age_group_id", "year_id", "location_id"],
sex="sex_id",
sex_m=1,
sex_f=2,
val="population",
)


@pytest.fixture
def valid_data():
return pd.DataFrame(
{
"age_group_id": [1, 1, 2, 2],
"year_id": [2000, 2000, 2001, 2001],
"location_id": [10, 20, 10, 20],
"sex_id": [3, 3, 3, 3],
"val": [100, 200, 150, 250],
"val_sd": [10, 20, 15, 25],
}
)


@pytest.fixture
def valid_pattern():
return pd.DataFrame(
{
"age_group_id": [1, 1, 2, 2],
"year_id": [2000, 2000, 2001, 2001],
"pattern_val": [1.5, 2.0, 1.2, 1.8],
"pattern_val_sd": [0.1, 0.2, 0.15, 0.25],
}
)


@pytest.fixture
def pattern_with_draws():
return pd.DataFrame(
{
"age_group_id": [1, 1, 2, 2],
"year_id": [2000, 2000, 2001, 2001],
"draw_1": [1.4, 1.9, 1.3, 1.7],
"draw_2": [1.6, 2.1, 1.1, 1.9],
"draw_3": [1.5, 2.0, 1.2, 1.8],
}
)


@pytest.fixture
def invalid_pattern_missing_columns():
return pd.DataFrame(
{
"age_group_id": [1, 1, 2, 2],
"year_id": [2000, 2000, 2001, 2001],
"pattern_val": [1.5, 2.0, 1.2, 1.8],
# Missing pattern_val_sd
}
)


@pytest.fixture
def duplicated_index_pattern():
return pd.DataFrame(
{
"age_group_id": [1, 1, 1, 1],
"year_id": [2000, 2000, 2000, 2000],
"pattern_val": [1.5, 2.0, 1.2, 1.8],
"pattern_val_sd": [0.1, 0.2, 0.15, 0.25],
}
)


@pytest.fixture
def pattern_with_nan():
return pd.DataFrame(
{
"age_group_id": [1, 1, 2, 2],
"year_id": [2000, 2000, 2001, 2001],
"pattern_val": [1.5, None, 1.2, 1.8],
"pattern_val_sd": [0.1, 0.2, None, 0.25],
}
)


@pytest.fixture
def pattern_with_non_positive():
return pd.DataFrame(
{
"age_group_id": [1, 1, 2, 2],
"year_id": [2000, 2000, 2001, 2001],
"pattern_val": [-1.5, 0, -1.2, 0],
"pattern_val_sd": [0.1, 0.2, 0.15, 0.25],
}
)


@pytest.fixture
def pattern_with_invalid_realnumbers():
return pd.DataFrame(
{
"age_group_id": [1, 1, 2, 2],
"year_id": [2000, 2000, 2001, 2001],
"pattern_val": [1.5, 2.0, 1.2, 1.8],
"pattern_val_sd": [0, 0.2, -0.15, 0.25],
}
)


@pytest.fixture
def sex_splitter(sex_data_config, sex_pattern_config, sex_population_config):
return SexSplitter(
data=sex_data_config,
pattern=sex_pattern_config,
population=sex_population_config,
)


# Step 2: Write Tests for parse_data
def test_parse_data_missing_columns(sex_splitter, valid_data):
"""Test parse_data raises an error when columns are missing."""
invalid_data = valid_data.drop(columns=["val"])
with pytest.raises(KeyError, match="Missing columns"):
sex_splitter.parse_data(invalid_data)


def test_parse_data_duplicated_index(sex_splitter, valid_data):
"""Test parse_data raises an error on duplicated index."""
duplicated_data = pd.concat([valid_data, valid_data])
with pytest.raises(ValueError, match="Duplicated index found"):
sex_splitter.parse_data(duplicated_data)


def test_parse_data_with_nan(sex_splitter, valid_data):
"""Test parse_data raises an error when there are NaN values."""
nan_data = valid_data.copy()
nan_data.loc[0, "val"] = None
with pytest.raises(ValueError, match="NaN values found"):
sex_splitter.parse_data(nan_data)


def test_parse_data_non_positive(sex_splitter, valid_data):
"""Test parse_data raises an error for non-positive values in val or val_sd."""
non_positive_data = valid_data.copy()
non_positive_data.loc[0, "val"] = -10
with pytest.raises(ValueError, match="Non-positive values found"):
sex_splitter.parse_data(non_positive_data)


def test_parse_data_valid(sex_splitter, valid_data):
"""Test that parse_data works correctly on valid data."""
parsed_data = sex_splitter.parse_data(valid_data)
assert not parsed_data.empty
assert "val" in parsed_data.columns
assert "val_sd" in parsed_data.columns


def test_parse_data_invalid_sex_rows(sex_splitter, valid_data):
"""Test parse_data raises an error if invalid sex_id rows are present."""
invalid_sex_data = valid_data.copy()
invalid_sex_data.loc[0, "sex_id"] = 1 # Setting sex_id to sex_m
with pytest.raises(ValueError, match="Invalid rows"):
sex_splitter.parse_data(invalid_sex_data)


# Step 3: Write Tests for parse_pattern

0 comments on commit d4dc800

Please sign in to comment.