Skip to content

Commit

Permalink
feat: #1276 add Asyncio SQLAlchemy support
Browse files Browse the repository at this point in the history
  • Loading branch information
galuszkak committed Jan 12, 2025
1 parent 86169db commit 16e211d
Show file tree
Hide file tree
Showing 5 changed files with 680 additions and 1 deletion.
3 changes: 3 additions & 0 deletions requirements/testing.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ boto3<=2
# For AWS tests
moto>=4.0.13,<6
mypy<=1.14.1
# For AsyncSQLAlchemy tests
greenlet<=4
aiosqlite<=1
281 changes: 280 additions & 1 deletion slack_sdk/oauth/installation_store/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
)
from sqlalchemy.engine import Engine
from sqlalchemy.sql.sqltypes import Boolean

from sqlalchemy.ext.asyncio import AsyncEngine
from slack_sdk.oauth.installation_store.installation_store import InstallationStore
from slack_sdk.oauth.installation_store.models.bot import Bot
from slack_sdk.oauth.installation_store.models.installation import Installation
from slack_sdk.oauth.installation_store.async_installation_store import (
AsyncInstallationStore,
)


class SQLAlchemyInstallationStore(InstallationStore):
Expand Down Expand Up @@ -362,3 +365,279 @@ def delete_installation(
)
)
conn.execute(deletion)


class AsyncSQLAlchemyInstallationStore(AsyncInstallationStore):
default_bots_table_name: str = "slack_bots"
default_installations_table_name: str = "slack_installations"

client_id: str
engine: AsyncEngine
metadata: MetaData
installations: Table

def __init__(
self,
client_id: str,
engine: AsyncEngine,
bots_table_name: str = default_bots_table_name,
installations_table_name: str = default_installations_table_name,
logger: Logger = logging.getLogger(__name__),
):
self.metadata = sqlalchemy.MetaData()
self.bots = self.build_bots_table(metadata=self.metadata, table_name=bots_table_name)
self.installations = self.build_installations_table(metadata=self.metadata, table_name=installations_table_name)
self.client_id = client_id
self._logger = logger
self.engine = engine

@classmethod
def build_installations_table(cls, metadata: MetaData, table_name: str) -> Table:
return SQLAlchemyInstallationStore.build_installations_table(metadata, table_name)

@classmethod
def build_bots_table(cls, metadata: MetaData, table_name: str) -> Table:
return SQLAlchemyInstallationStore.build_bots_table(metadata, table_name)

async def create_tables(self):
async with self.engine.begin() as conn:
await conn.run_sync(self.metadata.create_all)

@property
def logger(self) -> Logger:
return self._logger

async def async_save(self, installation: Installation):
async with self.engine.begin() as conn:
i = installation.to_dict()
i["client_id"] = self.client_id

i_column = self.installations.c
installations_rows = await conn.execute(
sqlalchemy.select(i_column.id)
.where(
and_(
i_column.client_id == self.client_id,
i_column.enterprise_id == installation.enterprise_id,
i_column.team_id == installation.team_id,
i_column.installed_at == i.get("installed_at"),
)
)
.limit(1)
)
installations_row_id: Optional[str] = None
for row in installations_rows.mappings():
installations_row_id = row["id"]
if installations_row_id is None:
await conn.execute(self.installations.insert(), i)
else:
update_statement = self.installations.update().where(i_column.id == installations_row_id).values(**i)
await conn.execute(update_statement, i)

# bots
await self.async_save_bot(installation.to_bot())

async def async_save_bot(self, bot: Bot):
async with self.engine.begin() as conn:
# bots
b = bot.to_dict()
b["client_id"] = self.client_id

b_column = self.bots.c
bots_rows = await conn.execute(
sqlalchemy.select(b_column.id)
.where(
and_(
b_column.client_id == self.client_id,
b_column.enterprise_id == bot.enterprise_id,
b_column.team_id == bot.team_id,
b_column.installed_at == b.get("installed_at"),
)
)
.limit(1)
)
bots_row_id: Optional[str] = None
for row in bots_rows.mappings():
bots_row_id = row["id"]
if bots_row_id is None:
await conn.execute(self.bots.insert(), b)
else:
update_statement = self.bots.update().where(b_column.id == bots_row_id).values(**b)
await conn.execute(update_statement, b)

