From fdaa9ca02a439d94b1137fd01377512700d6f03c Mon Sep 17 00:00:00 2001 From: saal Date: Mon, 26 Aug 2024 15:30:53 -0700 Subject: [PATCH 1/2] init commit for sex_split test --- src/pydisagg/ihme/splitter/age_splitter.py | 33 +++++-- src/pydisagg/ihme/splitter/sex_splitter.py | 6 +- src/pydisagg/ihme/validator.py | 16 +-- tests/test_age_splitter.py | 11 ++- tests/test_sex_splitter.py | 108 +++++++++++++++++++++ 5 files changed, 155 insertions(+), 19 deletions(-) create mode 100644 tests/test_sex_splitter.py diff --git a/src/pydisagg/ihme/splitter/age_splitter.py b/src/pydisagg/ihme/splitter/age_splitter.py index f673bd5..741a8f5 100644 --- a/src/pydisagg/ihme/splitter/age_splitter.py +++ b/src/pydisagg/ihme/splitter/age_splitter.py @@ -145,7 +145,9 @@ def parse_data(self, data: DataFrame, positive_strict: bool) -> DataFrame: data = data[self.data.columns].copy() validate_index(data, self.data.index, name) validate_nonan(data, name) - validate_positive(data, [self.data.val_sd], name, strict=positive_strict) + validate_positive( + data, [self.data.val_sd], name, strict=positive_strict + ) validate_interval( data, self.data.age_lwr, self.data.age_upr, self.data.index, name ) @@ -158,7 +160,8 @@ def parse_pattern( # Check if val and val_sd are missing, and generate them if necessary if not all( - col in pattern.columns for col in [self.pattern.val, self.pattern.val_sd] + col in pattern.columns + for col in [self.pattern.val, self.pattern.val_sd] ): if not self.pattern.draws: raise ValueError( @@ -169,7 +172,9 @@ def parse_pattern( # Generate val and val_sd from draws validate_columns(pattern, self.pattern.draws, name) pattern[self.pattern.val] = pattern[self.pattern.draws].mean(axis=1) - pattern[self.pattern.val_sd] = pattern[self.pattern.draws].std(axis=1) + pattern[self.pattern.val_sd] = pattern[self.pattern.draws].std( + axis=1 + ) # Validate columns after potential generation validate_columns(pattern, self.pattern.columns, name) @@ -212,7 +217,9 @@ def parse_pattern( return data_with_pattern - def _merge_with_pattern(self, data: DataFrame, pattern: DataFrame) -> DataFrame: + def _merge_with_pattern( + self, data: DataFrame, pattern: DataFrame + ) -> DataFrame: # Ensure the necessary columns are present before merging assert ( self.data.age_lwr in data.columns @@ -241,7 +248,9 @@ def _merge_with_pattern(self, data: DataFrame, pattern: DataFrame) -> DataFrame: ) return data_with_pattern - def parse_population(self, data: DataFrame, population: DataFrame) -> DataFrame: + def parse_population( + self, data: DataFrame, population: DataFrame + ) -> DataFrame: name = "Parsing Population" validate_columns(population, self.population.columns, name) @@ -367,18 +376,23 @@ def split( # If not propagating zeros, then positivity has to be strict data = self.parse_data(data, positive_strict=not propagate_zeros) - data = self.parse_pattern(data, pattern, positive_strict=not propagate_zeros) + data = self.parse_pattern( + data, pattern, positive_strict=not propagate_zeros + ) data = self.parse_population(data, population) data = self._align_pattern_and_population(data) # where split happens data["age_split_result"], data["age_split_result_se"] = np.nan, np.nan - data["age_split"] = 0 # Indicate that the row was not split by age initially + data["age_split"] = ( + 0 # Indicate that the row was not split by age initially + ) if propagate_zeros is True: data_zero = data[ - (data[self.data.val] == 0) | (data[self.pattern.val + "_aligned"] == 0) + (data[self.data.val] == 0) + | (data[self.pattern.val + "_aligned"] == 0) ] data = data[data[self.data.val] > 0] # Manually split zero values @@ -390,7 +404,8 @@ def split( num_zval = (data[self.data.val] == 0).sum() num_zpat = (data[self.pattern.val + "_aligned"] == 0).sum() num_overlap = ( - (data[self.data.val] == 0) * (data[self.pattern.val + "_aligned"] == 0) + (data[self.data.val] == 0) + * (data[self.pattern.val + "_aligned"] == 0) ).sum() if num_zval > 0: warnings.warn( diff --git a/src/pydisagg/ihme/splitter/sex_splitter.py b/src/pydisagg/ihme/splitter/sex_splitter.py index bda4fe6..4b63750 100644 --- a/src/pydisagg/ihme/splitter/sex_splitter.py +++ b/src/pydisagg/ihme/splitter/sex_splitter.py @@ -101,6 +101,7 @@ def get_population_by_sex(self, population, sex_value): def parse_data(self, data: DataFrame) -> DataFrame: name = "While parsing data" + # Validate core columns first try: validate_columns(data, self.data.columns, name) except KeyError as e: @@ -108,7 +109,10 @@ def parse_data(self, data: DataFrame) -> DataFrame: f"{name}: Missing columns in the input data. Details:\n{e}" ) - data = data[self.data.columns].copy() + if self.population.sex not in data.columns: + raise KeyError( + f"{name}: Missing column '{self.population.sex}' in the input data." + ) try: validate_index(data, self.data.index, name) diff --git a/src/pydisagg/ihme/validator.py b/src/pydisagg/ihme/validator.py index 5d564bf..1c86379 100644 --- a/src/pydisagg/ihme/validator.py +++ b/src/pydisagg/ihme/validator.py @@ -22,9 +22,7 @@ def validate_index(df: DataFrame, index: list[str], name: str) -> None: df[df[index].duplicated()][index] ).to_list() if duplicated_index: - error_message = ( - f"{name} has duplicated index with {len(duplicated_index)} indices \n" - ) + error_message = f"{name} has duplicated index with {len(duplicated_index)} indices \n" error_message += f"Index columns: ({', '.join(index)})\n" if len(duplicated_index) > 5: error_message += "First 5: \n" @@ -36,7 +34,9 @@ def validate_index(df: DataFrame, index: list[str], name: str) -> None: def validate_nonan(df: DataFrame, name: str) -> None: nan_columns = df.columns[df.isna().any(axis=0)].to_list() if nan_columns: - error_message = f"{name} has NaN values in {len(nan_columns)} columns. \n" + error_message = ( + f"{name} has NaN values in {len(nan_columns)} columns. \n" + ) error_message += f"Columns with NaN values: {', '.join(nan_columns)}\n" if len(nan_columns) > 5: error_message += "First 5 columns with NaN values: \n" @@ -80,7 +80,9 @@ def validate_noindexdiff( missing_index = index_ref.difference(index).to_list() if missing_index: - error_message = f"Missing {name} info for {len(missing_index)} indices \n" + error_message = ( + f"Missing {name} info for {len(missing_index)} indices \n" + ) error_message += f"Index columns: ({', '.join(index.names)})\n" if len(missing_index) > 5: error_message += "First 5: \n" @@ -153,7 +155,9 @@ def validate_realnumber(df: DataFrame, columns: list[str], name: str) -> None: invalid = [ col for col in columns - if not df[col].apply(lambda x: isinstance(x, (int, float)) and x != 0).all() + if not df[col] + .apply(lambda x: isinstance(x, (int, float)) and x != 0) + .all() ] if invalid: diff --git a/tests/test_age_splitter.py b/tests/test_age_splitter.py index 8ec4051..e10e453 100644 --- a/tests/test_age_splitter.py +++ b/tests/test_age_splitter.py @@ -87,7 +87,8 @@ def splitter(data, pattern, population): val_sd="val_sd", ) population_config = AgePopulationConfig( - index=["sex_id", "location_id", "age_group_id", "year_id"], val="population" + index=["sex_id", "location_id", "age_group_id", "year_id"], + val="population", ) return AgeSplitter( data=data_config, pattern=pattern_config, population=population_config @@ -162,7 +163,9 @@ def test_parse_pattern_missing_columns(splitter, data, pattern): with pytest.raises(ValueError, match="Must provide draws for pattern"): splitter.parse_pattern(data, pattern, positive_strict=True) else: - parsed_pattern = splitter.parse_pattern(data, pattern, positive_strict=True) + parsed_pattern = splitter.parse_pattern( + data, pattern, positive_strict=True + ) assert "val" in parsed_pattern.columns assert "val_sd" in parsed_pattern.columns @@ -213,7 +216,9 @@ def test_parse_pattern_nan_values(splitter, data, pattern): pattern.loc[0, "val"] = np.nan # Manually check if NaN is correctly set - assert pd.isna(pattern.loc[0, "val"]), "NaN not correctly set in 'val' column" + assert pd.isna( + pattern.loc[0, "val"] + ), "NaN not correctly set in 'val' column" # Ensure validate_nonan is called in parse_pattern with pytest.raises(ValueError, match="has NaN values"): diff --git a/tests/test_sex_splitter.py b/tests/test_sex_splitter.py new file mode 100644 index 0000000..e82b098 --- /dev/null +++ b/tests/test_sex_splitter.py @@ -0,0 +1,108 @@ +import pytest +import pandas as pd +from pydantic import ValidationError +from pydisagg.ihme.splitter import ( + SexSplitter, + SexDataConfig, + SexPatternConfig, + SexPopulationConfig, +) + +# Step 1: Setup Fixtures + + +@pytest.fixture +def sex_data_config(): + return SexDataConfig( + index=["age_group_id", "year_id", "location_id"], + val="val", + val_sd="val_sd", + ) + + +@pytest.fixture +def sex_pattern_config(): + return SexPatternConfig(by=["age_group_id", "year_id"]) + + +@pytest.fixture +def sex_population_config(): + return SexPopulationConfig( + index=["age_group_id", "year_id", "location_id"], + sex="sex_id", + sex_m=1, + sex_f=2, + val="population", + ) + + +@pytest.fixture +def valid_data(): + return pd.DataFrame( + { + "age_group_id": [1, 1, 2, 2], + "year_id": [2000, 2000, 2001, 2001], + "location_id": [10, 20, 10, 20], + "sex_id": [3, 3, 3, 3], + "val": [100, 200, 150, 250], + "val_sd": [10, 20, 15, 25], + } + ) + + +@pytest.fixture +def sex_splitter(sex_data_config, sex_pattern_config, sex_population_config): + return SexSplitter( + data=sex_data_config, + pattern=sex_pattern_config, + population=sex_population_config, + ) + + +# Step 2: Write Tests for parse_data + + +def test_parse_data_missing_columns(sex_splitter, valid_data): + """Test parse_data raises an error when columns are missing.""" + invalid_data = valid_data.drop(columns=["val"]) + with pytest.raises(KeyError, match="Missing columns"): + sex_splitter.parse_data(invalid_data) + + +def test_parse_data_duplicated_index(sex_splitter, valid_data): + """Test parse_data raises an error on duplicated index.""" + duplicated_data = pd.concat([valid_data, valid_data]) + with pytest.raises(ValueError, match="Duplicated index found"): + sex_splitter.parse_data(duplicated_data) + + +def test_parse_data_with_nan(sex_splitter, valid_data): + """Test parse_data raises an error when there are NaN values.""" + nan_data = valid_data.copy() + nan_data.loc[0, "val"] = None + with pytest.raises(ValueError, match="NaN values found"): + sex_splitter.parse_data(nan_data) + + +def test_parse_data_non_positive(sex_splitter, valid_data): + """Test parse_data raises an error for non-positive values in val or val_sd.""" + non_positive_data = valid_data.copy() + non_positive_data.loc[0, "val"] = -10 + with pytest.raises(ValueError, match="Non-positive values found"): + sex_splitter.parse_data(non_positive_data) + + +def test_parse_data_valid(sex_splitter, valid_data): + """Test that parse_data works correctly on valid data.""" + parsed_data = sex_splitter.parse_data(valid_data) + assert not parsed_data.empty + assert "val" in parsed_data.columns + assert "val_sd" in parsed_data.columns + + +def test_parse_data_invalid_sex_rows(sex_splitter, valid_data): + """Test parse_data raises an error if invalid sex_id rows are present.""" + invalid_sex_data = valid_data.copy() + invalid_sex_data.loc[0, "sex_id"] = 1 # Setting sex_id to sex_m + with pytest.raises(ValueError, match="Invalid rows"): + sex_splitter.parse_data(invalid_sex_data) From d3a6edd76aa83e7f4f4ab9c6af9777b16cc79df7 Mon Sep 17 00:00:00 2001 From: saal Date: Mon, 26 Aug 2024 15:39:25 -0700 Subject: [PATCH 2/2] Added parse data tests --- tests/test_sex_splitter.py | 90 +++++++++++++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 2 deletions(-) diff --git a/tests/test_sex_splitter.py b/tests/test_sex_splitter.py index e82b098..95815e5 100644 --- a/tests/test_sex_splitter.py +++ b/tests/test_sex_splitter.py @@ -50,6 +50,91 @@ def valid_data(): ) +@pytest.fixture +def valid_pattern(): + return pd.DataFrame( + { + "age_group_id": [1, 1, 2, 2], + "year_id": [2000, 2000, 2001, 2001], + "pattern_val": [1.5, 2.0, 1.2, 1.8], + "pattern_val_sd": [0.1, 0.2, 0.15, 0.25], + } + ) + + +@pytest.fixture +def pattern_with_draws(): + return pd.DataFrame( + { + "age_group_id": [1, 1, 2, 2], + "year_id": [2000, 2000, 2001, 2001], + "draw_1": [1.4, 1.9, 1.3, 1.7], + "draw_2": [1.6, 2.1, 1.1, 1.9], + "draw_3": [1.5, 2.0, 1.2, 1.8], + } + ) + + +@pytest.fixture +def invalid_pattern_missing_columns(): + return pd.DataFrame( + { + "age_group_id": [1, 1, 2, 2], + "year_id": [2000, 2000, 2001, 2001], + "pattern_val": [1.5, 2.0, 1.2, 1.8], + # Missing pattern_val_sd + } + ) + + +@pytest.fixture +def duplicated_index_pattern(): + return pd.DataFrame( + { + "age_group_id": [1, 1, 1, 1], + "year_id": [2000, 2000, 2000, 2000], + "pattern_val": [1.5, 2.0, 1.2, 1.8], + "pattern_val_sd": [0.1, 0.2, 0.15, 0.25], + } + ) + + +@pytest.fixture +def pattern_with_nan(): + return pd.DataFrame( + { + "age_group_id": [1, 1, 2, 2], + "year_id": [2000, 2000, 2001, 2001], + "pattern_val": [1.5, None, 1.2, 1.8], + "pattern_val_sd": [0.1, 0.2, None, 0.25], + } + ) + + +@pytest.fixture +def pattern_with_non_positive(): + return pd.DataFrame( + { + "age_group_id": [1, 1, 2, 2], + "year_id": [2000, 2000, 2001, 2001], + "pattern_val": [-1.5, 0, -1.2, 0], + "pattern_val_sd": [0.1, 0.2, 0.15, 0.25], + } + ) + + +@pytest.fixture +def pattern_with_invalid_realnumbers(): + return pd.DataFrame( + { + "age_group_id": [1, 1, 2, 2], + "year_id": [2000, 2000, 2001, 2001], + "pattern_val": [1.5, 2.0, 1.2, 1.8], + "pattern_val_sd": [0, 0.2, -0.15, 0.25], + } + ) + + @pytest.fixture def sex_splitter(sex_data_config, sex_pattern_config, sex_population_config): return SexSplitter( @@ -60,8 +145,6 @@ def sex_splitter(sex_data_config, sex_pattern_config, sex_population_config): # Step 2: Write Tests for parse_data - - def test_parse_data_missing_columns(sex_splitter, valid_data): """Test parse_data raises an error when columns are missing.""" invalid_data = valid_data.drop(columns=["val"]) @@ -106,3 +189,6 @@ def test_parse_data_invalid_sex_rows(sex_splitter, valid_data): invalid_sex_data.loc[0, "sex_id"] = 1 # Setting sex_id to sex_m with pytest.raises(ValueError, match="Invalid rows"): sex_splitter.parse_data(invalid_sex_data) + + +# Step 3: Write Tests for parse_pattern