Skip to content
This repository has been archived by the owner on Sep 6, 2024. It is now read-only.

Rate limits #58

Merged
merged 7 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,17 @@ convenient.
TODO: Separate these to variables necessary to run the bot, and those only relevant during development.
-->

| Variable name | Type | Default | Description |
| ---------------------- | ------ | ------------- | ------------------------------------------------------------------------------------------------------------------ |
| `BOT_TOKEN` | string | N/A | Bot token of the discord application (see: [this guide][bot-token-guide] if you don't have one yet) |
| `TVDB_API_KEY` | string | N/A | API key for TVDB (see [this page][tvdb-api-page] if you don't have one yet) |
| `SQLITE_DATABASE_FILE` | path | ./database.db | Path to sqlite database file, can be relative to project root (if the file doesn't yet exists, it will be created) |
| `ECHO_SQL` | bool | 0 | If `1`, print out every SQL command that SQLAlchemy library runs internally (can be useful when debugging) |
| `DEBUG` | bool | 0 | If `1`, debug logs will be enabled, if `0` only info logs and above will be shown |
| `LOG_FILE` | path | N/A | If set, also write the logs into given file, otherwise, only print them |
| `TRACE_LEVEL_FILTER` | custom | N/A | Configuration for trace level logging, see: [trace logs config section](#trace-logs-config) |
| Variable name | Type | Default | Description |
| -------------------------- | ------ | ------------- | ------------------------------------------------------------------------------------------------------------------- |
| `BOT_TOKEN` | string | N/A | Bot token of the discord application (see: [this guide][bot-token-guide] if you don't have one yet) |
| `TVDB_API_KEY` | string | N/A | API key for TVDB (see [this page][tvdb-api-page] if you don't have one yet) |
| `TVDB_RATE_LIMIT_REQUESTS` | int | 5 | Amount of requests that the bot is allowed to make to the TVDB API within `TVDB_RATE_LIMIT_PERIOD` |
| `TVDB_RATE_LIMIT_PERIOD` | float | 5 | Period of time in seconds, within which the bot can make up to `TVDB_RATE_LIMIT_REQUESTS` requests to the TVDB API. |
| `SQLITE_DATABASE_FILE` | path | ./database.db | Path to sqlite database file, can be relative to project root (if the file doesn't yet exists, it will be created) |
| `ECHO_SQL` | bool | 0 | If `1`, print out every SQL command that SQLAlchemy library runs internally (can be useful when debugging) |
| `DEBUG` | bool | 0 | If `1`, debug logs will be enabled, if `0` only info logs and above will be shown |
| `LOG_FILE` | path | N/A | If set, also write the logs into given file, otherwise, only print them |
| `TRACE_LEVEL_FILTER` | custom | N/A | Configuration for trace level logging, see: [trace logs config section](#trace-logs-config) |

[bot-token-guide]: https://guide.pycord.dev/getting-started/creating-your-first-bot#creating-the-bot-application
[tvdb-api-page]: https://www.thetvdb.com/api-information
Expand Down
53 changes: 45 additions & 8 deletions src/exts/error_handler/error_handler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import textwrap
from typing import cast

from discord import Any, ApplicationContext, Cog, Colour, Embed, errors
from discord import Any, ApplicationContext, Cog, Colour, Embed, EmbedField, EmbedFooter, errors
from discord.ext.commands import errors as commands_errors

from src.bot import Bot
from src.settings import FAIL_EMOJI, GITHUB_REPO
from src.utils.log import get_logger
from src.utils.ratelimit import RateLimitExceededError

log = get_logger(__name__)

Expand All @@ -23,12 +24,20 @@ async def send_error_embed(
*,
title: str | None = None,
description: str | None = None,
fields: list[EmbedField] | None = None,
footer: EmbedFooter | None = None,
) -> None:
"""Send an embed regarding the unhandled exception that occurred."""
if title is None and description is None:
raise ValueError("You need to provide either a title or a description.")

embed = Embed(title=title, description=description, color=Colour.red())
embed = Embed(
title=title,
description=description,
color=Colour.red(),
fields=fields,
footer=footer,
)
await ctx.respond(f"Sorry, {ctx.author.mention}", embed=embed)

async def send_unhandled_embed(self, ctx: ApplicationContext, exc: BaseException) -> None:
Expand Down Expand Up @@ -128,6 +137,38 @@ async def _handle_check_failure(

await self.send_unhandled_embed(ctx, exc)

async def _handle_command_invoke_error(
self,
ctx: ApplicationContext,
exc: errors.ApplicationCommandInvokeError,
) -> None:
original_exception = exc.__cause__

if original_exception is None:
await self.send_unhandled_embed(ctx, exc)
log.exception("Got ApplicationCommandInvokeError without a cause.", exc_info=exc)
return

if isinstance(original_exception, RateLimitExceededError):
msg = original_exception.msg or "Hit a rate-limit, please try again later."
time_remaining = f"Expected reset: <t:{round(original_exception.closest_expiration)}:R>"
footer = None
if original_exception.updates_when_exceeded:
footer = EmbedFooter(
text="Spamming the command will only increase the time you have to wait.",
)
await self.send_error_embed(
ctx,
title="Rate limit exceeded",
description=f"{FAIL_EMOJI} {msg}",
fields=[EmbedField(name="", value=time_remaining)],
footer=footer,
)
return

await self.send_unhandled_embed(ctx, original_exception)
log.exception("Unhandled exception occurred.", exc_info=original_exception)

@Cog.listener()
async def on_application_command_error(self, ctx: ApplicationContext, exc: errors.DiscordException) -> None:
"""Handle exceptions that have occurred while running some command."""
Expand All @@ -136,12 +177,8 @@ async def on_application_command_error(self, ctx: ApplicationContext, exc: error
return

if isinstance(exc, errors.ApplicationCommandInvokeError):
original_exception = exc.__cause__

if original_exception is not None:
await self.send_unhandled_embed(ctx, original_exception)
log.exception("Unhandled exception occurred.", exc_info=original_exception)
return
await self._handle_command_invoke_error(ctx, exc)
return

await self.send_unhandled_embed(ctx, exc)

Expand Down
69 changes: 37 additions & 32 deletions src/exts/tvdb_info/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from collections.abc import Sequence
from typing import Literal

import aiohttp
import discord
from discord import ApplicationContext, Cog, option, slash_command

Expand All @@ -10,6 +9,7 @@
from src.tvdb import FetchMeta, Movie, Series, TvdbClient
from src.tvdb.errors import InvalidIdError
from src.utils.log import get_logger
from src.utils.ratelimit import rate_limited

log = get_logger(__name__)

Expand Down Expand Up @@ -106,6 +106,7 @@ class InfoCog(Cog):

def __init__(self, bot: Bot) -> None:
self.bot = bot
self.tvdb_client = TvdbClient(self.bot.http_session, self.bot.cache)

@slash_command()
@option("query", input_type=str, description="The query to search for.")
Expand All @@ -118,6 +119,7 @@ def __init__(self, bot: Bot) -> None:
required=False,
)
@option("by_id", input_type=bool, description="Search by tvdb ID.", required=False)
@rate_limited(key=lambda self, ctx: f"{ctx.user}", limit=2, period=8, update_when_exceeded=True, prefix_key=True)
Paillat-dev marked this conversation as resolved.
Show resolved Hide resolved
async def search(
self,
ctx: ApplicationContext,
Expand All @@ -128,37 +130,40 @@ async def search(
) -> None:
"""Search for a movie or series."""
await ctx.defer()
async with aiohttp.ClientSession() as session:
client = TvdbClient(session, self.bot.cache)
if by_id:
if query.startswith("movie-"):
entity_type = "movie"
query = query[6:]
elif query.startswith("series-"):
entity_type = "series"
query = query[7:]
try:
match entity_type:
case "movie":
response = [await Movie.fetch(query, client, extended=True, meta=FetchMeta.TRANSLATIONS)]
case "series":
response = [await Series.fetch(query, client, extended=True, meta=FetchMeta.TRANSLATIONS)]
case None:
await ctx.respond(
"You must specify a type (movie or series) when searching by ID.", ephemeral=True
)
return
except InvalidIdError:
await ctx.respond(
'Invalid ID. Id must be an integer, or "movie-" / "series-" followed by an integer.',
ephemeral=True,
)
return
else:
response = await client.search(query, limit=5, entity_type=entity_type)
if not response:
await ctx.respond("No results found.")
return

if by_id:
if query.startswith("movie-"):
entity_type = "movie"
query = query[6:]
elif query.startswith("series-"):
entity_type = "series"
query = query[7:]
try:
match entity_type:
case "movie":
response = [
await Movie.fetch(query, self.tvdb_client, extended=True, meta=FetchMeta.TRANSLATIONS)
]
case "series":
response = [
await Series.fetch(query, self.tvdb_client, extended=True, meta=FetchMeta.TRANSLATIONS)
]
case None:
await ctx.respond(
"You must specify a type (movie or series) when searching by ID.", ephemeral=True
)
return
except InvalidIdError:
await ctx.respond(
'Invalid ID. Id must be an integer, or "movie-" / "series-" followed by an integer.',
ephemeral=True,
)
return
else:
response = await self.tvdb_client.search(query, limit=5, entity_type=entity_type)
if not response:
await ctx.respond("No results found.")
return
view = InfoView(response)
await view.send(ctx)

Expand Down
10 changes: 10 additions & 0 deletions src/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,13 @@
"Metadata provided by TheTVDB. Please consider adding missing information or subscribing at " "thetvdb.com."
)
THETVDB_LOGO = "https://www.thetvdb.com/images/attribution/logo1.png"

# The default rate-limit might be a bit too small for production-ready bots that live
# on multiple guilds. But it's good enough for our demonstration purposes and it's
# still actually quite hard to hit this rate-limit on a single guild, unless multiple
# people actually try to make many requests after each other..
#
# Note that tvdb doesn't actually have rate-limits (or at least they aren't documented),
# but we should still be careful not to spam the API too much and be on the safe side.
TVDB_RATE_LIMIT_REQUESTS = get_config("TVDB_RATE_LIMIT_REQUESTS", cast=int, default=5)
TVDB_RATE_LIMIT_PERIOD = get_config("TVDB_RATE_LIMIT_PERIOD", cast=float, default=5) # seconds
13 changes: 12 additions & 1 deletion src/tvdb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from aiocache import BaseCache
from yarl import URL

from src.settings import TVDB_API_KEY
from src.settings import TVDB_API_KEY, TVDB_RATE_LIMIT_PERIOD, TVDB_RATE_LIMIT_REQUESTS
from src.tvdb.generated_models import (
MovieBaseRecord,
MovieExtendedRecord,
Expand All @@ -21,6 +21,7 @@
)
from src.utils.iterators import get_first
from src.utils.log import get_logger
from src.utils.ratelimit import rate_limit

from .errors import BadCallError, InvalidApiKeyError, InvalidIdError

Expand Down Expand Up @@ -266,6 +267,16 @@ async def request(
"""Make an authorized request to the TVDB API."""
log.trace(f"Making TVDB {method} request to {endpoint}")

# TODO: It would be better to instead use a queue to handle rate-limits
# and block until the next request can be made.
await rate_limit(
self.cache,
"tvdb",
limit=TVDB_RATE_LIMIT_REQUESTS,
period=TVDB_RATE_LIMIT_PERIOD,
err_msg="Bot wide rate-limit for TheTVDB API was exceeded.",
)

if self.auth_token is None:
log.trace("No auth token found, requesting initial login.")
await self._login()
Expand Down
1 change: 1 addition & 0 deletions src/utils/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,5 +196,6 @@ def _setup_external_log_levels(root_log: LoggerClass) -> None:
get_logger("discord.gateway").setLevel(logging.WARNING)
get_logger("aiosqlite").setLevel(logging.INFO)
get_logger("alembic.runtime.migration").setLevel(logging.WARNING)
get_logger("aiocache.base").setLevel(logging.INFO)

get_logger("parso").setLevel(logging.WARNING) # For usage in IPython
Loading
Loading