async def async_find_bot(
self,
*,
enterprise_id: Optional[str],
team_id: Optional[str],
is_enterprise_install: Optional[bool] = False,
) -> Optional[Bot]:
if is_enterprise_install or team_id is None:
team_id = None

c = self.bots.c
query = (
self.bots.select()
.where(
and_(
c.client_id == self.client_id,
c.enterprise_id == enterprise_id,
c.team_id == team_id,
c.bot_token.is_not(None), # the latest one that has a bot token
)
)
.order_by(desc(c.installed_at))
.limit(1)
)

async with self.engine.connect() as conn:
result: object = await conn.execute(query)
for row in result.mappings(): # type: ignore[attr-defined]
return Bot(
app_id=row["app_id"],
enterprise_id=row["enterprise_id"],
enterprise_name=row["enterprise_name"],
team_id=row["team_id"],
team_name=row["team_name"],
bot_token=row["bot_token"],
bot_id=row["bot_id"],
bot_user_id=row["bot_user_id"],
bot_scopes=row["bot_scopes"],
bot_refresh_token=row["bot_refresh_token"],
bot_token_expires_at=row["bot_token_expires_at"],
is_enterprise_install=row["is_enterprise_install"],
installed_at=row["installed_at"],
)
return None

async def async_find_installation(
self,
*,
enterprise_id: Optional[str],
team_id: Optional[str],
user_id: Optional[str] = None,
is_enterprise_install: Optional[bool] = False,
) -> Optional[Installation]:
if is_enterprise_install or team_id is None:
team_id = None

c = self.installations.c
where_clause = and_(
c.client_id == self.client_id,
c.enterprise_id == enterprise_id,
c.team_id == team_id,
)
if user_id is not None:
where_clause = and_(
c.client_id == self.client_id,
c.enterprise_id == enterprise_id,
c.team_id == team_id,
c.user_id == user_id,
)

query = self.installations.select().where(where_clause).order_by(desc(c.installed_at)).limit(1)

installation: Optional[Installation] = None
async with self.engine.connect() as conn:
result: object = await conn.execute(query)
for row in result.mappings(): # type: ignore[attr-defined]
installation = Installation(
app_id=row["app_id"],
enterprise_id=row["enterprise_id"],
enterprise_name=row["enterprise_name"],
enterprise_url=row["enterprise_url"],
team_id=row["team_id"],
team_name=row["team_name"],
bot_token=row["bot_token"],
bot_id=row["bot_id"],
bot_user_id=row["bot_user_id"],
bot_scopes=row["bot_scopes"],
bot_refresh_token=row["bot_refresh_token"],
bot_token_expires_at=row["bot_token_expires_at"],
user_id=row["user_id"],
user_token=row["user_token"],
user_scopes=row["user_scopes"],
user_refresh_token=row["user_refresh_token"],
user_token_expires_at=row["user_token_expires_at"],
# Only the incoming webhook issued in the latest installation is set in this logic
incoming_webhook_url=row["incoming_webhook_url"],
incoming_webhook_channel=row["incoming_webhook_channel"],
incoming_webhook_channel_id=row["incoming_webhook_channel_id"],
incoming_webhook_configuration_url=row["incoming_webhook_configuration_url"],
is_enterprise_install=row["is_enterprise_install"],
token_type=row["token_type"],
installed_at=row["installed_at"],
)

has_user_installation = user_id is not None and installation is not None
no_bot_token_installation = installation is not None and installation.bot_token is None
should_find_bot_installation = has_user_installation or no_bot_token_installation
if should_find_bot_installation:
# Retrieve the latest bot token, just in case
# See also: https://github.com/slackapi/bolt-python/issues/664
latest_bot_installation = await self.async_find_bot(
enterprise_id=enterprise_id,
team_id=team_id,
is_enterprise_install=is_enterprise_install,
)
if (
latest_bot_installation is not None
and installation is not None
and installation.bot_token != latest_bot_installation.bot_token
):
installation.bot_id = latest_bot_installation.bot_id
installation.bot_user_id = latest_bot_installation.bot_user_id
installation.bot_token = latest_bot_installation.bot_token
installation.bot_scopes = latest_bot_installation.bot_scopes
installation.bot_refresh_token = latest_bot_installation.bot_refresh_token
installation.bot_token_expires_at = latest_bot_installation.bot_token_expires_at

