diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ae3d79b71d..b04a133df0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: # - --remove-duplicate-keys # - --remove-unused-variables - repo: https://github.com/asottile/pyupgrade - rev: v3.4.0 + rev: v3.7.0 hooks: - id: pyupgrade args: [--py38-plus] diff --git a/CHANGELOG.md b/CHANGELOG.md index 8fc36e7686..aa6cbc5342 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,8 +60,17 @@ These changes are available on the `master` branch, but have not yet been releas ([#2042](https://github.com/Pycord-Development/pycord/pull/2042)) - Added `icon` and `unicode_emoji` to `Guild.create_role`. ([#2086](https://github.com/Pycord-Development/pycord/pull/2086)) +- Added `cooldown` and `max_concurrency` to `SlashCommandGroup`. + ([#2091](https://github.com/Pycord-Development/pycord/pull/2091)) - Added new embedded activities, Gartic Phone and Jamspace. ([#2102](https://github.com/Pycord-Development/pycord/pull/2102)) +- Added `bridge.Context` as a shortcut to `Union` of subclasses. + ([#2106](https://github.com/Pycord-Development/pycord/pull/2106)) +- Added Annotated forms support for typehinting slash command options. + ([#2124](https://github.com/Pycord-Development/pycord/pull/2124)) +- Added `suppress` and `allowed_mentions` parameters to `Webhook` and + `InteractionResponse` edit methods. + ([#2138](https://github.com/Pycord-Development/pycord/pull/2138)) ### Changed @@ -129,6 +138,8 @@ These changes are available on the `master` branch, but have not yet been releas ([#2079](https://github.com/Pycord-Development/pycord/pull/2079)) - Fixed `HTTPException` when trying to create a forum thread with files. ([#2075](https://github.com/Pycord-Development/pycord/pull/2075)) +- Fixed `before_invoke` not being run for `SlashCommandGroup`. + ([#2091](https://github.com/Pycord-Development/pycord/pull/2091)) - Fixed `AttributeError` when accessing a `Select`'s values when it hasn't been interacted with. ([#2104](https://github.com/Pycord-Development/pycord/pull/2104)) - Fixed `Thread.applied_tags` not being updated. diff --git a/discord/channel.py b/discord/channel.py index 9a2e931eba..076b2c704f 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -751,7 +751,7 @@ def _repr_attrs(self) -> tuple[str, ...]: def _update(self, guild: Guild, data: TextChannelPayload) -> None: super()._update(guild, data) - async def _get_channel(self) -> "TextChannel": + async def _get_channel(self) -> TextChannel: return self def is_news(self) -> bool: @@ -1064,11 +1064,11 @@ async def edit( available_tags: list[ForumTag] = ..., require_tag: bool = ..., overwrites: Mapping[Role | Member | Snowflake, PermissionOverwrite] = ..., - ) -> "ForumChannel" | None: + ) -> ForumChannel | None: ... @overload - async def edit(self) -> "ForumChannel" | None: + async def edit(self) -> ForumChannel | None: ... async def edit(self, *, reason=None, **options): diff --git a/discord/colour.py b/discord/colour.py index c202351859..a8dbf7236a 100644 --- a/discord/colour.py +++ b/discord/colour.py @@ -335,8 +335,8 @@ def nitro_pink(cls: type[CT]) -> CT: @classmethod def embed_background(cls: type[CT], theme: str = "dark") -> CT: - """A factory method that returns a :class:`Color` corresponding to the - embed colors on discord clients, with a value of: + """A factory method that returns a :class:`Colour` corresponding to the + embed colours on discord clients, with a value of: - ``0x2B2D31`` (dark) - ``0xEEEFF1`` (light) @@ -347,7 +347,7 @@ def embed_background(cls: type[CT], theme: str = "dark") -> CT: Parameters ---------- theme: :class:`str` - The theme color to apply, must be one of "dark", "light", or "amoled". + The theme colour to apply, must be one of "dark", "light", or "amoled". """ themes_cls = { "dark": 0x2B2D31, diff --git a/discord/commands/core.py b/discord/commands/core.py index 90a614a786..09222ab72b 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -30,6 +30,7 @@ import functools import inspect import re +import sys import types from collections import OrderedDict from enum import Enum @@ -64,6 +65,11 @@ from .context import ApplicationContext, AutocompleteContext from .options import Option, OptionChoice +if sys.version_info >= (3, 11): + from typing import Annotated, get_args, get_origin +else: + from typing_extensions import Annotated, get_args, get_origin + __all__ = ( "_BaseCommand", "ApplicationCommand", @@ -84,6 +90,7 @@ from .. import Permissions from ..cog import Cog + from ..ext.commands.cooldowns import CooldownMapping, MaxConcurrency T = TypeVar("T") CogT = TypeVar("CogT", bound="Cog") @@ -294,18 +301,17 @@ async def prepare(self, ctx: ApplicationContext) -> None: f"The check functions for the command {self.name} failed" ) - if hasattr(self, "_max_concurrency"): - if self._max_concurrency is not None: - # For this application, context can be duck-typed as a Message - await self._max_concurrency.acquire(ctx) # type: ignore # ctx instead of non-existent message + if self._max_concurrency is not None: + # For this application, context can be duck-typed as a Message + await self._max_concurrency.acquire(ctx) # type: ignore # ctx instead of non-existent message - try: - self._prepare_cooldowns(ctx) - await self.call_before_hooks(ctx) - except: - if self._max_concurrency is not None: - await self._max_concurrency.release(ctx) # type: ignore # ctx instead of non-existent message - raise + try: + self._prepare_cooldowns(ctx) + await self.call_before_hooks(ctx) + except: + if self._max_concurrency is not None: + await self._max_concurrency.release(ctx) # type: ignore # ctx instead of non-existent message + raise def is_on_cooldown(self, ctx: ApplicationContext) -> bool: """Checks whether the command is currently on cooldown. @@ -732,6 +738,19 @@ def _parse_options(self, params, *, check_params: bool = True) -> list[Option]: if option == inspect.Parameter.empty: option = str + if self._is_typing_annotated(option): + type_hint = get_args(option)[0] + metadata = option.__metadata__ + # If multiple Options in metadata, the first will be used. + option_gen = (elem for elem in metadata if isinstance(elem, Option)) + option = next(option_gen, Option()) + # Handle Optional + if self._is_typing_optional(type_hint): + option.input_type = get_args(type_hint)[0] + option.default = None + else: + option.input_type = type_hint + if self._is_typing_union(option): if self._is_typing_optional(option): option = Option(option.__args__[0], default=None) @@ -820,6 +839,9 @@ def _is_typing_union(self, annotation): def _is_typing_optional(self, annotation): return self._is_typing_union(annotation) and type(None) in annotation.__args__ # type: ignore + def _is_typing_annotated(self, annotation): + return get_origin(annotation) is Annotated + @property def cog(self): return getattr(self, "_cog", MISSING) @@ -1119,6 +1141,8 @@ def __init__( description: str | None = None, guild_ids: list[int] | None = None, parent: SlashCommandGroup | None = None, + cooldown: CooldownMapping | None = None, + max_concurrency: MaxConcurrency | None = None, **kwargs, ) -> None: self.name = str(name) @@ -1153,6 +1177,33 @@ def __init__( "description_localizations", MISSING ) + # similar to ApplicationCommand + from ..ext.commands.cooldowns import BucketType, CooldownMapping, MaxConcurrency + + # no need to getattr, since slash cmds groups cant be created using a decorator + + if cooldown is None: + buckets = CooldownMapping(cooldown, BucketType.default) + elif isinstance(cooldown, CooldownMapping): + buckets = cooldown + else: + raise TypeError( + "Cooldown must be a an instance of CooldownMapping or None." + ) + + self._buckets: CooldownMapping = buckets + + # no need to getattr, since slash cmds groups cant be created using a decorator + + if max_concurrency is not None and not isinstance( + max_concurrency, MaxConcurrency + ): + raise TypeError( + "max_concurrency must be an instance of MaxConcurrency or None" + ) + + self._max_concurrency: MaxConcurrency | None = max_concurrency + @property def module(self) -> str | None: return self.__module__ diff --git a/discord/embeds.py b/discord/embeds.py index b7037dd349..fb34acc33f 100644 --- a/discord/embeds.py +++ b/discord/embeds.py @@ -173,16 +173,18 @@ class EmbedMedia: # Thumbnail, Image, Video url: str proxy_url: str | None - height: int - width: int + height: int | None + width: int | None @classmethod def from_dict(cls, data: dict[str, str | int]) -> EmbedMedia: self = cls.__new__(cls) self.url = str(data.get("url")) - self.proxy_url = str(data.get("proxy_url")) - self.height = int(data["height"]) - self.width = int(data["width"]) + self.proxy_url = ( + str(proxy_url) if (proxy_url := data.get("proxy_url")) else None + ) + self.height = int(height) if (height := data.get("height")) else None + self.width = int(width) if (width := data.get("width")) else None return self def __repr__(self) -> str: diff --git a/discord/ext/bridge/context.py b/discord/ext/bridge/context.py index 3ba5989886..033fdd6caf 100644 --- a/discord/ext/bridge/context.py +++ b/discord/ext/bridge/context.py @@ -25,7 +25,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, overload +from typing import TYPE_CHECKING, Any, Union, overload from discord.commands import ApplicationContext from discord.interactions import Interaction, InteractionMessage @@ -38,7 +38,7 @@ from .core import BridgeExtCommand, BridgeSlashCommand -__all__ = ("BridgeContext", "BridgeExtContext", "BridgeApplicationContext") +__all__ = ("BridgeContext", "BridgeExtContext", "BridgeApplicationContext", "Context") class BridgeContext(ABC): @@ -195,3 +195,10 @@ async def delete( """ if self._original_response_message: await self._original_response_message.delete(delay=delay, reason=reason) + + +Context = Union[BridgeExtContext, BridgeApplicationContext] +""" +A Union class for either :class:`BridgeExtContext` or :class:`BridgeApplicationContext`. +Can be used as a type hint for Context for bridge commands. +""" diff --git a/discord/ext/commands/flags.py b/discord/ext/commands/flags.py index 0848a831a1..54e7e0c37c 100644 --- a/discord/ext/commands/flags.py +++ b/discord/ext/commands/flags.py @@ -33,6 +33,11 @@ from discord.utils import MISSING, MissingField, maybe_coroutine, resolve_annotation +if sys.version_info >= (3, 11): + _MISSING = MissingField +else: + _MISSING = MISSING + from .converter import run_converters from .errors import ( BadFlagArgument, @@ -81,13 +86,13 @@ class Flag: Whether multiple given values overrides the previous value. """ - name: str = MISSING + name: str = _MISSING aliases: list[str] = field(default_factory=list) - attribute: str = MISSING - annotation: Any = MISSING - default: Any = MISSING - max_args: int = MISSING - override: bool = MISSING + attribute: str = _MISSING + annotation: Any = _MISSING + default: Any = _MISSING + max_args: int = _MISSING + override: bool = _MISSING cast_to_dict: bool = False @property diff --git a/discord/interactions.py b/discord/interactions.py index 4a74edf201..83ca14f128 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -33,6 +33,7 @@ from .enums import InteractionResponseType, InteractionType, try_enum from .errors import ClientException, InteractionResponded, InvalidArgument from .file import File +from .flags import MessageFlags from .member import Member from .message import Attachment, Message from .object import Object @@ -386,6 +387,7 @@ async def edit_original_response( view: View | None = MISSING, allowed_mentions: AllowedMentions | None = None, delete_after: float | None = None, + suppress: bool = False, ) -> InteractionMessage: """|coro| @@ -424,6 +426,8 @@ async def edit_original_response( If provided, the number of seconds to wait in the background before deleting the message we just edited. If the deletion fails, then it is silently ignored. + suppress: :class:`bool` + Whether to suppress embeds for the message. Returns ------- @@ -453,6 +457,7 @@ async def edit_original_response( view=view, allowed_mentions=allowed_mentions, previous_allowed_mentions=previous_mentions, + suppress=suppress, ) adapter = async_context.get() http = self._state.http @@ -936,6 +941,8 @@ async def edit_message( attachments: list[Attachment] = MISSING, view: View | None = MISSING, delete_after: float | None = None, + suppress: bool | None = MISSING, + allowed_mentions: AllowedMentions | None = None, ) -> None: """|coro| @@ -966,6 +973,15 @@ async def edit_message( If provided, the number of seconds to wait in the background before deleting the message we just edited. If the deletion fails, then it is silently ignored. + suppress: Optional[:class:`bool`] + Whether to suppress embeds for the message. + allowed_mentions: Optional[:class:`~discord.AllowedMentions`] + Controls the mentions being processed in this message. If this is + passed, then the object is merged with :attr:`~discord.Client.allowed_mentions`. + The merging behaviour only overrides attributes that have been explicitly passed + to the object, otherwise it uses the attributes set in :attr:`~discord.Client.allowed_mentions`. + If no object is passed at all then the defaults given by :attr:`~discord.Client.allowed_mentions` + are used instead. Raises ------ @@ -1029,6 +1045,23 @@ async def edit_message( # we keep previous attachments when adding new files payload["attachments"] = [a.to_dict() for a in msg.attachments] + if suppress is not MISSING: + flags = MessageFlags._from_value(self._parent.message.flags.value) + flags.suppress_embeds = suppress + payload["flags"] = flags.value + + if allowed_mentions is None: + payload["allowed_mentions"] = ( + state.allowed_mentions and state.allowed_mentions.to_dict() + ) + + elif state.allowed_mentions is not None: + payload["allowed_mentions"] = state.allowed_mentions.merge( + allowed_mentions + ).to_dict() + else: + payload["allowed_mentions"] = allowed_mentions.to_dict() + adapter = async_context.get() http = parent._state.http try: @@ -1215,6 +1248,7 @@ async def edit( view: View | None = MISSING, allowed_mentions: AllowedMentions | None = None, delete_after: float | None = None, + suppress: bool | None = MISSING, ) -> InteractionMessage: """|coro| @@ -1247,6 +1281,8 @@ async def edit( If provided, the number of seconds to wait in the background before deleting the message we just edited. If the deletion fails, then it is silently ignored. + suppress: Optional[:class:`bool`] + Whether to suppress embeds for the message. Returns ------- @@ -1266,6 +1302,8 @@ async def edit( """ if attachments is MISSING: attachments = self.attachments or MISSING + if suppress is MISSING: + suppress = self.flags.suppress_embeds return await self._state._interaction.edit_original_response( content=content, embeds=embeds, @@ -1276,6 +1314,7 @@ async def edit( view=view, allowed_mentions=allowed_mentions, delete_after=delete_after, + suppress=suppress, ) async def delete(self, *, delay: float | None = None) -> None: diff --git a/discord/webhook/async_.py b/discord/webhook/async_.py index f7d98f80a9..3a98c4baca 100644 --- a/discord/webhook/async_.py +++ b/discord/webhook/async_.py @@ -47,6 +47,7 @@ InvalidArgument, NotFound, ) +from ..flags import MessageFlags from ..http import Route from ..message import Attachment, Message from ..mixins import Hashable @@ -622,6 +623,7 @@ def handle_message_parameters( view: View | None = MISSING, allowed_mentions: AllowedMentions | None = MISSING, previous_allowed_mentions: AllowedMentions | None = None, + suppress: bool = False, ) -> ExecuteWebhookParameters: if files is not MISSING and file is not MISSING: raise TypeError("Cannot mix file and files keyword arguments.") @@ -648,8 +650,9 @@ def handle_message_parameters( payload["avatar_url"] = str(avatar_url) if username: payload["username"] = username - if ephemeral: - payload["flags"] = 64 + + flags = MessageFlags(suppress_embeds=suppress, ephemeral=ephemeral) + payload["flags"] = flags.value if allowed_mentions: if previous_allowed_mentions is not None: @@ -827,6 +830,7 @@ async def edit( attachments: list[Attachment] = MISSING, view: View | None = MISSING, allowed_mentions: AllowedMentions | None = None, + suppress: bool | None = MISSING, ) -> WebhookMessage: """|coro| @@ -868,6 +872,8 @@ async def edit( the view is removed. .. versionadded:: 2.0 + suppress: Optional[:class:`bool`] + Whether to suppress embeds for the message. Returns ------- @@ -898,6 +904,9 @@ async def edit( if attachments is MISSING: attachments = self.attachments or MISSING + if suppress is MISSING: + suppress = self.flags.suppress_embeds + return await self._state._webhook.edit_message( self.id, content=content, @@ -909,6 +918,7 @@ async def edit( view=view, allowed_mentions=allowed_mentions, thread=thread, + suppress=suppress, ) async def delete(self, *, delay: float | None = None) -> None: @@ -1845,6 +1855,7 @@ async def edit_message( view: View | None = MISSING, allowed_mentions: AllowedMentions | None = None, thread: Snowflake | None = MISSING, + suppress: bool = False, ) -> WebhookMessage: """|coro| @@ -1892,6 +1903,8 @@ async def edit_message( .. versionadded:: 2.0 thread: Optional[:class:`~discord.abc.Snowflake`] The thread that contains the message. + suppress: :class:`bool` + Whether to suppress embeds for the message. Returns ------- @@ -1939,6 +1952,7 @@ async def edit_message( view=view, allowed_mentions=allowed_mentions, previous_allowed_mentions=previous_mentions, + suppress=suppress, ) thread_id: int | None = None diff --git a/discord/webhook/sync.py b/discord/webhook/sync.py index 042130ac25..9e812b5709 100644 --- a/discord/webhook/sync.py +++ b/discord/webhook/sync.py @@ -472,6 +472,7 @@ def edit( file: File = MISSING, files: list[File] = MISSING, allowed_mentions: AllowedMentions | None = None, + suppress: bool | None = MISSING, ) -> SyncWebhookMessage: """Edits the message. @@ -492,6 +493,8 @@ def edit( allowed_mentions: :class:`AllowedMentions` Controls the mentions being processed in this message. See :meth:`.abc.Messageable.send` for more information. + suppress: Optional[:class:`bool`] + Whether to suppress embeds for the message. Returns ------- @@ -517,6 +520,9 @@ def edit( elif isinstance(self.channel, Thread): thread = Object(self.channel.id) + if suppress is MISSING: + suppress = self.flags.suppress_embeds + return self._state._webhook.edit_message( self.id, content=content, @@ -526,6 +532,7 @@ def edit( files=files, allowed_mentions=allowed_mentions, thread=thread, + suppress=suppress, ) def delete(self, *, delay: float | None = None) -> None: @@ -952,6 +959,7 @@ def send( thread: Snowflake = MISSING, thread_name: str | None = None, wait: Literal[False] = ..., + suppress: bool = MISSING, ) -> None: ... @@ -970,6 +978,7 @@ def send( thread: Snowflake = MISSING, thread_name: str | None = None, wait: bool = False, + suppress: bool = False, ) -> SyncWebhookMessage | None: """Sends a message using the webhook. @@ -1022,6 +1031,8 @@ def send( The name of the thread to create. Only works for forum channels. .. versionadded:: 2.0 + suppress: :class:`bool` + Whether to suppress embeds for the message. Returns ------- @@ -1070,6 +1081,7 @@ def send( embeds=embeds, allowed_mentions=allowed_mentions, previous_allowed_mentions=previous_mentions, + suppress=suppress, ) adapter: WebhookAdapter = _get_webhook_adapter() thread_id: int | None = None @@ -1151,6 +1163,7 @@ def edit_message( files: list[File] = MISSING, allowed_mentions: AllowedMentions | None = None, thread: Snowflake | None = MISSING, + suppress: bool = False, ) -> SyncWebhookMessage: """Edits a message owned by this webhook. @@ -1211,6 +1224,7 @@ def edit_message( embeds=embeds, allowed_mentions=allowed_mentions, previous_allowed_mentions=previous_mentions, + suppress=suppress, ) adapter: WebhookAdapter = _get_webhook_adapter() diff --git a/docs/ext/bridge/api.rst b/docs/ext/bridge/api.rst index 327e6600fb..d8f4d78821 100644 --- a/docs/ext/bridge/api.rst +++ b/docs/ext/bridge/api.rst @@ -151,3 +151,9 @@ BridgeContext Subclasses .. autoclass:: discord.ext.bridge.BridgeExtContext :members: + +.. attributetable:: discord.ext.bridge.Context + +.. data:: discord.ext.bridge.Context + + Alias of :data:`typing.Union` [ :class:`.BridgeExtContext`, :class:`.BridgeApplicationContext` ] for typing convenience. diff --git a/requirements/dev.txt b/requirements/dev.txt index 5be4239173..a62ccb7f74 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -1,11 +1,11 @@ -r _.txt pylint~=2.17.4 -pytest~=7.3.1 +pytest~=7.4.0 pytest-asyncio~=0.21.0 # pytest-order~=1.0.1 -mypy~=1.3.0 +mypy~=1.4.0 coverage~=7.2 -pre-commit==3.3.2 -codespell==2.2.4 +pre-commit==3.3.3 +codespell==2.2.5 bandit==1.7.5 flake8==6.0.0 diff --git a/tests/test_typing_annotated.py b/tests/test_typing_annotated.py new file mode 100644 index 0000000000..582bd4f8a0 --- /dev/null +++ b/tests/test_typing_annotated.py @@ -0,0 +1,86 @@ +from typing import Optional + +import pytest +from typing_extensions import Annotated + +import discord +from discord import ApplicationContext +from discord.commands.core import SlashCommand, slash_command + + +def test_typing_annotated(): + async def echo(ctx, txt: Annotated[str, discord.Option()]): + await ctx.respond(txt) + + cmd = SlashCommand(echo) + bot = discord.Bot() + bot.add_application_command(cmd) + + +def test_typing_annotated_decorator(): + bot = discord.Bot() + + @bot.slash_command() + async def echo(ctx, txt: Annotated[str, discord.Option(description="Some text")]): + await ctx.respond(txt) + + +def test_typing_annotated_cog(): + class echoCog(discord.Cog): + def __init__(self, bot_) -> None: + self.bot = bot_ + super().__init__() + + @slash_command() + async def echo( + self, ctx, txt: Annotated[str, discord.Option(description="Some text")] + ): + await ctx.respond(txt) + + bot = discord.Bot() + bot.add_cog(echoCog(bot)) + + +def test_typing_annotated_cog_slashgroup(): + class echoCog(discord.Cog): + grp = discord.commands.SlashCommandGroup("echo") + + def __init__(self, bot_) -> None: + self.bot = bot_ + super().__init__() + + @grp.command() + async def echo( + self, ctx, txt: Annotated[str, discord.Option(description="Some text")] + ): + await ctx.respond(txt) + + bot = discord.Bot() + bot.add_cog(echoCog(bot)) + + +def test_typing_annotated_optional(): + async def echo(ctx, txt: Annotated[Optional[str], discord.Option()]): + await ctx.respond(txt) + + cmd = SlashCommand(echo) + bot = discord.Bot() + bot.add_application_command(cmd) + + +def test_no_annotation(): + async def echo(ctx, txt: str): + await ctx.respond(txt) + + cmd = SlashCommand(echo) + bot = discord.Bot() + bot.add_application_command(cmd) + + +def test_annotated_no_option(): + async def echo(ctx, txt: Annotated[str, "..."]): + await ctx.respond(txt) + + cmd = SlashCommand(echo) + bot = discord.Bot() + bot.add_application_command(cmd)