diff --git a/rctab/crud/auth.py b/rctab/crud/auth.py index 38ff155..ad4e6ad 100644 --- a/rctab/crud/auth.py +++ b/rctab/crud/auth.py @@ -18,7 +18,7 @@ async def load_cache(oid: str) -> msal.SerializableTokenCache: """Load a user's token cache from the database.""" cache = msal.SerializableTokenCache() value = await database.fetch_val( - select([user_cache.c.cache]).where(user_cache.c.oid == oid) + select(user_cache.c.cache).where(user_cache.c.oid == oid) ) if value: cache.deserialize(value) @@ -56,12 +56,10 @@ async def check_user_access( raise_http_exception: Raise an exception if the user isn't found. """ statement = select( - [ - user_rbac.c.oid, - user_rbac.c.username, - user_rbac.c.has_access, - user_rbac.c.is_admin, - ] + user_rbac.c.oid, + user_rbac.c.username, + user_rbac.c.has_access, + user_rbac.c.is_admin, ).where(user_rbac.c.oid == oid) user_status = await database.fetch_one(statement) diff --git a/rctab/routers/accounting/abolishment.py b/rctab/routers/accounting/abolishment.py index 37cf074..1d393c2 100644 --- a/rctab/routers/accounting/abolishment.py +++ b/rctab/routers/accounting/abolishment.py @@ -32,10 +32,8 @@ async def get_inactive_subs() -> Optional[List[UUID]]: # The most recent time_created for each subscription's subscription_detail query_1 = ( select( - [ - subscription_details.c.subscription_id, - func.max(subscription_details.c.time_created).label("time_created"), - ] + subscription_details.c.subscription_id, + func.max(subscription_details.c.time_created).label("time_created"), ).group_by(subscription_details.c.subscription_id) ).alias() @@ -54,7 +52,7 @@ async def get_inactive_subs() -> Optional[List[UUID]]: # subscriptions that have been inactive for more than 90 days # and have not been abolished yet query_2_result = await database.fetch_all( - select([subscription_details.c.subscription_id]) + select(subscription_details.c.subscription_id) .select_from(query_2) .where( and_( @@ -78,7 +76,7 @@ async def adjust_budgets_to_zero(admin_oid: UUID, sub_ids: List[UUID]) -> List[d sub_query = get_subscriptions_summary(execute=False).alias() summaries = ( - select([sub_query]) + select(sub_query) .where(sub_query.c.subscription_id.in_([str(sub_id) for sub_id in sub_ids])) .alias() ) diff --git a/rctab/routers/accounting/cost_recovery.py b/rctab/routers/accounting/cost_recovery.py index affec4f..f815301 100644 --- a/rctab/routers/accounting/cost_recovery.py +++ b/rctab/routers/accounting/cost_recovery.py @@ -67,7 +67,7 @@ async def calc_cost_recovery( but not for this month. """ last_recovered_day = await database.fetch_one( - select([cost_recovery_log]).order_by(desc(cost_recovery_log.c.month)) + select(cost_recovery_log).order_by(desc(cost_recovery_log.c.month)) ) last_recovered_month = ( CostRecoveryMonth(first_day=last_recovered_day["month"]) @@ -86,9 +86,7 @@ async def calc_cost_recovery( subscription_ids = [ r["sub_id"] for r in await database.fetch_all( - select( - [func.distinct(finance.c.subscription_id).label("sub_id")] - ).where( + select(func.distinct(finance.c.subscription_id).label("sub_id")).where( between( recovery_month.first_day, finance.c.date_from, finance.c.date_to ) @@ -99,7 +97,7 @@ async def calc_cost_recovery( for subscription_id in subscription_ids: usage_row = await database.fetch_one( - select([func.sum(usage.c.total_cost).label("the_sum")]) + select(func.sum(usage.c.total_cost).label("the_sum")) .where( func.date_trunc( "month", sqlalchemy.cast(usage.c.date, sqlalchemy.Date) @@ -118,7 +116,7 @@ async def calc_cost_recovery( # The lower the value of priority, the higher importance finance_periods = await database.fetch_all( - select([finance]) + select(finance) .where( and_( between( @@ -136,7 +134,7 @@ async def calc_cost_recovery( for finance_period in finance_periods: cost_recovery_row = await database.fetch_one( - select([func.sum(cost_recovery.c.amount).label("the_sum")]).where( + select(func.sum(cost_recovery.c.amount).label("the_sum")).where( cost_recovery.c.finance_id == finance_period["id"] ) ) @@ -160,21 +158,21 @@ async def calc_cost_recovery( cost_recovery_id = await database.execute( insert( cost_recovery, - { - "subscription_id": finance_period["subscription_id"], - "month": recovery_month.first_day, - "finance_code": finance_period["finance_code"], - "amount": recoverable_amount, - "date_recovered": None, - "finance_id": finance_period["id"], - "admin": admin, - }, - ) + ), + { + "subscription_id": finance_period["subscription_id"], + "month": recovery_month.first_day, + "finance_code": finance_period["finance_code"], + "amount": recoverable_amount, + "date_recovered": None, + "finance_id": finance_period["id"], + "admin": admin, + }, ) cost_recovery_ids.append(cost_recovery_id) inserted_rows = await database.fetch_all( - select([cost_recovery]).where(cost_recovery.c.id.in_(cost_recovery_ids)) + select(cost_recovery).where(cost_recovery.c.id.in_(cost_recovery_ids)) ) # Note that we patch CostRecovery as a unit testing hack @@ -182,10 +180,8 @@ async def calc_cost_recovery( if commit_transaction: await database.execute( - insert( - cost_recovery_log, - {"month": recovery_month.first_day, "admin": admin}, - ) + insert(cost_recovery_log), + {"month": recovery_month.first_day, "admin": admin}, ) await transaction.commit() else: diff --git a/rctab/routers/accounting/desired_states.py b/rctab/routers/accounting/desired_states.py index 5166ce7..f68bb57 100644 --- a/rctab/routers/accounting/desired_states.py +++ b/rctab/routers/accounting/desired_states.py @@ -85,20 +85,16 @@ async def get_desired_states( summaries = get_subscriptions_summary(execute=False).alias() to_be_changed = select( - [ - summaries.c.subscription_id, - case( - [ - # If the desired status is False, we want to disable it - ( - summaries.c.desired_status == False, - SubscriptionState("Disabled"), - ), - # If the desired status is True, we want to enable it - (summaries.c.desired_status == True, SubscriptionState("Enabled")), - ] - ).label("desired_state"), - ] + summaries.c.subscription_id, + case( + # If the desired status is False, we want to disable it + ( + summaries.c.desired_status == False, + SubscriptionState("Disabled"), + ), + # If the desired status is True, we want to enable it + (summaries.c.desired_status == True, SubscriptionState("Enabled")), + ).label("desired_state"), ).where( or_( and_( @@ -151,7 +147,7 @@ async def refresh_desired_states( if subscription_ids: summaries = ( - select([sub_query]) + select(sub_query) .where( sub_query.c.subscription_id.in_( [str(sub_id) for sub_id in subscription_ids] @@ -165,7 +161,7 @@ async def refresh_desired_states( # Subscriptions without an approved_to date or whose approved_to # date is in the past, that currently have a desired_status of True over_time = ( - select([summaries]).where( + select(summaries).where( and_( or_( summaries.c.approved_to == None, @@ -206,7 +202,7 @@ async def refresh_desired_states( # Subscriptions with more usage than allocated budget # that currently have a desired_status of True over_budget = ( - select([summaries]) + select(summaries) .where( and_( # To gracelessly sidestep rounding errors, allow a tolerance @@ -222,25 +218,21 @@ async def refresh_desired_states( over_time_or_over_budget = ( select( - [ - literal_column("1").label("reason_enum"), - over_time.c.subscription_id, - literal_column("uuid('" + str(admin_oid) + "')").label("admin_oid"), - literal_column("False").label("active"), - over_time.c.desired_status, - over_time.c.desired_status_info, - ] + literal_column("1").label("reason_enum"), + over_time.c.subscription_id, + literal_column("uuid('" + str(admin_oid) + "')").label("admin_oid"), + literal_column("False").label("active"), + over_time.c.desired_status, + over_time.c.desired_status_info, ) .union( select( - [ - literal_column("2").label("reason_enum"), - over_budget.c.subscription_id, - literal_column("uuid('" + str(admin_oid) + "')").label("admin_oid"), - literal_column("False").label("active"), - over_budget.c.desired_status, - over_budget.c.desired_status_info, - ] + literal_column("2").label("reason_enum"), + over_budget.c.subscription_id, + literal_column("uuid('" + str(admin_oid) + "')").label("admin_oid"), + literal_column("False").label("active"), + over_budget.c.desired_status, + over_budget.c.desired_status_info, ) ) .alias() @@ -248,32 +240,28 @@ async def refresh_desired_states( over_time_or_over_budget_reason = ( select( - [ - over_time_or_over_budget.c.subscription_id, - over_time_or_over_budget.c.admin_oid, - over_time_or_over_budget.c.active, - case( - [ - ( - func.sum(over_time_or_over_budget.c.reason_enum) == 1, - cast(BillingStatus.EXPIRED, Enum(BillingStatus)), - ), - ( - func.sum(over_time_or_over_budget.c.reason_enum) == 2, - cast(BillingStatus.OVER_BUDGET, Enum(BillingStatus)), - ), - ( - func.sum(over_time_or_over_budget.c.reason_enum) == 3, - cast( - BillingStatus.OVER_BUDGET_AND_EXPIRED, - Enum(BillingStatus), - ), - ), - ], - ).label("reason"), - over_time_or_over_budget.c.desired_status_info.label("old_reason"), - over_time_or_over_budget.c.desired_status.label("old_desired_status"), - ] + over_time_or_over_budget.c.subscription_id, + over_time_or_over_budget.c.admin_oid, + over_time_or_over_budget.c.active, + case( + ( + func.sum(over_time_or_over_budget.c.reason_enum) == 1, + cast(BillingStatus.EXPIRED, Enum(BillingStatus)), + ), + ( + func.sum(over_time_or_over_budget.c.reason_enum) == 2, + cast(BillingStatus.OVER_BUDGET, Enum(BillingStatus)), + ), + ( + func.sum(over_time_or_over_budget.c.reason_enum) == 3, + cast( + BillingStatus.OVER_BUDGET_AND_EXPIRED, + Enum(BillingStatus), + ), + ), + ).label("reason"), + over_time_or_over_budget.c.desired_status_info.label("old_reason"), + over_time_or_over_budget.c.desired_status.label("old_desired_status"), ) .group_by( over_time_or_over_budget.c.subscription_id, @@ -289,13 +277,11 @@ async def refresh_desired_states( # desired status or has the wrong reason or a missing reason over_time_or_over_budget_desired_on = ( select( - [ - over_time_or_over_budget_reason.c.subscription_id, - over_time_or_over_budget_reason.c.admin_oid, - over_time_or_over_budget_reason.c.active, - over_time_or_over_budget_reason.c.reason, - over_time_or_over_budget_reason.c.old_reason, - ] + over_time_or_over_budget_reason.c.subscription_id, + over_time_or_over_budget_reason.c.admin_oid, + over_time_or_over_budget_reason.c.active, + over_time_or_over_budget_reason.c.reason, + over_time_or_over_budget_reason.c.old_reason, ) .where( or_( @@ -319,12 +305,10 @@ async def refresh_desired_states( status_table.c.reason, ], select( - [ - over_time_or_over_budget_desired_on.c.subscription_id, - over_time_or_over_budget_desired_on.c.admin_oid, - over_time_or_over_budget_desired_on.c.active, - over_time_or_over_budget_desired_on.c.reason, - ] + over_time_or_over_budget_desired_on.c.subscription_id, + over_time_or_over_budget_desired_on.c.admin_oid, + over_time_or_over_budget_desired_on.c.active, + over_time_or_over_budget_desired_on.c.reason, ), ) @@ -353,17 +337,15 @@ async def refresh_desired_states( # but aren't currently. These are all of our subscriptions # that are disabled but aren't over time or budget. should_be_enabled_but_are_not = select( - [ - summaries.c.subscription_id, - literal_column("uuid('" + str(admin_oid) + "')"), - literal_column("True"), - literal_column("NULL"), - ] + summaries.c.subscription_id, + literal_column("uuid('" + str(admin_oid) + "')"), + literal_column("True"), + literal_column("NULL"), ).where( and_( not_( summaries.c.subscription_id.in_( - select([over_time_or_over_budget.c.subscription_id]) + select(over_time_or_over_budget.c.subscription_id) ) ), or_( diff --git a/rctab/routers/accounting/finances.py b/rctab/routers/accounting/finances.py index 284d633..466426f 100644 --- a/rctab/routers/accounting/finances.py +++ b/rctab/routers/accounting/finances.py @@ -36,7 +36,7 @@ async def check_create_finance( ) query = ( - select([cost_recovery]) + select(cost_recovery) .where(cost_recovery.c.subscription_id == new_finance.subscription_id) .order_by(desc(cost_recovery.c.id)) ) @@ -63,7 +63,7 @@ async def get_subscription_finances( return [ FinanceWithID(**dict(x)) for x in await database.fetch_all( - select([finance]).where(finance.c.subscription_id == subscription.sub_id) + select(finance).where(finance.c.subscription_id == subscription.sub_id) ) ] @@ -81,7 +81,7 @@ async def post_finance( {"admin": user.oid, **new_finance.model_dump()}, ) new_row = await database.fetch_one( - select([finance]).where(finance.c.id == new_primary_key) + select(finance).where(finance.c.id == new_primary_key) ) assert new_row newly_created_finance = FinanceWithID(**dict(new_row)) @@ -98,7 +98,7 @@ async def check_update_finance(new_finance: FinanceWithID) -> None: raise HTTPException(status_code=400, detail="amount < 0") old_finance_row = await database.fetch_one( - select([finance]).where(finance.c.id == new_finance.id) + select(finance).where(finance.c.id == new_finance.id) ) assert old_finance_row old_finance = FinanceWithID(**dict(old_finance_row)) @@ -106,7 +106,7 @@ async def check_update_finance(new_finance: FinanceWithID) -> None: if old_finance.subscription_id != new_finance.subscription_id: raise HTTPException(status_code=400, detail="Subscription IDs should match") - query = select([cost_recovery_log]).order_by(desc(cost_recovery_log.c.month)) + query = select(cost_recovery_log).order_by(desc(cost_recovery_log.c.month)) last_cost_recovery = await database.fetch_one(query) if last_cost_recovery: last_cost_recovery_dict = {**dict(last_cost_recovery)} @@ -163,7 +163,7 @@ async def get_finance( finance_id: int, _: UserRBAC = Depends(token_admin_verified) ) -> FinanceWithID: """Returns a Finance if given a finance table ID.""" - row = await database.fetch_one(select([finance]).where(finance.c.id == finance_id)) + row = await database.fetch_one(select(finance).where(finance.c.id == finance_id)) if not row: raise HTTPException(status_code=404, detail="Finance not found") return FinanceWithID(**dict(row)) @@ -177,7 +177,7 @@ async def delete_finance( ) -> FinanceWithID: """Deletes a Finance record.""" finance_row = await database.fetch_one( - select([finance]).where(finance.c.id == finance_id) + select(finance).where(finance.c.id == finance_id) ) # Check that we recognise this finance row diff --git a/rctab/routers/accounting/routes.py b/rctab/routers/accounting/routes.py index 55e5526..7cebb5d 100644 --- a/rctab/routers/accounting/routes.py +++ b/rctab/routers/accounting/routes.py @@ -53,7 +53,7 @@ class SubscriptionSummary(BaseModel): @db_select def get_subscriptions() -> Select: """Returns all subscriptions.""" - return select([subscription.c.subscription_id, subscription.c.abolished]) + return select(subscription.c.subscription_id, subscription.c.abolished) @db_select @@ -64,11 +64,9 @@ def get_subscription_details(sub_id: Optional[UUID] = None) -> Select: lateral = ( select( - [ - subscription_details.c.display_name.label("name"), - subscription_details.c.role_assignments, - subscription_details.c.state.label("status"), - ] + subscription_details.c.display_name.label("name"), + subscription_details.c.role_assignments, + subscription_details.c.state.label("status"), ) .where(subscription_details.c.subscription_id == all_subs_sq.c.subscription_id) .order_by(subscription_details.c.id.desc()) @@ -76,7 +74,7 @@ def get_subscription_details(sub_id: Optional[UUID] = None) -> Select: .lateral("o2") ) - query = select([all_subs_sq.c.subscription_id, lateral]).select_from( + query = select(all_subs_sq.c.subscription_id, lateral).select_from( all_subs_sq.join(lateral, true(), isouter=True) ) @@ -96,10 +94,8 @@ def get_sub_allocations_summary(sub_id: Optional[UUID] = None) -> Select: all_subs_sq = get_subscriptions(execute=False).alias() query = select( - [ - all_subs_sq.c.subscription_id, - func.coalesce(func.sum(allocations.c.amount), 0.0).label("allocated"), - ] + all_subs_sq.c.subscription_id, + func.coalesce(func.sum(allocations.c.amount), 0.0).label("allocated"), ).select_from( all_subs_sq.join( allocations, @@ -131,12 +127,10 @@ def get_sub_approvals_summary(sub_id: Optional[UUID] = None) -> Select: all_subs_sq = get_subscriptions(execute=False).alias() query = select( - [ - all_subs_sq.c.subscription_id, - func.min(approvals.c.date_from).label("approved_from"), - func.max(approvals.c.date_to).label("approved_to"), - func.coalesce(func.sum(approvals.c.amount), 0.0).label("approved"), - ] + all_subs_sq.c.subscription_id, + func.min(approvals.c.date_from).label("approved_from"), + func.max(approvals.c.date_to).label("approved_to"), + func.coalesce(func.sum(approvals.c.amount), 0.0).label("approved"), ).select_from( all_subs_sq.join( approvals, @@ -166,14 +160,12 @@ def get_sub_usage_summary( all_subs_sq = get_subscriptions(execute=False).alias() query = select( - [ - all_subs_sq.c.subscription_id, - usage_view.c.first_usage, - usage_view.c.latest_usage, - func.coalesce(usage_view.c.total_cost, 0.0).label("total_cost"), - func.coalesce(usage_view.c.amortised_cost, 0.0).label("amortised_cost"), - func.coalesce(usage_view.c.cost, 0.0).label("cost"), - ] + all_subs_sq.c.subscription_id, + usage_view.c.first_usage, + usage_view.c.latest_usage, + func.coalesce(usage_view.c.total_cost, 0.0).label("total_cost"), + func.coalesce(usage_view.c.amortised_cost, 0.0).label("amortised_cost"), + func.coalesce(usage_view.c.cost, 0.0).label("cost"), ).select_from( all_subs_sq.join( usage_view, @@ -195,14 +187,14 @@ def sub_persistency_status(sub_id: Optional[UUID] = None) -> Select: all_subs_sq = get_subscriptions(execute=False).alias() lateral = ( - select([persistence.c.always_on]) + select(persistence.c.always_on) .where(persistence.c.subscription_id == all_subs_sq.c.subscription_id) .order_by(persistence.c.id.desc()) .limit(1) .lateral("o2") ) - query = select([all_subs_sq.c.subscription_id, lateral]).select_from( + query = select(all_subs_sq.c.subscription_id, lateral).select_from( all_subs_sq.join(lateral, true(), isouter=True) ) @@ -220,10 +212,8 @@ def get_desired_status(sub_id: Optional[Union[UUID, List[UUID]]] = None) -> Sele lateral = ( select( - [ - status.c.active.label("desired_status"), - status.c.reason.label("desired_status_info"), - ] + status.c.active.label("desired_status"), + status.c.reason.label("desired_status_info"), ) .where(status.c.subscription_id == all_subs_sq.c.subscription_id) .order_by(status.c.id.desc()) @@ -231,7 +221,7 @@ def get_desired_status(sub_id: Optional[Union[UUID, List[UUID]]] = None) -> Sele .lateral("o2") ) - query = select([all_subs_sq.c.subscription_id, lateral]).select_from( + query = select(all_subs_sq.c.subscription_id, lateral).select_from( all_subs_sq.join(lateral, true(), isouter=True) ) @@ -277,25 +267,23 @@ def get_subscriptions_summary( all_persistence_sq = sub_persistency_status(execute=False).alias() query = select( - [ - all_subs_sq.c.subscription_id, - all_subs_sq.c.abolished, - all_details_sq.c.name, - all_details_sq.c.role_assignments, - all_details_sq.c.status, - all_approvals_sq.c.approved_from, - all_approvals_sq.c.approved_to, - all_approvals_sq.c.approved, - all_allocations_sq.c.allocated, - all_usage_sq.c.cost, - all_usage_sq.c.amortised_cost, - all_usage_sq.c.total_cost, - all_usage_sq.c.first_usage, - all_usage_sq.c.latest_usage, - all_persistence_sq.c.always_on, - all_desired_status_sq.c.desired_status, - all_desired_status_sq.c.desired_status_info, - ] + all_subs_sq.c.subscription_id, + all_subs_sq.c.abolished, + all_details_sq.c.name, + all_details_sq.c.role_assignments, + all_details_sq.c.status, + all_approvals_sq.c.approved_from, + all_approvals_sq.c.approved_to, + all_approvals_sq.c.approved, + all_allocations_sq.c.allocated, + all_usage_sq.c.cost, + all_usage_sq.c.amortised_cost, + all_usage_sq.c.total_cost, + all_usage_sq.c.first_usage, + all_usage_sq.c.latest_usage, + all_persistence_sq.c.always_on, + all_desired_status_sq.c.desired_status, + all_desired_status_sq.c.desired_status_info, ).select_from( all_subs_sq.join( all_details_sq, @@ -346,13 +334,10 @@ def get_subscriptions_with_disable( ).alias() return select( - [ - subscription_summary_sq, - ( - subscription_summary_sq.c.allocated - - subscription_summary_sq.c.total_cost - ).label("remaining"), - ] + subscription_summary_sq, + ( + subscription_summary_sq.c.allocated - subscription_summary_sq.c.total_cost + ).label("remaining"), ) @@ -370,13 +355,11 @@ def get_total_usage( Select: [description] """ query = select( - [ - func.min(usage_view.c.first_usage).label("first_usage"), - func.max(usage_view.c.latest_usage).label("latest_usage"), - func.sum(usage_view.c.cost).label("cost"), - func.sum(usage_view.c.amortised_cost).label("amortised_cost"), - func.sum(usage_view.c.total_cost).label("total_cost"), - ] + func.min(usage_view.c.first_usage).label("first_usage"), + func.max(usage_view.c.latest_usage).label("latest_usage"), + func.sum(usage_view.c.cost).label("cost"), + func.sum(usage_view.c.amortised_cost).label("amortised_cost"), + func.sum(usage_view.c.total_cost).label("total_cost"), ) if start_date: @@ -393,12 +376,10 @@ def get_allocations(sub_id: UUID) -> Select: """Get all allocations for a subscription.""" return ( select( - [ - allocations.c.ticket, - allocations.c.amount, - allocations.c.currency, - allocations.c.time_created, - ] + allocations.c.ticket, + allocations.c.amount, + allocations.c.currency, + allocations.c.time_created, ) .where(allocations.c.subscription_id == sub_id) .order_by(desc(allocations.c.time_created)) @@ -410,14 +391,12 @@ def get_approvals(sub_id: UUID) -> Select: """Get all approvals for a subscription.""" return ( select( - [ - approvals.c.ticket, - approvals.c.amount, - approvals.c.currency, - approvals.c.date_from, - approvals.c.date_to, - approvals.c.time_created, - ] + approvals.c.ticket, + approvals.c.amount, + approvals.c.currency, + approvals.c.date_from, + approvals.c.date_to, + approvals.c.time_created, ) .where(approvals.c.subscription_id == sub_id) .order_by(desc(approvals.c.date_to)) @@ -429,15 +408,13 @@ def get_finance(sub_id: UUID) -> Select: """Get all finance items for a subscription.""" return ( select( - [ - finance.c.ticket, - finance.c.amount, - finance.c.priority, - finance.c.finance_code, - finance.c.date_from, - finance.c.date_to, - finance.c.time_created, - ] + finance.c.ticket, + finance.c.amount, + finance.c.priority, + finance.c.finance_code, + finance.c.date_from, + finance.c.date_to, + finance.c.time_created, ) .where(finance.c.subscription_id == sub_id) .order_by(desc(finance.c.date_to)) @@ -449,14 +426,12 @@ def get_costrecovery(sub_id: UUID) -> Select: """Get all cost recovery items for a subscription.""" return ( select( - [ - cost_recovery.c.subscription_id, - cost_recovery.c.finance_id, - cost_recovery.c.month, - cost_recovery.c.finance_code, - cost_recovery.c.amount, - cost_recovery.c.date_recovered, - ] + cost_recovery.c.subscription_id, + cost_recovery.c.finance_id, + cost_recovery.c.month, + cost_recovery.c.finance_code, + cost_recovery.c.amount, + cost_recovery.c.date_recovered, ) .where(cost_recovery.c.subscription_id == sub_id) .order_by(desc(cost_recovery.c.month)) @@ -468,55 +443,53 @@ def get_usage(sub_id: UUID, target_date: datetime.datetime) -> Select: """Get all usage items for a subscription.""" return ( select( - [ - usage.c.id, - usage.c.name, - usage.c.type, - usage.c.tags, - usage.c.billing_account_id, - usage.c.billing_account_name, - usage.c.billing_period_start_date, - usage.c.billing_period_end_date, - usage.c.billing_profile_id, - usage.c.billing_profile_name, - usage.c.account_owner_id, - usage.c.account_name, - usage.c.subscription_id, - usage.c.subscription_name, - usage.c.date, - usage.c.product, - usage.c.part_number, - usage.c.meter_id, - usage.c.quantity, - usage.c.effective_price, - usage.c.cost, - usage.c.amortised_cost, - usage.c.total_cost, - usage.c.unit_price, - usage.c.billing_currency, - usage.c.resource_location, - usage.c.consumed_service, - usage.c.resource_id, - usage.c.resource_name, - usage.c.service_info1, - usage.c.service_info2, - usage.c.additional_info, - usage.c.invoice_section, - usage.c.cost_center, - usage.c.resource_group, - usage.c.reservation_id, - usage.c.reservation_name, - usage.c.product_order_id, - usage.c.offer_id, - usage.c.is_azure_credit_eligible, - usage.c.term, - usage.c.publisher_name, - usage.c.publisher_type, - usage.c.plan_name, - usage.c.charge_type, - usage.c.frequency, - usage.c.monthly_upload, - ] + usage.c.id, + usage.c.name, + usage.c.type, + usage.c.tags, + usage.c.billing_account_id, + usage.c.billing_account_name, + usage.c.billing_period_start_date, + usage.c.billing_period_end_date, + usage.c.billing_profile_id, + usage.c.billing_profile_name, + usage.c.account_owner_id, + usage.c.account_name, + usage.c.subscription_id, + usage.c.subscription_name, + usage.c.date, + usage.c.product, + usage.c.part_number, + usage.c.meter_id, + usage.c.quantity, + usage.c.effective_price, + usage.c.cost, + usage.c.amortised_cost, + usage.c.total_cost, + usage.c.unit_price, + usage.c.billing_currency, + usage.c.resource_location, + usage.c.consumed_service, + usage.c.resource_id, + usage.c.resource_name, + usage.c.service_info1, + usage.c.service_info2, + usage.c.additional_info, + usage.c.invoice_section, + usage.c.cost_center, + usage.c.resource_group, + usage.c.reservation_id, + usage.c.reservation_name, + usage.c.product_order_id, + usage.c.offer_id, + usage.c.is_azure_credit_eligible, + usage.c.term, + usage.c.publisher_name, + usage.c.publisher_type, + usage.c.plan_name, + usage.c.charge_type, + usage.c.frequency, + usage.c.monthly_upload, ) .where((usage.c.subscription_id == sub_id) & (usage.c.date >= target_date)) .order_by( @@ -535,6 +508,6 @@ def get_subscription_name(sub_id: Optional[UUID] = None) -> Select: Returns: A SELECT query for all current and former names of the subscription. """ - return select([subscription_details.c.display_name.label("name")]).where( + return select(subscription_details.c.display_name.label("name")).where( subscription_details.c.subscription_id == sub_id ) diff --git a/rctab/routers/accounting/send_emails.py b/rctab/routers/accounting/send_emails.py index 9bdf748..2b666aa 100644 --- a/rctab/routers/accounting/send_emails.py +++ b/rctab/routers/accounting/send_emails.py @@ -318,13 +318,13 @@ def sub_time_based_emails() -> Select: SELECT statement for database query. """ most_recent = ( - select([func.max_(emails.c.id).label("max_id")]) + select(func.max_(emails.c.id).label("max_id")) .where(emails.c.type == EMAIL_TYPE_TIMEBASED) .group_by(emails.c.subscription_id) .alias() ) - return select([emails]).select_from( + return select(emails).select_from( emails.join(most_recent, emails.c.id == most_recent.c.max_id) ) @@ -335,7 +335,7 @@ def sub_usage_emails() -> Select: Warning emails include over-budget, usage alert and expiry-looming emails. """ most_recent = ( - select([func.max_(emails.c.id).label("max_id")]) + select(func.max_(emails.c.id).label("max_id")) .where( or_( emails.c.type == EMAIL_TYPE_OVERBUDGET, @@ -347,7 +347,7 @@ def sub_usage_emails() -> Select: .alias() ) - return select([emails]).select_from( + return select(emails).select_from( emails.join(most_recent, emails.c.id == most_recent.c.max_id) ) @@ -411,7 +411,7 @@ async def check_for_subs_nearing_expiry(database: Database) -> None: # We don't _have_ to filter on approved_to, but it might make things slightly quicker expiry_query = ( - select([summary.c.subscription_id, summary.c.approved_to, summary.c.status]) + select(summary.c.subscription_id, summary.c.approved_to, summary.c.status) .where(summary.c.approved_to <= date.today() + timedelta(days=30)) .order_by(summary.c.approved_to) ) @@ -435,7 +435,7 @@ async def check_for_overbudget_subs(database: Database) -> None: summary = get_subscriptions_summary(execute=False).alias() overbudget_query = ( - select([summary.c.subscription_id, summary.c.allocated, summary.c.total_cost]) + select(summary.c.subscription_id, summary.c.allocated, summary.c.total_cost) .where(summary.c.total_cost > summary.c.allocated) .where( or_( @@ -665,7 +665,7 @@ async def get_new_subscriptions_since( database: Database, since_this_datetime: datetime ) -> List: """Returns a list of all the subscriptions created since the provided datetime.""" - subscription_query = select([subscription]).where( + subscription_query = select(subscription).where( subscription.c.time_created > since_this_datetime ) rows = await database.fetch_all(subscription_query) @@ -689,7 +689,7 @@ async def get_subscription_details_since( """ logger.info("Looking for subscription details since %s", since_this_datetime) status_query = ( - select([subscription_details]) + select(subscription_details) .where(subscription_details.c.subscription_id == subscription_id) .where(subscription_details.c.time_created > since_this_datetime) .order_by(asc(subscription_details.c.id)) @@ -716,7 +716,7 @@ async def get_emails_sent_since( was sent since the specified datetime. """ emails_query = ( - select([emails]) + select(emails) .where(emails.c.type != EMAIL_TYPE_SUMMARY) .where(emails.c.time_created > since_this_datetime) ) @@ -726,7 +726,7 @@ async def get_emails_sent_since( emails_by_subscription = [] for key, value in groupby(all_emails_sent, extract_sub_id): name_query = ( - select([subscription_details.c.display_name.label("name")]) + select(subscription_details.c.display_name.label("name")) .where(subscription_details.c.subscription_id == key) .order_by(desc(subscription_details.c.time_created)) .limit(1) @@ -763,15 +763,13 @@ async def get_finance_entries_since( Each element of the list is a subscription with one or more finance items. """ - finance_query = select([finance]).where( - finance.c.time_created > since_this_datetime - ) + finance_query = select(finance).where(finance.c.time_created > since_this_datetime) rows = await database.fetch_all(finance_query) all_new_entries = [{**row._mapping} for row in rows] entries_by_subscription = [] for key, value in groupby(all_new_entries, extract_sub_id): name_query = ( - select([subscription_details.c.display_name.label("name")]) + select(subscription_details.c.display_name.label("name")) .where(subscription_details.c.subscription_id == key) .order_by(desc(subscription_details.c.time_created)) .limit(1) @@ -905,7 +903,7 @@ async def get_allocations_since( A list with all allocations for the given subscription id since the provided datetime. """ query = ( - select([allocations.c.amount]) + select(allocations.c.amount) .where(allocations.c.subscription_id == subscription_id) .where(allocations.c.time_created > since_this_datetime) ) @@ -928,7 +926,7 @@ async def get_approvals_since( A list with all approvals for the given subscription id since the provided datetime. """ query = ( - select([approvals.c.amount]) + select(approvals.c.amount) .where(approvals.c.subscription_id == subscription_id) .where(approvals.c.time_created > since_this_datetime) ) diff --git a/rctab/routers/accounting/status.py b/rctab/routers/accounting/status.py index fec9e85..6619d7b 100644 --- a/rctab/routers/accounting/status.py +++ b/rctab/routers/accounting/status.py @@ -99,7 +99,7 @@ async def post_status( old_status = SubscriptionStatus(**dict(status_row)) if status_row else None previous_welcome_email = await database.fetch_one( - select([emails]).where( + select(emails).where( and_( emails.c.subscription_id == new_status.subscription_id, emails.c.type == EMAIL_TYPE_SUB_WELCOME, diff --git a/rctab/routers/accounting/summary_emails.py b/rctab/routers/accounting/summary_emails.py index 2df2a5f..eb336de 100644 --- a/rctab/routers/accounting/summary_emails.py +++ b/rctab/routers/accounting/summary_emails.py @@ -94,7 +94,7 @@ async def get_timestamp_last_summary_email() -> Optional[datetime]: The timestamp of the last summary email sent. """ query = ( - select([emails]) + select(emails) .where(emails.c.type == EMAIL_TYPE_SUMMARY) .order_by(desc(emails.c.id)) ) diff --git a/rctab/routers/accounting/usage.py b/rctab/routers/accounting/usage.py index 6a4a092..8367154 100644 --- a/rctab/routers/accounting/usage.py +++ b/rctab/routers/accounting/usage.py @@ -202,7 +202,7 @@ async def post_usage( @router.get("/all-usage", response_model=List[Usage]) async def get_usage(_: UserRBAC = Depends(token_admin_verified)) -> List[Usage]: """Get all usage data.""" - usage_query = select([accounting_models.usage]) + usage_query = select(accounting_models.usage) rows = [dict(x) for x in await database.fetch_all(usage_query)] result = [Usage(**x) for x in rows] @@ -243,7 +243,7 @@ async def post_cm_usage( @router.get("/all-cm-usage", response_model=List[CMUsage]) async def get_cm_usage(_: UserRBAC = Depends(token_admin_verified)) -> List[CMUsage]: """Get all cost-management data.""" - cm_query = select([accounting_models.costmanagement]) + cm_query = select(accounting_models.costmanagement) rows = [dict(x) for x in await database.fetch_all(cm_query)] result = [CMUsage(**x) for x in rows] return result diff --git a/tests/test_routes/conftest.py b/tests/test_routes/conftest.py index 48d9f2a..f6ca37a 100644 --- a/tests/test_routes/conftest.py +++ b/tests/test_routes/conftest.py @@ -15,27 +15,24 @@ def pytest_configure(config: Any) -> None: # pylint: disable=unused-argument This hook is called for every plugin and initial conftest file after command line options have been parsed.""" - conn = engine.connect() - - conn.execute( - insert(user_rbac).values( - (str(constants.ADMIN_UUID), constants.ADMIN_NAME, True, True) + with engine.begin() as conn: + conn.execute( + insert(user_rbac).values( + (str(constants.ADMIN_UUID), constants.ADMIN_NAME, True, True) + ) ) - ) - - conn.close() def pytest_unconfigure(config: Any) -> None: # pylint: disable=unused-argument """Called before test process is exited.""" - conn = engine.connect() + with engine.begin() as conn: - clean_up(conn) + clean_up(conn) - conn.execute(delete(user_rbac).where(user_rbac.c.oid == str(constants.ADMIN_UUID))) - - conn.close() + conn.execute( + delete(user_rbac).where(user_rbac.c.oid == str(constants.ADMIN_UUID)) + ) def clean_up(conn: Connection) -> None: diff --git a/tests/test_routes/test_abolishment.py b/tests/test_routes/test_abolishment.py index 1292f4f..d859544 100644 --- a/tests/test_routes/test_abolishment.py +++ b/tests/test_routes/test_abolishment.py @@ -79,9 +79,7 @@ async def test_abolishment( assert adjustments[0]["approval"] == 10.0 sub_query = get_subscriptions_summary(execute=False).alias() - summary_qr = select([sub_query]).where( - sub_query.c.subscription_id == expired_sub_id - ) + summary_qr = select(sub_query).where(sub_query.c.subscription_id == expired_sub_id) summary = await test_db.fetch_all(summary_qr) assert summary @@ -95,9 +93,7 @@ async def test_abolishment( await set_abolished_flag(inactive_subs) sub_query = get_subscriptions_summary(execute=False).alias() - summary_qr = select([sub_query]).where( - sub_query.c.subscription_id == expired_sub_id - ) + summary_qr = select(sub_query).where(sub_query.c.subscription_id == expired_sub_id) summary = await test_db.fetch_all(summary_qr) assert summary diff --git a/tests/test_routes/test_cost_recovery.py b/tests/test_routes/test_cost_recovery.py index e3399a1..e4423d3 100644 --- a/tests/test_routes/test_cost_recovery.py +++ b/tests/test_routes/test_cost_recovery.py @@ -148,7 +148,7 @@ async def test_cost_recovery_simple( results = [ dict(row) - for row in await test_db.fetch_all(select([accounting_models.cost_recovery])) + for row in await test_db.fetch_all(select(accounting_models.cost_recovery)) ] assert len(results) == 1 @@ -201,7 +201,7 @@ async def test_cost_recovery_two_finances( results = [ dict(row) - for row in await test_db.fetch_all(select([accounting_models.cost_recovery])) + for row in await test_db.fetch_all(select(accounting_models.cost_recovery)) ] assert len(results) == 2 @@ -269,7 +269,7 @@ async def test_cost_recovery_second_month( results = [ dict(row) for row in await test_db.fetch_all( - select([accounting_models.cost_recovery]).order_by( + select(accounting_models.cost_recovery).order_by( accounting_models.cost_recovery.c.id ) ) @@ -338,7 +338,7 @@ async def test_cost_recovery_two_subscriptions( results = [ dict(row) for row in await test_db.fetch_all( - select([accounting_models.cost_recovery]).order_by( + select(accounting_models.cost_recovery).order_by( accounting_models.cost_recovery.c.id ) ) @@ -393,7 +393,7 @@ async def test_cost_recovery_priority_one_month( results = [ dict(row) for row in await test_db.fetch_all( - select([accounting_models.cost_recovery]).order_by( + select(accounting_models.cost_recovery).order_by( accounting_models.cost_recovery.c.finance_code ) ) @@ -450,7 +450,7 @@ async def test_cost_recovery_priority_two_months( results = [ dict(row) for row in await test_db.fetch_all( - select([accounting_models.cost_recovery]).order_by( + select(accounting_models.cost_recovery).order_by( accounting_models.cost_recovery.c.finance_code ) ) @@ -557,7 +557,7 @@ async def test_cost_recovery_commit_param( results = [ dict(row) for row in await no_rollback_test_db.fetch_all( - select([accounting_models.cost_recovery]) + select(accounting_models.cost_recovery) ) ] @@ -567,7 +567,7 @@ async def test_cost_recovery_commit_param( results = [ dict(row) for row in await no_rollback_test_db.fetch_all( - select([accounting_models.cost_recovery_log]) + select(accounting_models.cost_recovery_log) ) ] @@ -633,7 +633,7 @@ async def test_cost_recovery_rollsback( results = [ dict(row) for row in await no_rollback_test_db.fetch_all( - select([accounting_models.cost_recovery]) + select(accounting_models.cost_recovery) ) ] diff --git a/tests/test_routes/test_desired_states.py b/tests/test_routes/test_desired_states.py index 8653f1e..f7351da 100644 --- a/tests/test_routes/test_desired_states.py +++ b/tests/test_routes/test_desired_states.py @@ -57,7 +57,7 @@ async def test_desired_states_budget_adjustment_applied( await refresh_desired_states(constants.ADMIN_UUID, [expired_sub_id]) - desired_state_rows = await test_db.fetch_all(select([status_table])) + desired_state_rows = await test_db.fetch_all(select(status_table)) row_dicts = [dict(row) for row in desired_state_rows] # The subscription expired today @@ -98,7 +98,7 @@ async def test_desired_states_budget_adjustment_approved_ignored( await refresh_desired_states(constants.ADMIN_UUID, [expired_sub_id]) - desired_state_rows = await test_db.fetch_all(select([status_table])) + desired_state_rows = await test_db.fetch_all(select(status_table)) row_dicts = [dict(row) for row in desired_state_rows] # The subscription expired today @@ -138,7 +138,7 @@ async def test_desired_states_budget_adjustment_ignored( await refresh_desired_states(constants.ADMIN_UUID, [expired_sub_id]) - desired_state_rows = await test_db.fetch_all(select([status_table])) + desired_state_rows = await test_db.fetch_all(select(status_table)) row_dicts = [dict(row) for row in desired_state_rows] # The subscription expired today @@ -372,7 +372,7 @@ async def test_refresh_reason_changes(test_db: Database, mocker: MockerFixture) ) assert mock_send_emails.call_count == 1 - desired_state_rows = await test_db.fetch_all(select([status_table])) + desired_state_rows = await test_db.fetch_all(select(status_table)) row_dicts = [dict(row) for row in desired_state_rows] # The subscription expired today @@ -399,7 +399,7 @@ async def test_refresh_reason_changes(test_db: Database, mocker: MockerFixture) row_dicts = [ dict(row) for row in await test_db.fetch_all( - select([status_table]).order_by(status_table.c.time_created) + select(status_table).order_by(status_table.c.time_created) ) ] @@ -430,7 +430,7 @@ async def test_refresh_reason_stays_the_same( await refresh_desired_states(constants.ADMIN_UUID, [expired_sub_id]) - desired_state_rows = await test_db.fetch_all(select([status_table])) + desired_state_rows = await test_db.fetch_all(select(status_table)) row_dicts = [dict(row) for row in desired_state_rows] assert len(row_dicts) == 1 assert row_dicts[0]["reason"] == BillingStatus.EXPIRED @@ -438,7 +438,7 @@ async def test_refresh_reason_stays_the_same( await refresh_desired_states(constants.ADMIN_UUID, [expired_sub_id]) desired_state_rows = await test_db.fetch_all( - select([status_table]).order_by(status_table.c.time_created) + select(status_table).order_by(status_table.c.time_created) ) row_dicts = [dict(row) for row in desired_state_rows] assert len(row_dicts) == 1 @@ -465,7 +465,7 @@ async def test_small_tolerance(test_db: Database, mocker: MockerFixture) -> None await refresh_desired_states(constants.ADMIN_UUID, [close_to_budget_sub_id]) desired_state_rows = await test_db.fetch_all( - select([status_table]).where(status_table.c.reason == None) + select(status_table).where(status_table.c.reason == None) ) row_dicts = [dict(row) for row in desired_state_rows] assert len(row_dicts) == 1 diff --git a/tests/test_routes/test_finances.py b/tests/test_routes/test_finances.py index 76fb8a1..16c3350 100644 --- a/tests/test_routes/test_finances.py +++ b/tests/test_routes/test_finances.py @@ -416,7 +416,7 @@ async def test_finance_history_delete( ) assert actual == [] - rows = await test_db.fetch_all(select([finance_history])) + rows = await test_db.fetch_all(select(finance_history)) dicts = [dict(x) for x in rows] assert len(dicts) == 1 @@ -458,7 +458,7 @@ async def test_finance_history_update( ) assert len(actual) == 1 - rows = await test_db.fetch_all(select([finance_history])) + rows = await test_db.fetch_all(select(finance_history)) dicts = [dict(x) for x in rows] assert len(dicts) == 1 @@ -746,7 +746,7 @@ async def test_check_update_finance_admin( mock_rbac.oid = updater_oid await update_finance(f_b.id, f_b, mock_rbac) - rows = await test_db.fetch_all(select([finance])) + rows = await test_db.fetch_all(select(finance)) updated_finances = [dict(row) for row in rows] assert len(updated_finances) == 1 assert updated_finances[0]["admin"] == updater_oid diff --git a/tests/test_routes/test_routes.py b/tests/test_routes/test_routes.py index ecfde18..5bed132 100644 --- a/tests/test_routes/test_routes.py +++ b/tests/test_routes/test_routes.py @@ -208,7 +208,7 @@ async def test_refresh_desired_states_disable( ], ) - rows = await test_db.fetch_all(select([status]).order_by(status.c.subscription_id)) + rows = await test_db.fetch_all(select(status).order_by(status.c.subscription_id)) disabled_subscriptions = [ (row["subscription_id"], row["reason"]) for row in rows @@ -281,7 +281,7 @@ async def test_refresh_desired_states_enable( [always_on_sub_id, no_allocation_sub_id, currently_disabled_sub_id], ) - rows = await test_db.fetch_all(select([status]).order_by(status.c.subscription_id)) + rows = await test_db.fetch_all(select(status).order_by(status.c.subscription_id)) enabled_subscriptions = [ row["subscription_id"] for row in rows if row["active"] is True @@ -349,12 +349,12 @@ async def test_refresh_desired_states_doesnt_duplicate( ) latest_status_id = ( - select([status.c.subscription_id, func.max(status.c.id).label("max_id")]) + select(status.c.subscription_id, func.max(status.c.id).label("max_id")) .group_by(status.c.subscription_id) .alias() ) - latest_status = select([status.c.subscription_id, status.c.active]).select_from( + latest_status = select(status.c.subscription_id, status.c.active).select_from( status.join( latest_status_id, and_( diff --git a/tests/test_routes/test_send_emails.py b/tests/test_routes/test_send_emails.py index d7142d2..3b71314 100644 --- a/tests/test_routes/test_send_emails.py +++ b/tests/test_routes/test_send_emails.py @@ -440,7 +440,7 @@ async def test_send_generic_emails( mock_get_recipients.return_value, ) - email_query = select([accounting_models.emails]).where( + email_query = select(accounting_models.emails).where( accounting_models.emails.c.type == EMAIL_TYPE_SUB_APPROVAL ) email_results = await test_db.fetch_all(email_query) @@ -750,7 +750,7 @@ async def fetch_one_or_fail(query: Select) -> Record: approved=(100.0, date.today() + timedelta(days=7)), spent=(70.0, 0), ) - the_sub = await fetch_one_or_fail(select([accounting_models.subscription])) + the_sub = await fetch_one_or_fail(select(accounting_models.subscription)) sub_time = the_sub["time_created"] insert_statement = insert(accounting_models.emails) @@ -1190,7 +1190,7 @@ async def test_catches_params_missing( ) # check the last row added is as expected - last_row_query = select([accounting_models.failed_emails]).order_by( + last_row_query = select(accounting_models.failed_emails).order_by( accounting_models.failed_emails.c.id.desc() ) last_row_result = await test_db.fetch_one(last_row_query) diff --git a/tests/test_routes/test_status.py b/tests/test_routes/test_status.py index 546c347..95a0670 100644 --- a/tests/test_routes/test_status.py +++ b/tests/test_routes/test_status.py @@ -611,7 +611,7 @@ async def test_post_status_filters_roles( # No new emails expected mock_send_email.assert_has_calls([welcome_call]) - results = await test_db.fetch_all(select([subscription_details])) + results = await test_db.fetch_all(select(subscription_details)) actual = [SubscriptionStatus(**dict(result)) for result in results] expected = [old_status, newer_status] @@ -706,7 +706,7 @@ async def test_post_status_filters_roles( ) results = await test_db.fetch_all( - select([subscription_details]).order_by(subscription_details.c.id) + select(subscription_details).order_by(subscription_details.c.id) ) actual = [SubscriptionStatus(**dict(result)) for result in results] diff --git a/tests/test_routes/test_transactions.py b/tests/test_routes/test_transactions.py index dd510a7..b00560a 100644 --- a/tests/test_routes/test_transactions.py +++ b/tests/test_routes/test_transactions.py @@ -26,6 +26,6 @@ async def test_databases_rollback( # Should remove the subscription await transaction.rollback() - results = await no_rollback_test_db.fetch_all(select([subscription])) + results = await no_rollback_test_db.fetch_all(select(subscription)) assert len(results) == 0