Skip to content
This repository has been archived by the owner on Sep 6, 2024. It is now read-only.

Commit

Permalink
add pdf
Browse files Browse the repository at this point in the history
  • Loading branch information
Paillat-dev committed Aug 1, 2024
1 parent 367fe86 commit 2a35f89
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 92 deletions.
Binary file added presentation/presentation.pdf
Binary file not shown.
4 changes: 4 additions & 0 deletions src/db_adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
list_remove_item_safe,
refresh_list_items,
)
from .media import series_get
from .misc import refresh
from .user import user_create_list, user_get, user_get_list_safe, user_get_safe

__all__ = [
Expand All @@ -21,4 +23,6 @@
"refresh_list_items",
"get_list_item",
"list_remove_item_safe",
"refresh",
"series_get",
]
79 changes: 46 additions & 33 deletions src/db_adapters/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,39 +32,47 @@ async def list_put_item(
:raises ValueError: If the item is already present in the list.
"""
if series_id:
await ensure_media(session, tvdb_id, kind, series_id=series_id)
else:
await ensure_media(session, tvdb_id, kind)
if await session.get(UserListItem, (user_list.id, tvdb_id, kind)) is not None:
raise ValueError(f"Item {tvdb_id} is already in list {user_list.id}.")

item = UserListItem(list_id=user_list.id, tvdb_id=tvdb_id, kind=kind)
session.add(item)
await session.commit()
return item
async with session:
if series_id:
await ensure_media(session, tvdb_id, kind, series_id=series_id)
else:
await ensure_media(session, tvdb_id, kind)
if await session.get(UserListItem, (user_list.id, tvdb_id, kind)) is not None:
raise ValueError(f"Item {tvdb_id} is already in list {user_list.id}.")

item = UserListItem(list_id=user_list.id, tvdb_id=tvdb_id, kind=kind)
session.add(item)
await session.commit()
return item


async def list_get_item(
session: AsyncSession, user_list: UserList, tvdb_id: int, kind: UserListItemKind
) -> UserListItem | None:
"""Get an item from a user list."""
return await session.get(UserListItem, (user_list.id, tvdb_id, kind))
async with session:
return await session.get(UserListItem, (user_list.id, tvdb_id, kind))


async def list_remove_item(session: AsyncSession, user_list: UserList, item: UserListItem) -> None:
async def list_remove_item(session: AsyncSession, user_list: UserList, item: UserListItem) -> UserList:
"""Remove an item from a user list."""
await session.delete(item)
await session.commit()
await session.refresh(user_list, ["items"])
async with session:
item = await session.merge(item)
user_list = await session.merge(user_list)
await session.delete(item)
await session.commit()
await session.refresh(user_list, ["items"])
return user_list


async def list_remove_item_safe(
session: AsyncSession, user_list: UserList, tvdb_id: int, kind: UserListItemKind
) -> None:
) -> UserList:
"""Removes an item from a user list if it exists."""
if item := await list_get_item(session, user_list, tvdb_id, kind):
await list_remove_item(session, user_list, item)
async with session:
if item := await list_get_item(session, user_list, tvdb_id, kind):
return await list_remove_item(session, user_list, item)
return user_list


@overload
Expand All @@ -90,23 +98,27 @@ async def list_put_item_safe(
session: AsyncSession, user_list: UserList, tvdb_id: int, kind: UserListItemKind, series_id: int | None = None
) -> UserListItem:
"""Add an item to a user list, or return the existing item if it is already present."""
if series_id:
await ensure_media(session, tvdb_id, kind, series_id=series_id)
else:
await ensure_media(session, tvdb_id, kind)
item = await list_get_item(session, user_list, tvdb_id, kind)
if item:
async with session:
if series_id:
await ensure_media(session, tvdb_id, kind, series_id=series_id)
else:
await ensure_media(session, tvdb_id, kind)
item = await list_get_item(session, user_list, tvdb_id, kind)
if item:
return item

item = UserListItem(list_id=user_list.id, tvdb_id=tvdb_id, kind=kind)
session.add(item)
await session.commit()
return item

item = UserListItem(list_id=user_list.id, tvdb_id=tvdb_id, kind=kind)
session.add(item)
await session.commit()
return item


async def refresh_list_items(session: AsyncSession, user_list: UserList) -> None:
async def refresh_list_items(session: AsyncSession, user_list: UserList) -> UserList:
"""Refresh the items in a user list."""
await session.refresh(user_list, ["items"])
async with session:
user_list = await session.merge(user_list)
await session.refresh(user_list, ["items"])
return user_list


async def get_list_item(
Expand All @@ -116,4 +128,5 @@ async def get_list_item(
kind: UserListItemKind,
) -> UserListItem | None:
"""Get a user list."""
return await session.get(UserListItem, (user_list.id, tvdb_id, kind))
async with session:
return await session.get(UserListItem, (user_list.id, tvdb_id, kind))
43 changes: 25 additions & 18 deletions src/db_adapters/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,29 @@

async def ensure_media(session: AsyncSession, tvdb_id: int, kind: UserListItemKind, **kwargs: Any) -> None:
"""Ensure that a tvdb media item is present in its respective table."""
match kind:
case UserListItemKind.MOVIE:
cls = Movie
case UserListItemKind.SERIES:
cls = Series
case UserListItemKind.EPISODE:
cls = Episode
media = await session.get(cls, tvdb_id)
if media is None:
media = cls(tvdb_id=tvdb_id, **kwargs)
session.add(media)
await session.commit()

if isinstance(media, Episode):
await session.refresh(media, ["series"])
if not media.series:
series = Series(tvdb_id=kwargs["series_id"])
session.add(series)
async with session:
match kind:
case UserListItemKind.MOVIE:
cls = Movie
case UserListItemKind.SERIES:
cls = Series
case UserListItemKind.EPISODE:
cls = Episode
media = await session.get(cls, tvdb_id)
if media is None:
media = cls(tvdb_id=tvdb_id, **kwargs)
session.add(media)
await session.commit()

if isinstance(media, Episode):
await session.refresh(media, ["series"])
if not media.series:
series = Series(tvdb_id=kwargs["series_id"])
session.add(series)
await session.commit()


async def series_get(session: AsyncSession, tvdb_id: int) -> Series | None:
"""Get a series from the database."""
async with session:
return await session.get(Series, tvdb_id)
9 changes: 9 additions & 0 deletions src/db_adapters/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from sqlalchemy.ext.asyncio import AsyncSession


async def refresh[T](session: AsyncSession, item: T, fields: list[str]) -> T:
"""Refresh a media item with the specified fields."""
async with session:
item = await session.merge(item)
await session.refresh(item, fields)
return item
51 changes: 28 additions & 23 deletions src/db_adapters/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,33 @@

async def user_get(session: AsyncSession, discord_id: int) -> User | None:
"""Get a user by their Discord ID."""
return await session.get(User, discord_id)
async with session:
return await session.get(User, discord_id)


async def user_get_safe(session: AsyncSession, discord_id: int) -> User:
"""Get a user by their Discord ID, creating them if they don't exist."""
user = await user_get(session, discord_id)
if user is None:
user = User(discord_id=discord_id)
session.add(user)
await session.commit()
async with session:
user = await user_get(session, discord_id)
if user is None:
user = User(discord_id=discord_id)
session.add(user)
await session.commit()

return user


async def user_get_list(session: AsyncSession, user: User, name: str) -> UserList | None:
"""Get a user's list by name."""
# use where clause on user.id and name
user_list = await session.execute(
select(UserList)
.where(
UserList.user_id == user.discord_id,
async with session:
user_list = await session.execute(
select(UserList)
.where(
UserList.user_id == user.discord_id,
)
.where(UserList.name == name)
)
.where(UserList.name == name)
)
return user_list.scalars().first()


Expand All @@ -39,14 +42,15 @@ async def user_create_list(session: AsyncSession, user: User, name: str, item_ki
:raises ValueError: If a list with the same name already exists for the user.
"""
if await user_get_list(session, user, name) is not None:
raise ValueError(f"List with name {name} already exists for user {user.discord_id}.")
user_list = UserList(user_id=user.discord_id, name=name, item_kind=item_kind)
session.add(user_list)
await session.commit()
await session.refresh(user, ["lists"])
async with session:
if await user_get_list(session, user, name) is not None:
raise ValueError(f"List with name {name} already exists for user {user.discord_id}.")
user_list = UserList(user_id=user.discord_id, name=name, item_kind=item_kind)
session.add(user_list)
await session.commit()
await session.refresh(user, ["lists"])

return user_list
return user_list


async def user_get_list_safe(
Expand All @@ -57,8 +61,9 @@ async def user_get_list_safe(
:param kind: The kind of list to create if it doesn't exist.
:return: The user list.
"""
user_list = await user_get_list(session, user, name)
if user_list is None:
user_list = await user_create_list(session, user, name, kind)
async with session:
user_list = await user_get_list(session, user, name)
if user_list is None:
user_list = await user_create_list(session, user, name, kind)

return user_list
return user_list
2 changes: 1 addition & 1 deletion src/exts/tvdb_info/ui/episode_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ async def set_watched(self, state: bool) -> None:
)
if item is None:
raise ValueError("Episode is not marked as watched, can't re-mark as unwatched.")
await list_remove_item(self.bot.db_session, self.watched_list, item)
self.watched_list = await list_remove_item(self.bot.db_session, self.watched_list, item)
else:
try:
await list_put_item(
Expand Down
8 changes: 4 additions & 4 deletions src/exts/tvdb_info/ui/movie_series_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ async def set_favorite(self, state: bool) -> None:
item = await get_list_item(self.bot.db_session, self.favorite_list, self.media_data.id, self._db_item_kind)
if item is None:
raise ValueError("Media is not marked as favorite, can't re-mark as favorite.")
await list_remove_item(self.bot.db_session, self.watched_list, item)
self.watched_list = await list_remove_item(self.bot.db_session, self.watched_list, item)
else:
try:
await list_put_item(self.bot.db_session, self.favorite_list, self.media_data.id, self._db_item_kind)
Expand All @@ -92,7 +92,7 @@ async def set_watched(self, state: bool) -> None:
item = await get_list_item(self.bot.db_session, self.watched_list, self.media_data.id, self._db_item_kind)
if item is None:
raise ValueError("Media is not marked as watched, can't re-mark as unwatched.")
await list_remove_item(self.bot.db_session, self.watched_list, item)
self.watched_list = await list_remove_item(self.bot.db_session, self.watched_list, item)
else:
try:
await list_put_item(self.bot.db_session, self.watched_list, self.media_data.id, self._db_item_kind)
Expand Down Expand Up @@ -226,14 +226,14 @@ async def set_watched(self, state: bool) -> None:
if not episode.id:
raise ValueError("Episode has no ID")

await list_remove_item_safe(
self.watched_list = await list_remove_item_safe(
self.bot.db_session,
self.watched_list,
episode.id,
UserListItemKind.EPISODE,
)

await refresh_list_items(self.bot.db_session, self.watched_list)
self.watched_list = await refresh_list_items(self.bot.db_session, self.watched_list)
else:
for episode in self.media_data.episodes:
if not episode.id:
Expand Down
22 changes: 11 additions & 11 deletions src/exts/tvdb_info/ui/profile_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import discord

from src.bot import Bot
from src.db_adapters import refresh_list_items
from src.db_adapters import refresh, refresh_list_items, series_get
from src.db_tables.media import Episode as EpisodeTable, Movie as MovieTable, Series as SeriesTable
from src.db_tables.user_list import UserList, UserListItemKind
from src.exts.error_handler.view import ErrorHandledView
Expand Down Expand Up @@ -49,8 +49,8 @@ def __init__(

async def _initialize(self) -> None:
"""Initialize the view, obtaining any necessary state."""
await refresh_list_items(self.bot.db_session, self.watched_list)
await refresh_list_items(self.bot.db_session, self.favorite_list)
self.watched_list = await refresh_list_items(self.bot.db_session, self.watched_list)
self.favorite_list = await refresh_list_items(self.bot.db_session, self.favorite_list)

watched_movies: list[MovieTable] = []
watched_shows: list[SeriesTable] = []
Expand All @@ -60,14 +60,14 @@ async def _initialize(self) -> None:
for item in self.watched_list.items:
match item.kind:
case UserListItemKind.MOVIE:
await self.bot.db_session.refresh(item, ["movie"])
item = await refresh(self.bot.db_session, item, ["movie"])
watched_movies.append(item.movie)
case UserListItemKind.SERIES:
await self.bot.db_session.refresh(item, ["series"])
item = await refresh(self.bot.db_session, item, ["series"])
watched_shows.append(item.series)
case UserListItemKind.EPISODE:
await self.bot.db_session.refresh(item, ["episode"])
await self.bot.db_session.refresh(item.episode, ["series"])
item = await refresh(self.bot.db_session, item, ["episode"])
item.episode = await refresh(self.bot.db_session, item.episode, ["series"])
watched_episodes.append(item.episode)

# We don't actually care about episodes in the profile view, however, we need them
Expand All @@ -91,10 +91,10 @@ async def _initialize(self) -> None:

group_episode_ids = {episode.tvdb_id for episode in episodes_it}
group_episode_ids.add(first_db_episode.tvdb_id)
await self.bot.db_session.refresh(first_db_episode, ["series"])
first_db_episode = await refresh(self.bot.db_session, first_db_episode, ["series"])

if first_db_episode.series is None: # pyright: ignore[reportUnnecessaryComparison]
manual = await self.bot.db_session.get(SeriesTable, first_db_episode.series_id)
manual = await series_get(self.bot.db_session, first_db_episode.series_id)
raise ValueError(f"DB series is None id={first_db_episode.series_id}, manual={manual}")

if last_episode.id in group_episode_ids:
Expand All @@ -108,10 +108,10 @@ async def _initialize(self) -> None:
for item in self.favorite_list.items:
match item.kind:
case UserListItemKind.MOVIE:
await self.bot.db_session.refresh(item, ["movie"])
item = await refresh(self.bot.db_session, item, ["movie"])
favorite_movies.append(item.movie)
case UserListItemKind.SERIES:
await self.bot.db_session.refresh(item, ["series"])
item = await refresh(self.bot.db_session, item, ["series"])
favorite_shows.append(item.series)
case UserListItemKind.EPISODE:
raise TypeError("Found an episode in favorite list")
Expand Down
4 changes: 2 additions & 2 deletions src/exts/tvdb_info/ui/search_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def search_view(
user = await user_get_safe(bot.db_session, user_id)
watched_list = await user_get_list_safe(bot.db_session, user, "watched")
favorite_list = await user_get_list_safe(bot.db_session, user, "favorite")
await refresh_list_items(bot.db_session, watched_list)
await refresh_list_items(bot.db_session, favorite_list)
watched_list = await refresh_list_items(bot.db_session, watched_list)
favorite_list = await refresh_list_items(bot.db_session, favorite_list)

return _search_view(bot, user_id, invoker_user_id, watched_list, favorite_list, results, 0)

0 comments on commit 2a35f89

Please sign in to comment.