return installation

async def async_delete_bot(
self,
*,
enterprise_id: Optional[str],
team_id: Optional[str],
) -> None:
table = self.bots
c = table.c
async with self.engine.begin() as conn:
deletion = table.delete().where(
and_(
c.client_id == self.client_id,
c.enterprise_id == enterprise_id,
c.team_id == team_id,
)
)
await conn.execute(deletion)

async def async_delete_installation(
self,
*,
enterprise_id: Optional[str],
team_id: Optional[str],
user_id: Optional[str] = None,
) -> None:
table = self.installations
c = table.c
async with self.engine.begin() as conn:
if user_id is not None:
deletion = table.delete().where(
and_(
c.client_id == self.client_id,
c.enterprise_id == enterprise_id,
c.team_id == team_id,
c.user_id == user_id,
)
)
await conn.execute(deletion)
else:
deletion = table.delete().where(
and_(
c.client_id == self.client_id,
c.enterprise_id == enterprise_id,
c.team_id == team_id,
)
)
await conn.execute(deletion)
71 changes: 71 additions & 0 deletions slack_sdk/oauth/state_store/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from uuid import uuid4

from ..state_store import OAuthStateStore
from ..async_state_store import AsyncOAuthStateStore
import sqlalchemy
from sqlalchemy import Table, Column, Integer, String, DateTime, and_, MetaData
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import AsyncEngine


class SQLAlchemyOAuthStateStore(OAuthStateStore):
Expand Down Expand Up @@ -76,3 +78,72 @@ def consume(self, state: str) -> bool:
message = f"Failed to find any persistent data for state: {state} - {e}"
self.logger.warning(message)
return False


class AsyncSQLAlchemyOAuthStateStore(AsyncOAuthStateStore):
default_table_name: str = "slack_oauth_states"

expiration_seconds: int
engine: AsyncEngine
metadata: MetaData
oauth_states: Table

@classmethod
def build_oauth_states_table(cls, metadata: MetaData, table_name: str) -> Table:
return sqlalchemy.Table(
table_name,
metadata,
metadata,
Column("id", Integer, primary_key=True, autoincrement=True),
Column("state", String(200), nullable=False),
Column("expire_at", DateTime, nullable=False),
)

def __init__(
self,
expiration_seconds: int,
engine: Engine,
logger: Logger = logging.getLogger(__name__),
table_name: str = default_table_name,
):
self.expiration_seconds = expiration_seconds
self._logger = logger
self.engine = engine
self.metadata = MetaData()
self.oauth_states = self.build_oauth_states_table(self.metadata, table_name)

async def create_tables(self):
async with self.engine.begin() as conn:
await conn.run_sync(self.metadata.create_all)

@property
def logger(self) -> Logger:
if self._logger is None:
self._logger = logging.getLogger(__name__)
return self._logger

async def async_issue(self, *args, **kwargs) -> str:
state: str = str(uuid4())
now = datetime.utcfromtimestamp(time.time() + self.expiration_seconds)
async with self.engine.begin() as conn:
await conn.execute(
self.oauth_states.insert(),
{"state": state, "expire_at": now},
)
return state

async def async_consume(self, state: str) -> bool:
try:
async with self.engine.begin() as conn:
c = self.oauth_states.c
query = self.oauth_states.select().where(and_(c.state == state, c.expire_at > datetime.utcnow()))
result = await conn.execute(query)
for row in result.mappings():
self.logger.debug(f"consume's query result: {row}")
await conn.execute(self.oauth_states.delete().where(c.id == row["id"]))
return True
return False
except Exception as e:
message = f"Failed to find any persistent data for state: {state} - {e}"
self.logger.warning(message)
return False
Loading

0 comments on commit 16e211d

Please sign in to comment.