diff --git a/src/pydisagg/ihme/splitter/age_splitter.py b/src/pydisagg/ihme/splitter/age_splitter.py index f673bd5..741a8f5 100644 --- a/src/pydisagg/ihme/splitter/age_splitter.py +++ b/src/pydisagg/ihme/splitter/age_splitter.py @@ -145,7 +145,9 @@ def parse_data(self, data: DataFrame, positive_strict: bool) -> DataFrame: data = data[self.data.columns].copy() validate_index(data, self.data.index, name) validate_nonan(data, name) - validate_positive(data, [self.data.val_sd], name, strict=positive_strict) + validate_positive( + data, [self.data.val_sd], name, strict=positive_strict + ) validate_interval( data, self.data.age_lwr, self.data.age_upr, self.data.index, name ) @@ -158,7 +160,8 @@ def parse_pattern( # Check if val and val_sd are missing, and generate them if necessary if not all( - col in pattern.columns for col in [self.pattern.val, self.pattern.val_sd] + col in pattern.columns + for col in [self.pattern.val, self.pattern.val_sd] ): if not self.pattern.draws: raise ValueError( @@ -169,7 +172,9 @@ def parse_pattern( # Generate val and val_sd from draws validate_columns(pattern, self.pattern.draws, name) pattern[self.pattern.val] = pattern[self.pattern.draws].mean(axis=1) - pattern[self.pattern.val_sd] = pattern[self.pattern.draws].std(axis=1) + pattern[self.pattern.val_sd] = pattern[self.pattern.draws].std( + axis=1 + ) # Validate columns after potential generation validate_columns(pattern, self.pattern.columns, name) @@ -212,7 +217,9 @@ def parse_pattern( return data_with_pattern - def _merge_with_pattern(self, data: DataFrame, pattern: DataFrame) -> DataFrame: + def _merge_with_pattern( + self, data: DataFrame, pattern: DataFrame + ) -> DataFrame: # Ensure the necessary columns are present before merging assert ( self.data.age_lwr in data.columns @@ -241,7 +248,9 @@ def _merge_with_pattern(self, data: DataFrame, pattern: DataFrame) -> DataFrame: ) return data_with_pattern - def parse_population(self, data: DataFrame, population: DataFrame) -> DataFrame: + def parse_population( + self, data: DataFrame, population: DataFrame + ) -> DataFrame: name = "Parsing Population" validate_columns(population, self.population.columns, name) @@ -367,18 +376,23 @@ def split( # If not propagating zeros, then positivity has to be strict data = self.parse_data(data, positive_strict=not propagate_zeros) - data = self.parse_pattern(data, pattern, positive_strict=not propagate_zeros) + data = self.parse_pattern( + data, pattern, positive_strict=not propagate_zeros + ) data = self.parse_population(data, population) data = self._align_pattern_and_population(data) # where split happens data["age_split_result"], data["age_split_result_se"] = np.nan, np.nan - data["age_split"] = 0 # Indicate that the row was not split by age initially + data["age_split"] = ( + 0 # Indicate that the row was not split by age initially + ) if propagate_zeros is True: data_zero = data[ - (data[self.data.val] == 0) | (data[self.pattern.val + "_aligned"] == 0) + (data[self.data.val] == 0) + | (data[self.pattern.val + "_aligned"] == 0) ] data = data[data[self.data.val] > 0] # Manually split zero values @@ -390,7 +404,8 @@ def split( num_zval = (data[self.data.val] == 0).sum() num_zpat = (data[self.pattern.val + "_aligned"] == 0).sum() num_overlap = ( - (data[self.data.val] == 0) * (data[self.pattern.val + "_aligned"] == 0) + (data[self.data.val] == 0) + * (data[self.pattern.val + "_aligned"] == 0) ).sum() if num_zval > 0: warnings.warn( diff --git a/src/pydisagg/ihme/validator.py b/src/pydisagg/ihme/validator.py index 5d564bf..1c86379 100644 --- a/src/pydisagg/ihme/validator.py +++ b/src/pydisagg/ihme/validator.py @@ -22,9 +22,7 @@ def validate_index(df: DataFrame, index: list[str], name: str) -> None: df[df[index].duplicated()][index] ).to_list() if duplicated_index: - error_message = ( - f"{name} has duplicated index with {len(duplicated_index)} indices \n" - ) + error_message = f"{name} has duplicated index with {len(duplicated_index)} indices \n" error_message += f"Index columns: ({', '.join(index)})\n" if len(duplicated_index) > 5: error_message += "First 5: \n" @@ -36,7 +34,9 @@ def validate_index(df: DataFrame, index: list[str], name: str) -> None: def validate_nonan(df: DataFrame, name: str) -> None: nan_columns = df.columns[df.isna().any(axis=0)].to_list() if nan_columns: - error_message = f"{name} has NaN values in {len(nan_columns)} columns. \n" + error_message = ( + f"{name} has NaN values in {len(nan_columns)} columns. \n" + ) error_message += f"Columns with NaN values: {', '.join(nan_columns)}\n" if len(nan_columns) > 5: error_message += "First 5 columns with NaN values: \n" @@ -80,7 +80,9 @@ def validate_noindexdiff( missing_index = index_ref.difference(index).to_list() if missing_index: - error_message = f"Missing {name} info for {len(missing_index)} indices \n" + error_message = ( + f"Missing {name} info for {len(missing_index)} indices \n" + ) error_message += f"Index columns: ({', '.join(index.names)})\n" if len(missing_index) > 5: error_message += "First 5: \n" @@ -153,7 +155,9 @@ def validate_realnumber(df: DataFrame, columns: list[str], name: str) -> None: invalid = [ col for col in columns - if not df[col].apply(lambda x: isinstance(x, (int, float)) and x != 0).all() + if not df[col] + .apply(lambda x: isinstance(x, (int, float)) and x != 0) + .all() ] if invalid: diff --git a/tests/test_age_splitter.py b/tests/test_age_splitter.py index 8ec4051..e10e453 100644 --- a/tests/test_age_splitter.py +++ b/tests/test_age_splitter.py @@ -87,7 +87,8 @@ def splitter(data, pattern, population): val_sd="val_sd", ) population_config = AgePopulationConfig( - index=["sex_id", "location_id", "age_group_id", "year_id"], val="population" + index=["sex_id", "location_id", "age_group_id", "year_id"], + val="population", ) return AgeSplitter( data=data_config, pattern=pattern_config, population=population_config @@ -162,7 +163,9 @@ def test_parse_pattern_missing_columns(splitter, data, pattern): with pytest.raises(ValueError, match="Must provide draws for pattern"): splitter.parse_pattern(data, pattern, positive_strict=True) else: - parsed_pattern = splitter.parse_pattern(data, pattern, positive_strict=True) + parsed_pattern = splitter.parse_pattern( + data, pattern, positive_strict=True + ) assert "val" in parsed_pattern.columns assert "val_sd" in parsed_pattern.columns @@ -213,7 +216,9 @@ def test_parse_pattern_nan_values(splitter, data, pattern): pattern.loc[0, "val"] = np.nan # Manually check if NaN is correctly set - assert pd.isna(pattern.loc[0, "val"]), "NaN not correctly set in 'val' column" + assert pd.isna( + pattern.loc[0, "val"] + ), "NaN not correctly set in 'val' column" # Ensure validate_nonan is called in parse_pattern with pytest.raises(ValueError, match="has NaN values"):