Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added sex splitting tests #77

Merged
merged 2 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions src/pydisagg/ihme/splitter/age_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def parse_data(self, data: DataFrame, positive_strict: bool) -> DataFrame:
data = data[self.data.columns].copy()
validate_index(data, self.data.index, name)
validate_nonan(data, name)
validate_positive(data, [self.data.val_sd], name, strict=positive_strict)
validate_positive(
data, [self.data.val_sd], name, strict=positive_strict
)
validate_interval(
data, self.data.age_lwr, self.data.age_upr, self.data.index, name
)
Expand All @@ -158,7 +160,8 @@ def parse_pattern(

# Check if val and val_sd are missing, and generate them if necessary
if not all(
col in pattern.columns for col in [self.pattern.val, self.pattern.val_sd]
col in pattern.columns
for col in [self.pattern.val, self.pattern.val_sd]
):
if not self.pattern.draws:
raise ValueError(
Expand All @@ -169,7 +172,9 @@ def parse_pattern(
# Generate val and val_sd from draws
validate_columns(pattern, self.pattern.draws, name)
pattern[self.pattern.val] = pattern[self.pattern.draws].mean(axis=1)
pattern[self.pattern.val_sd] = pattern[self.pattern.draws].std(axis=1)
pattern[self.pattern.val_sd] = pattern[self.pattern.draws].std(
axis=1
)

# Validate columns after potential generation
validate_columns(pattern, self.pattern.columns, name)
Expand Down Expand Up @@ -212,7 +217,9 @@ def parse_pattern(

return data_with_pattern

def _merge_with_pattern(self, data: DataFrame, pattern: DataFrame) -> DataFrame:
def _merge_with_pattern(
self, data: DataFrame, pattern: DataFrame
) -> DataFrame:
# Ensure the necessary columns are present before merging
assert (
self.data.age_lwr in data.columns
Expand Down Expand Up @@ -241,7 +248,9 @@ def _merge_with_pattern(self, data: DataFrame, pattern: DataFrame) -> DataFrame:
)
return data_with_pattern

def parse_population(self, data: DataFrame, population: DataFrame) -> DataFrame:
def parse_population(
self, data: DataFrame, population: DataFrame
) -> DataFrame:
name = "Parsing Population"
validate_columns(population, self.population.columns, name)

Expand Down Expand Up @@ -367,18 +376,23 @@ def split(

# If not propagating zeros, then positivity has to be strict
data = self.parse_data(data, positive_strict=not propagate_zeros)
data = self.parse_pattern(data, pattern, positive_strict=not propagate_zeros)
data = self.parse_pattern(
data, pattern, positive_strict=not propagate_zeros
)
data = self.parse_population(data, population)

data = self._align_pattern_and_population(data)

# where split happens
data["age_split_result"], data["age_split_result_se"] = np.nan, np.nan
data["age_split"] = 0 # Indicate that the row was not split by age initially
data["age_split"] = (
0 # Indicate that the row was not split by age initially
)

if propagate_zeros is True:
data_zero = data[
(data[self.data.val] == 0) | (data[self.pattern.val + "_aligned"] == 0)
(data[self.data.val] == 0)
| (data[self.pattern.val + "_aligned"] == 0)
]
data = data[data[self.data.val] > 0]
# Manually split zero values
Expand All @@ -390,7 +404,8 @@ def split(
num_zval = (data[self.data.val] == 0).sum()
num_zpat = (data[self.pattern.val + "_aligned"] == 0).sum()
num_overlap = (
(data[self.data.val] == 0) * (data[self.pattern.val + "_aligned"] == 0)
(data[self.data.val] == 0)
* (data[self.pattern.val + "_aligned"] == 0)
).sum()
if num_zval > 0:
warnings.warn(
Expand Down
6 changes: 5 additions & 1 deletion src/pydisagg/ihme/splitter/sex_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,18 @@ def get_population_by_sex(self, population, sex_value):
def parse_data(self, data: DataFrame) -> DataFrame:
name = "While parsing data"

# Validate core columns first
try:
validate_columns(data, self.data.columns, name)
except KeyError as e:
raise KeyError(
f"{name}: Missing columns in the input data. Details:\n{e}"
)

data = data[self.data.columns].copy()
if self.population.sex not in data.columns:
raise KeyError(
f"{name}: Missing column '{self.population.sex}' in the input data."
)

try:
validate_index(data, self.data.index, name)
Expand Down
16 changes: 10 additions & 6 deletions src/pydisagg/ihme/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ def validate_index(df: DataFrame, index: list[str], name: str) -> None:
df[df[index].duplicated()][index]
).to_list()
if duplicated_index:
error_message = (
f"{name} has duplicated index with {len(duplicated_index)} indices \n"
)
error_message = f"{name} has duplicated index with {len(duplicated_index)} indices \n"
error_message += f"Index columns: ({', '.join(index)})\n"
if len(duplicated_index) > 5:
error_message += "First 5: \n"
Expand All @@ -36,7 +34,9 @@ def validate_index(df: DataFrame, index: list[str], name: str) -> None:
def validate_nonan(df: DataFrame, name: str) -> None:
nan_columns = df.columns[df.isna().any(axis=0)].to_list()
if nan_columns:
error_message = f"{name} has NaN values in {len(nan_columns)} columns. \n"
error_message = (
f"{name} has NaN values in {len(nan_columns)} columns. \n"
)
error_message += f"Columns with NaN values: {', '.join(nan_columns)}\n"
if len(nan_columns) > 5:
error_message += "First 5 columns with NaN values: \n"
Expand Down Expand Up @@ -80,7 +80,9 @@ def validate_noindexdiff(
missing_index = index_ref.difference(index).to_list()

if missing_index:
error_message = f"Missing {name} info for {len(missing_index)} indices \n"
error_message = (
f"Missing {name} info for {len(missing_index)} indices \n"
)
error_message += f"Index columns: ({', '.join(index.names)})\n"
if len(missing_index) > 5:
error_message += "First 5: \n"
Expand Down Expand Up @@ -153,7 +155,9 @@ def validate_realnumber(df: DataFrame, columns: list[str], name: str) -> None:
invalid = [
col
for col in columns
if not df[col].apply(lambda x: isinstance(x, (int, float)) and x != 0).all()
if not df[col]
.apply(lambda x: isinstance(x, (int, float)) and x != 0)
.all()
]

if invalid:
Expand Down
11 changes: 8 additions & 3 deletions tests/test_age_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def splitter(data, pattern, population):
val_sd="val_sd",
)
population_config = AgePopulationConfig(
index=["sex_id", "location_id", "age_group_id", "year_id"], val="population"
index=["sex_id", "location_id", "age_group_id", "year_id"],
val="population",
)
return AgeSplitter(
data=data_config, pattern=pattern_config, population=population_config
Expand Down Expand Up @@ -162,7 +163,9 @@ def test_parse_pattern_missing_columns(splitter, data, pattern):
with pytest.raises(ValueError, match="Must provide draws for pattern"):
splitter.parse_pattern(data, pattern, positive_strict=True)
else:
parsed_pattern = splitter.parse_pattern(data, pattern, positive_strict=True)
parsed_pattern = splitter.parse_pattern(
data, pattern, positive_strict=True
)
assert "val" in parsed_pattern.columns
assert "val_sd" in parsed_pattern.columns

Expand Down Expand Up @@ -213,7 +216,9 @@ def test_parse_pattern_nan_values(splitter, data, pattern):
pattern.loc[0, "val"] = np.nan

# Manually check if NaN is correctly set
assert pd.isna(pattern.loc[0, "val"]), "NaN not correctly set in 'val' column"
assert pd.isna(
pattern.loc[0, "val"]
), "NaN not correctly set in 'val' column"

# Ensure validate_nonan is called in parse_pattern
with pytest.raises(ValueError, match="has NaN values"):
Expand Down
194 changes: 194 additions & 0 deletions tests/test_sex_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import pytest
import pandas as pd
from pydantic import ValidationError
from pydisagg.ihme.splitter import (
SexSplitter,
SexDataConfig,
SexPatternConfig,
SexPopulationConfig,
)

# Step 1: Setup Fixtures


@pytest.fixture
def sex_data_config():
return SexDataConfig(
index=["age_group_id", "year_id", "location_id"],
val="val",
val_sd="val_sd",
)


@pytest.fixture
def sex_pattern_config():
return SexPatternConfig(by=["age_group_id", "year_id"])


@pytest.fixture
def sex_population_config():
return SexPopulationConfig(
index=["age_group_id", "year_id", "location_id"],
sex="sex_id",
sex_m=1,
sex_f=2,
val="population",
)


@pytest.fixture
def valid_data():
return pd.DataFrame(
{
"age_group_id": [1, 1, 2, 2],
"year_id": [2000, 2000, 2001, 2001],
"location_id": [10, 20, 10, 20],
"sex_id": [3, 3, 3, 3],
"val": [100, 200, 150, 250],
"val_sd": [10, 20, 15, 25],
}
)


@pytest.fixture
def valid_pattern():
return pd.DataFrame(
{
"age_group_id": [1, 1, 2, 2],
"year_id": [2000, 2000, 2001, 2001],
"pattern_val": [1.5, 2.0, 1.2, 1.8],
"pattern_val_sd": [0.1, 0.2, 0.15, 0.25],
}
)


@pytest.fixture
def pattern_with_draws():
return pd.DataFrame(
{
"age_group_id": [1, 1, 2, 2],
"year_id": [2000, 2000, 2001, 2001],
"draw_1": [1.4, 1.9, 1.3, 1.7],
"draw_2": [1.6, 2.1, 1.1, 1.9],
"draw_3": [1.5, 2.0, 1.2, 1.8],
}
)


@pytest.fixture
def invalid_pattern_missing_columns():
return pd.DataFrame(
{
"age_group_id": [1, 1, 2, 2],
"year_id": [2000, 2000, 2001, 2001],
"pattern_val": [1.5, 2.0, 1.2, 1.8],
# Missing pattern_val_sd
}
)


@pytest.fixture
def duplicated_index_pattern():
return pd.DataFrame(
{
"age_group_id": [1, 1, 1, 1],
"year_id": [2000, 2000, 2000, 2000],
"pattern_val": [1.5, 2.0, 1.2, 1.8],
"pattern_val_sd": [0.1, 0.2, 0.15, 0.25],
}
)


@pytest.fixture
def pattern_with_nan():
return pd.DataFrame(
{
"age_group_id": [1, 1, 2, 2],
"year_id": [2000, 2000, 2001, 2001],
"pattern_val": [1.5, None, 1.2, 1.8],
"pattern_val_sd": [0.1, 0.2, None, 0.25],
}
)


@pytest.fixture
def pattern_with_non_positive():
return pd.DataFrame(
{
"age_group_id": [1, 1, 2, 2],
"year_id": [2000, 2000, 2001, 2001],
"pattern_val": [-1.5, 0, -1.2, 0],
"pattern_val_sd": [0.1, 0.2, 0.15, 0.25],
}
)


@pytest.fixture
def pattern_with_invalid_realnumbers():
return pd.DataFrame(
{
"age_group_id": [1, 1, 2, 2],
"year_id": [2000, 2000, 2001, 2001],
"pattern_val": [1.5, 2.0, 1.2, 1.8],
"pattern_val_sd": [0, 0.2, -0.15, 0.25],
}
)


@pytest.fixture
def sex_splitter(sex_data_config, sex_pattern_config, sex_population_config):
return SexSplitter(
data=sex_data_config,
pattern=sex_pattern_config,
population=sex_population_config,
)


# Step 2: Write Tests for parse_data
def test_parse_data_missing_columns(sex_splitter, valid_data):
"""Test parse_data raises an error when columns are missing."""
invalid_data = valid_data.drop(columns=["val"])
with pytest.raises(KeyError, match="Missing columns"):
sex_splitter.parse_data(invalid_data)


def test_parse_data_duplicated_index(sex_splitter, valid_data):
"""Test parse_data raises an error on duplicated index."""
duplicated_data = pd.concat([valid_data, valid_data])
with pytest.raises(ValueError, match="Duplicated index found"):
sex_splitter.parse_data(duplicated_data)


def test_parse_data_with_nan(sex_splitter, valid_data):
"""Test parse_data raises an error when there are NaN values."""
nan_data = valid_data.copy()
nan_data.loc[0, "val"] = None
with pytest.raises(ValueError, match="NaN values found"):
sex_splitter.parse_data(nan_data)


def test_parse_data_non_positive(sex_splitter, valid_data):
"""Test parse_data raises an error for non-positive values in val or val_sd."""
non_positive_data = valid_data.copy()
non_positive_data.loc[0, "val"] = -10
with pytest.raises(ValueError, match="Non-positive values found"):
sex_splitter.parse_data(non_positive_data)


def test_parse_data_valid(sex_splitter, valid_data):
"""Test that parse_data works correctly on valid data."""
parsed_data = sex_splitter.parse_data(valid_data)
assert not parsed_data.empty
assert "val" in parsed_data.columns
assert "val_sd" in parsed_data.columns


def test_parse_data_invalid_sex_rows(sex_splitter, valid_data):
"""Test parse_data raises an error if invalid sex_id rows are present."""
invalid_sex_data = valid_data.copy()
invalid_sex_data.loc[0, "sex_id"] = 1 # Setting sex_id to sex_m
with pytest.raises(ValueError, match="Invalid rows"):
sex_splitter.parse_data(invalid_sex_data)


# Step 3: Write Tests for parse_pattern