Skip to content

Commit

Permalink
-L added fixes and suggested changes, still need tests
Browse files Browse the repository at this point in the history
  • Loading branch information
saalUW committed May 1, 2024
1 parent fd8c687 commit e5251f1
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 126 deletions.
13 changes: 7 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
repos:
- repo: local
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.4.2
hooks:
# Run the linter.
- id: ruff
name: ruff
files: '(.*\.py$|.*\.pyi$)'
language: python
entry: ruff check .
types: [python]
args: [ --fix ]
# Run the formatter.
- id: ruff-format
40 changes: 22 additions & 18 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,35 +1,39 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
requires = [
"setuptools>=61",
]

[project]
name = "pydisagg"
version = "0.3.4.5"
authors = [{ name = "IHME Math Sciences", email = "ihme.math.sciences@gmail.com" }]
version = "0.3.4"
description = ""
readme = "README.md"
license = { text = "BSD 2-Clause License" }
authors = [{ name = "IHME Math Sciences", email = "ihme.math.sciences@gmail.com" }]
requires-python = ">=3.6"
classifiers = [
"Development Status :: 4 - Beta",
"Programming Language :: Python :: 3",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Development Status :: 4 - Beta",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]

dependencies = [
"matplotlib",
"numpy",
"pandas",
"scipy",
"matplotlib",
"numpy",
"pandas",
"scipy",
]
requires-python = ">=3.6"

[project.urls]
homepage = "https://github.com/ihmeuw-msca/pydisagg"

[tool.setuptools.packages.find]
where = ["src"]

[tool.pytest.ini_options]
testpaths = ["tests", "integration"]
addopts = "-v -ra -q"
Expand All @@ -38,4 +42,4 @@ log_cli_level = "INFO"
log_format = "%(asctime)s %(levelname)s %(message)s"
log_date_format = "%Y-%m-%d %H:%M:%S"
minversion = "6.0"
filterwarnings = "ignore"
filterwarnings = "ignore"
14 changes: 0 additions & 14 deletions src/pydisagg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +0,0 @@
# pydisagg/__init__.py
# from . import (
# DisaggModel,
# ParameterTransformation,
# disaggregate,
# models,
# )

# __all__ = [
# "DisaggModel",
# "disaggregate",
# "models",
# "ParameterTransformation",
# ]
2 changes: 0 additions & 2 deletions src/pydisagg/ihme/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
# ihme/__init__.py
"""This is the ihme package."""
6 changes: 3 additions & 3 deletions src/pydisagg/ihme/age_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import numpy as np
import pandas as pd

from .age_var import match_cols
from ..disaggregate import split_datapoint
from ..models import LogOdds_model, RateMultiplicativeModel
from pydisagg.ihme.age_var import match_cols
from pydisagg.disaggregate import split_datapoint
from pydisagg.models import LogOdds_model, RateMultiplicativeModel


def split_row(
Expand Down
61 changes: 36 additions & 25 deletions src/pydisagg/ihme/helper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .age_var import rename_dict_dis
from pydisagg.ihme.age_var import rename_dict_dis
import matplotlib.pyplot as plt
import numpy as np

Expand All @@ -7,17 +7,21 @@

def rename_df(frozen_df, rename_dict=rename_dict_dis, drop=True):
"""
Renames columns of a DataFrame based on a given dictionary and optionally drops columns.
Parameters:
frozen_df (DataFrame): The input DataFrame to be renamed.
rename_dict (dict): A dictionary mapping old column names to new column names.
drop (bool, optional): Whether to drop columns not present in the rename_dict.
Defaults to True.
Returns:
return_df: The renamed dataframe
frozen_df: The original dataframe with 'row_id' column added
Parameters
----------
frozen_df : DataFrame
The input DataFrame to be renamed.
rename_dict : dict
A dictionary mapping old column names to new column names.
drop : bool, optional
Whether to drop columns not present in the rename_dict. Defaults to True.
Returns
-------
return_df : DataFrame
The renamed dataframe
frozen_df : DataFrame
The original dataframe with 'row_id' column added
"""
# Create a copy of the input DataFrame to avoid changing it in place
df = frozen_df.copy()
Expand All @@ -38,12 +42,17 @@ def glue_back(df, frozen_df):
"""
Appends the columns of a frozen_df to a df based on the "row_id" column.
Parameters:
df (DataFrame): The main DataFrame to which columns will be appended.
frozen_df (DataFrame): The DataFrame whose columns will be appended to df.
Returns:
DataFrame: The merged dataframe
Parameters
----------
df : DataFrame
The main DataFrame to which columns will be appended.
frozen_df : DataFrame
The DataFrame whose columns will be appended to df.
Returns
-------
DataFrame
The merged dataframe
"""
merged_df = df.merge(
frozen_df, on="row_id", how="left", suffixes=("", "_frozen")
Expand All @@ -57,7 +66,7 @@ def plot_results(
result_df, pattern_df, row_id, title="Default Title", y_label="Some Measure"
):
# Filter result_df for the given row_id
sub_df = result_df[result_df["row_id"] == row_id]
sub_df = result_df.query(f"row_id == {row_id}")

# For each row in sub_df
for _, row in sub_df.iterrows():
Expand Down Expand Up @@ -85,10 +94,11 @@ def plot_results(
)

# Look up the mean_draw value in patterns_df corresponding to the rows age_group_id and sex_id
mean_draw_df = pattern_df[
(pattern_df["age_group_id"] == row["age_group_id"])
& (pattern_df["sex_id"] == row["sex_id"])
]
mean_draw_df = pattern_df.query(
" and ".join(
[f"{col} == {row[col]}" for col in ["age_group_id", "sex_id"]]
)
)

if not mean_draw_df.empty:
mean_draw = mean_draw_df["mean_draw"].values[0]
Expand Down Expand Up @@ -157,5 +167,6 @@ def plot_results(
# Set title to "Prevalence by Age Group"
plt.title(title)

# Show the plot
plt.show()
fig, ax = plt.subplots()

return fig
Loading

0 comments on commit e5251f1

Please sign in to comment.