Skip to content

Commit

Permalink
ruff formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
saalUW committed Aug 26, 2024
1 parent 506e0ca commit 0664142
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 18 deletions.
33 changes: 24 additions & 9 deletions src/pydisagg/ihme/splitter/age_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def parse_data(self, data: DataFrame, positive_strict: bool) -> DataFrame:
data = data[self.data.columns].copy()
validate_index(data, self.data.index, name)
validate_nonan(data, name)
validate_positive(data, [self.data.val_sd], name, strict=positive_strict)
validate_positive(
data, [self.data.val_sd], name, strict=positive_strict
)
validate_interval(
data, self.data.age_lwr, self.data.age_upr, self.data.index, name
)
Expand All @@ -158,7 +160,8 @@ def parse_pattern(

# Check if val and val_sd are missing, and generate them if necessary
if not all(
col in pattern.columns for col in [self.pattern.val, self.pattern.val_sd]
col in pattern.columns
for col in [self.pattern.val, self.pattern.val_sd]
):
if not self.pattern.draws:
raise ValueError(
Expand All @@ -169,7 +172,9 @@ def parse_pattern(
# Generate val and val_sd from draws
validate_columns(pattern, self.pattern.draws, name)
pattern[self.pattern.val] = pattern[self.pattern.draws].mean(axis=1)
pattern[self.pattern.val_sd] = pattern[self.pattern.draws].std(axis=1)
pattern[self.pattern.val_sd] = pattern[self.pattern.draws].std(
axis=1
)

# Validate columns after potential generation
validate_columns(pattern, self.pattern.columns, name)
Expand Down Expand Up @@ -212,7 +217,9 @@ def parse_pattern(

return data_with_pattern

def _merge_with_pattern(self, data: DataFrame, pattern: DataFrame) -> DataFrame:
def _merge_with_pattern(
self, data: DataFrame, pattern: DataFrame
) -> DataFrame:
# Ensure the necessary columns are present before merging
assert (
self.data.age_lwr in data.columns
Expand Down Expand Up @@ -241,7 +248,9 @@ def _merge_with_pattern(self, data: DataFrame, pattern: DataFrame) -> DataFrame:
)
return data_with_pattern

def parse_population(self, data: DataFrame, population: DataFrame) -> DataFrame:
def parse_population(
self, data: DataFrame, population: DataFrame
) -> DataFrame:
name = "Parsing Population"
validate_columns(population, self.population.columns, name)

Expand Down Expand Up @@ -367,18 +376,23 @@ def split(

# If not propagating zeros, then positivity has to be strict
data = self.parse_data(data, positive_strict=not propagate_zeros)
data = self.parse_pattern(data, pattern, positive_strict=not propagate_zeros)
data = self.parse_pattern(
data, pattern, positive_strict=not propagate_zeros
)
data = self.parse_population(data, population)

data = self._align_pattern_and_population(data)

# where split happens
data["age_split_result"], data["age_split_result_se"] = np.nan, np.nan
data["age_split"] = 0 # Indicate that the row was not split by age initially
data["age_split"] = (
0 # Indicate that the row was not split by age initially
)

if propagate_zeros is True:
data_zero = data[
(data[self.data.val] == 0) | (data[self.pattern.val + "_aligned"] == 0)
(data[self.data.val] == 0)
| (data[self.pattern.val + "_aligned"] == 0)
]
data = data[data[self.data.val] > 0]
# Manually split zero values
Expand All @@ -390,7 +404,8 @@ def split(
num_zval = (data[self.data.val] == 0).sum()
num_zpat = (data[self.pattern.val + "_aligned"] == 0).sum()
num_overlap = (
(data[self.data.val] == 0) * (data[self.pattern.val + "_aligned"] == 0)
(data[self.data.val] == 0)
* (data[self.pattern.val + "_aligned"] == 0)
).sum()
if num_zval > 0:
warnings.warn(
Expand Down
16 changes: 10 additions & 6 deletions src/pydisagg/ihme/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ def validate_index(df: DataFrame, index: list[str], name: str) -> None:
df[df[index].duplicated()][index]
).to_list()
if duplicated_index:
error_message = (
f"{name} has duplicated index with {len(duplicated_index)} indices \n"
)
error_message = f"{name} has duplicated index with {len(duplicated_index)} indices \n"
error_message += f"Index columns: ({', '.join(index)})\n"
if len(duplicated_index) > 5:
error_message += "First 5: \n"
Expand All @@ -36,7 +34,9 @@ def validate_index(df: DataFrame, index: list[str], name: str) -> None:
def validate_nonan(df: DataFrame, name: str) -> None:
nan_columns = df.columns[df.isna().any(axis=0)].to_list()
if nan_columns:
error_message = f"{name} has NaN values in {len(nan_columns)} columns. \n"
error_message = (
f"{name} has NaN values in {len(nan_columns)} columns. \n"
)
error_message += f"Columns with NaN values: {', '.join(nan_columns)}\n"
if len(nan_columns) > 5:
error_message += "First 5 columns with NaN values: \n"
Expand Down Expand Up @@ -80,7 +80,9 @@ def validate_noindexdiff(
missing_index = index_ref.difference(index).to_list()

if missing_index:
error_message = f"Missing {name} info for {len(missing_index)} indices \n"
error_message = (
f"Missing {name} info for {len(missing_index)} indices \n"
)
error_message += f"Index columns: ({', '.join(index.names)})\n"
if len(missing_index) > 5:
error_message += "First 5: \n"
Expand Down Expand Up @@ -153,7 +155,9 @@ def validate_realnumber(df: DataFrame, columns: list[str], name: str) -> None:
invalid = [
col
for col in columns
if not df[col].apply(lambda x: isinstance(x, (int, float)) and x != 0).all()
if not df[col]
.apply(lambda x: isinstance(x, (int, float)) and x != 0)
.all()
]

if invalid:
Expand Down
11 changes: 8 additions & 3 deletions tests/test_age_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def splitter(data, pattern, population):
val_sd="val_sd",
)
population_config = AgePopulationConfig(
index=["sex_id", "location_id", "age_group_id", "year_id"], val="population"
index=["sex_id", "location_id", "age_group_id", "year_id"],
val="population",
)
return AgeSplitter(
data=data_config, pattern=pattern_config, population=population_config
Expand Down Expand Up @@ -162,7 +163,9 @@ def test_parse_pattern_missing_columns(splitter, data, pattern):
with pytest.raises(ValueError, match="Must provide draws for pattern"):
splitter.parse_pattern(data, pattern, positive_strict=True)
else:
parsed_pattern = splitter.parse_pattern(data, pattern, positive_strict=True)
parsed_pattern = splitter.parse_pattern(
data, pattern, positive_strict=True
)
assert "val" in parsed_pattern.columns
assert "val_sd" in parsed_pattern.columns

Expand Down Expand Up @@ -213,7 +216,9 @@ def test_parse_pattern_nan_values(splitter, data, pattern):
pattern.loc[0, "val"] = np.nan

# Manually check if NaN is correctly set
assert pd.isna(pattern.loc[0, "val"]), "NaN not correctly set in 'val' column"
assert pd.isna(
pattern.loc[0, "val"]
), "NaN not correctly set in 'val' column"

# Ensure validate_nonan is called in parse_pattern
with pytest.raises(ValueError, match="has NaN values"):
Expand Down

0 comments on commit 0664142

Please sign in to comment.