Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove deprecated SQLAlchemy v1 calls in preparation for v2 #73

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions rctab/crud/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async def load_cache(oid: str) -> msal.SerializableTokenCache:
"""Load a user's token cache from the database."""
cache = msal.SerializableTokenCache()
value = await database.fetch_val(
select([user_cache.c.cache]).where(user_cache.c.oid == oid)
select(user_cache.c.cache).where(user_cache.c.oid == oid)
)
if value:
cache.deserialize(value)
Expand Down Expand Up @@ -56,12 +56,10 @@ async def check_user_access(
raise_http_exception: Raise an exception if the user isn't found.
"""
statement = select(
[
user_rbac.c.oid,
user_rbac.c.username,
user_rbac.c.has_access,
user_rbac.c.is_admin,
]
user_rbac.c.oid,
user_rbac.c.username,
user_rbac.c.has_access,
user_rbac.c.is_admin,
).where(user_rbac.c.oid == oid)

user_status = await database.fetch_one(statement)
Expand Down
10 changes: 4 additions & 6 deletions rctab/routers/accounting/abolishment.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ async def get_inactive_subs() -> Optional[List[UUID]]:
# The most recent time_created for each subscription's subscription_detail
query_1 = (
select(
[
subscription_details.c.subscription_id,
func.max(subscription_details.c.time_created).label("time_created"),
]
subscription_details.c.subscription_id,
func.max(subscription_details.c.time_created).label("time_created"),
).group_by(subscription_details.c.subscription_id)
).alias()

Expand All @@ -54,7 +52,7 @@ async def get_inactive_subs() -> Optional[List[UUID]]:
# subscriptions that have been inactive for more than 90 days
# and have not been abolished yet
query_2_result = await database.fetch_all(
select([subscription_details.c.subscription_id])
select(subscription_details.c.subscription_id)
.select_from(query_2)
.where(
and_(
Expand All @@ -78,7 +76,7 @@ async def adjust_budgets_to_zero(admin_oid: UUID, sub_ids: List[UUID]) -> List[d
sub_query = get_subscriptions_summary(execute=False).alias()

summaries = (
select([sub_query])
select(sub_query)
.where(sub_query.c.subscription_id.in_([str(sub_id) for sub_id in sub_ids]))
.alias()
)
Expand Down
40 changes: 18 additions & 22 deletions rctab/routers/accounting/cost_recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def calc_cost_recovery(
but not for this month.
"""
last_recovered_day = await database.fetch_one(
select([cost_recovery_log]).order_by(desc(cost_recovery_log.c.month))
select(cost_recovery_log).order_by(desc(cost_recovery_log.c.month))
)
last_recovered_month = (
CostRecoveryMonth(first_day=last_recovered_day["month"])
Expand All @@ -86,9 +86,7 @@ async def calc_cost_recovery(
subscription_ids = [
r["sub_id"]
for r in await database.fetch_all(
select(
[func.distinct(finance.c.subscription_id).label("sub_id")]
).where(
select(func.distinct(finance.c.subscription_id).label("sub_id")).where(
between(
recovery_month.first_day, finance.c.date_from, finance.c.date_to
)
Expand All @@ -99,7 +97,7 @@ async def calc_cost_recovery(
for subscription_id in subscription_ids:

usage_row = await database.fetch_one(
select([func.sum(usage.c.total_cost).label("the_sum")])
select(func.sum(usage.c.total_cost).label("the_sum"))
.where(
func.date_trunc(
"month", sqlalchemy.cast(usage.c.date, sqlalchemy.Date)
Expand All @@ -118,7 +116,7 @@ async def calc_cost_recovery(

# The lower the value of priority, the higher importance
finance_periods = await database.fetch_all(
select([finance])
select(finance)
.where(
and_(
between(
Expand All @@ -136,7 +134,7 @@ async def calc_cost_recovery(
for finance_period in finance_periods:

cost_recovery_row = await database.fetch_one(
select([func.sum(cost_recovery.c.amount).label("the_sum")]).where(
select(func.sum(cost_recovery.c.amount).label("the_sum")).where(
cost_recovery.c.finance_id == finance_period["id"]
)
)
Expand All @@ -160,32 +158,30 @@ async def calc_cost_recovery(
cost_recovery_id = await database.execute(
insert(
cost_recovery,
{
"subscription_id": finance_period["subscription_id"],
"month": recovery_month.first_day,
"finance_code": finance_period["finance_code"],
"amount": recoverable_amount,
"date_recovered": None,
"finance_id": finance_period["id"],
"admin": admin,
},
)
),
{
"subscription_id": finance_period["subscription_id"],
"month": recovery_month.first_day,
"finance_code": finance_period["finance_code"],
"amount": recoverable_amount,
"date_recovered": None,
"finance_id": finance_period["id"],
"admin": admin,
},
)
cost_recovery_ids.append(cost_recovery_id)

inserted_rows = await database.fetch_all(
select([cost_recovery]).where(cost_recovery.c.id.in_(cost_recovery_ids))
select(cost_recovery).where(cost_recovery.c.id.in_(cost_recovery_ids))
)

# Note that we patch CostRecovery as a unit testing hack
cost_recoveries = [CostRecovery(**dict(cr)) for cr in inserted_rows]

if commit_transaction:
await database.execute(
insert(
cost_recovery_log,
{"month": recovery_month.first_day, "admin": admin},
)
insert(cost_recovery_log),
{"month": recovery_month.first_day, "admin": admin},
)
await transaction.commit()
else:
Expand Down
140 changes: 61 additions & 79 deletions rctab/routers/accounting/desired_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,16 @@ async def get_desired_states(
summaries = get_subscriptions_summary(execute=False).alias()

to_be_changed = select(
[
summaries.c.subscription_id,
case(
[
# If the desired status is False, we want to disable it
(
summaries.c.desired_status == False,
SubscriptionState("Disabled"),
),
# If the desired status is True, we want to enable it
(summaries.c.desired_status == True, SubscriptionState("Enabled")),
]
).label("desired_state"),
]
summaries.c.subscription_id,
case(
# If the desired status is False, we want to disable it
(
summaries.c.desired_status == False,
SubscriptionState("Disabled"),
),
# If the desired status is True, we want to enable it
(summaries.c.desired_status == True, SubscriptionState("Enabled")),
).label("desired_state"),
).where(
or_(
and_(
Expand Down Expand Up @@ -151,7 +147,7 @@ async def refresh_desired_states(

if subscription_ids:
summaries = (
select([sub_query])
select(sub_query)
.where(
sub_query.c.subscription_id.in_(
[str(sub_id) for sub_id in subscription_ids]
Expand All @@ -165,7 +161,7 @@ async def refresh_desired_states(
# Subscriptions without an approved_to date or whose approved_to
# date is in the past, that currently have a desired_status of True
over_time = (
select([summaries]).where(
select(summaries).where(
and_(
or_(
summaries.c.approved_to == None,
Expand Down Expand Up @@ -206,7 +202,7 @@ async def refresh_desired_states(
# Subscriptions with more usage than allocated budget
# that currently have a desired_status of True
over_budget = (
select([summaries])
select(summaries)
.where(
and_(
# To gracelessly sidestep rounding errors, allow a tolerance
Expand All @@ -222,58 +218,50 @@ async def refresh_desired_states(

over_time_or_over_budget = (
select(
[
literal_column("1").label("reason_enum"),
over_time.c.subscription_id,
literal_column("uuid('" + str(admin_oid) + "')").label("admin_oid"),
literal_column("False").label("active"),
over_time.c.desired_status,
over_time.c.desired_status_info,
]
literal_column("1").label("reason_enum"),
over_time.c.subscription_id,
literal_column("uuid('" + str(admin_oid) + "')").label("admin_oid"),
literal_column("False").label("active"),
over_time.c.desired_status,
over_time.c.desired_status_info,
)
.union(
select(
[
literal_column("2").label("reason_enum"),
over_budget.c.subscription_id,
literal_column("uuid('" + str(admin_oid) + "')").label("admin_oid"),
literal_column("False").label("active"),
over_budget.c.desired_status,
over_budget.c.desired_status_info,
]
literal_column("2").label("reason_enum"),
over_budget.c.subscription_id,
literal_column("uuid('" + str(admin_oid) + "')").label("admin_oid"),
literal_column("False").label("active"),
over_budget.c.desired_status,
over_budget.c.desired_status_info,
)
)
.alias()
)

over_time_or_over_budget_reason = (
select(
[
over_time_or_over_budget.c.subscription_id,
over_time_or_over_budget.c.admin_oid,
over_time_or_over_budget.c.active,
case(
[
(
func.sum(over_time_or_over_budget.c.reason_enum) == 1,
cast(BillingStatus.EXPIRED, Enum(BillingStatus)),
),
(
func.sum(over_time_or_over_budget.c.reason_enum) == 2,
cast(BillingStatus.OVER_BUDGET, Enum(BillingStatus)),
),
(
func.sum(over_time_or_over_budget.c.reason_enum) == 3,
cast(
BillingStatus.OVER_BUDGET_AND_EXPIRED,
Enum(BillingStatus),
),
),
],
).label("reason"),
over_time_or_over_budget.c.desired_status_info.label("old_reason"),
over_time_or_over_budget.c.desired_status.label("old_desired_status"),
]
over_time_or_over_budget.c.subscription_id,
over_time_or_over_budget.c.admin_oid,
over_time_or_over_budget.c.active,
case(
(
func.sum(over_time_or_over_budget.c.reason_enum) == 1,
cast(BillingStatus.EXPIRED, Enum(BillingStatus)),
),
(
func.sum(over_time_or_over_budget.c.reason_enum) == 2,
cast(BillingStatus.OVER_BUDGET, Enum(BillingStatus)),
),
(
func.sum(over_time_or_over_budget.c.reason_enum) == 3,
cast(
BillingStatus.OVER_BUDGET_AND_EXPIRED,
Enum(BillingStatus),
),
),
).label("reason"),
over_time_or_over_budget.c.desired_status_info.label("old_reason"),
over_time_or_over_budget.c.desired_status.label("old_desired_status"),
)
.group_by(
over_time_or_over_budget.c.subscription_id,
Expand All @@ -289,13 +277,11 @@ async def refresh_desired_states(
# desired status or has the wrong reason or a missing reason
over_time_or_over_budget_desired_on = (
select(
[
over_time_or_over_budget_reason.c.subscription_id,
over_time_or_over_budget_reason.c.admin_oid,
over_time_or_over_budget_reason.c.active,
over_time_or_over_budget_reason.c.reason,
over_time_or_over_budget_reason.c.old_reason,
]
over_time_or_over_budget_reason.c.subscription_id,
over_time_or_over_budget_reason.c.admin_oid,
over_time_or_over_budget_reason.c.active,
over_time_or_over_budget_reason.c.reason,
over_time_or_over_budget_reason.c.old_reason,
)
.where(
or_(
Expand All @@ -319,12 +305,10 @@ async def refresh_desired_states(
status_table.c.reason,
],
select(
[
over_time_or_over_budget_desired_on.c.subscription_id,
over_time_or_over_budget_desired_on.c.admin_oid,
over_time_or_over_budget_desired_on.c.active,
over_time_or_over_budget_desired_on.c.reason,
]
over_time_or_over_budget_desired_on.c.subscription_id,
over_time_or_over_budget_desired_on.c.admin_oid,
over_time_or_over_budget_desired_on.c.active,
over_time_or_over_budget_desired_on.c.reason,
),
)

Expand Down Expand Up @@ -353,17 +337,15 @@ async def refresh_desired_states(
# but aren't currently. These are all of our subscriptions
# that are disabled but aren't over time or budget.
should_be_enabled_but_are_not = select(
[
summaries.c.subscription_id,
literal_column("uuid('" + str(admin_oid) + "')"),
literal_column("True"),
literal_column("NULL"),
]
summaries.c.subscription_id,
literal_column("uuid('" + str(admin_oid) + "')"),
literal_column("True"),
literal_column("NULL"),
).where(
and_(
not_(
summaries.c.subscription_id.in_(
select([over_time_or_over_budget.c.subscription_id])
select(over_time_or_over_budget.c.subscription_id)
)
),
or_(
Expand Down
Loading