Skip to content

Commit

Permalink
feat: improvements to reset/retry/remove endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
dreulavelle committed Sep 20, 2024
1 parent 365f022 commit 98f9e49
Show file tree
Hide file tree
Showing 15 changed files with 185 additions and 142 deletions.
3 changes: 1 addition & 2 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,9 @@ clean:
@find . -type d -name '__pycache__' -exec rm -rf {} +
@find . -type d -name '.pytest_cache' -exec rm -rf {} +
@find . -type d -name '.ruff_cache' -exec rm -rf {} +
@rm -rf data/alembic/
@rm -rf data/*.db

hard_reset: clean
@rm -rf data/alembic/
@poetry run python src/main.py --hard_reset_db

install:
Expand Down
24 changes: 14 additions & 10 deletions src/controllers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)


@router.get("/")
@router.get("/", operation_id="root")
async def root():
return {
"success": True,
Expand All @@ -26,15 +26,15 @@ async def root():
}


@router.get("/health")
@router.get("/health", operation_id="health")
async def health(request: Request):
return {
"success": True,
"message": request.app.program.initialized,
}


@router.get("/rd")
@router.get("/rd", operation_id="rd")
async def get_rd_user():
api_key = settings_manager.settings.downloaders.real_debrid.api_key
headers = {"Authorization": f"Bearer {api_key}"}
Expand All @@ -57,7 +57,7 @@ async def get_rd_user():
}


@router.get("/torbox")
@router.get("/torbox", operation_id="torbox")
async def get_torbox_user():
api_key = settings_manager.settings.downloaders.torbox.api_key
headers = {"Authorization": f"Bearer {api_key}"}
Expand All @@ -67,7 +67,7 @@ async def get_torbox_user():
return response.json()


@router.get("/services")
@router.get("/services", operation_id="services")
async def get_services(request: Request):
data = {}
if hasattr(request.app.program, "services"):
Expand All @@ -80,7 +80,7 @@ async def get_services(request: Request):
return {"success": True, "data": data}


@router.get("/trakt/oauth/initiate")
@router.get("/trakt/oauth/initiate", operation_id="trakt_oauth_initiate")
async def initiate_trakt_oauth(request: Request):
trakt = request.app.program.services.get(TraktContent)
if trakt is None:
Expand All @@ -89,7 +89,7 @@ async def initiate_trakt_oauth(request: Request):
return {"auth_url": auth_url}


@router.get("/trakt/oauth/callback")
@router.get("/trakt/oauth/callback", operation_id="trakt_oauth_callback")
async def trakt_oauth_callback(code: str, request: Request):
trakt = request.app.program.services.get(TraktContent)
if trakt is None:
Expand All @@ -101,7 +101,7 @@ async def trakt_oauth_callback(code: str, request: Request):
raise HTTPException(status_code=400, detail="Failed to obtain OAuth token")


@router.get("/stats")
@router.get("/stats", operation_id="stats")
async def get_stats(_: Request):
payload = {}
with db.Session() as session:
Expand Down Expand Up @@ -137,7 +137,7 @@ async def get_stats(_: Request):

return {"success": True, "data": payload}

@router.get("/logs")
@router.get("/logs", operation_id="logs")
async def get_logs():
log_file_path = None
for handler in logger._core.handlers.values():
Expand All @@ -154,4 +154,8 @@ async def get_logs():
return {"success": True, "logs": log_contents}
except Exception as e:
logger.error(f"Failed to read log file: {e}")
return {"success": False, "message": "Failed to read log file"}
return {"success": False, "message": "Failed to read log file"}

@router.get("/events", operation_id="events")
async def get_events(request: Request):
return {"success": True, "data": request.app.program.em.get_event_updates()}
53 changes: 32 additions & 21 deletions src/controllers/items.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from datetime import datetime
from typing import Optional

Expand All @@ -6,11 +7,12 @@

from program.content import Overseerr
from program.db.db import db
from program.db.db_functions import get_media_items_by_ids, delete_media_item, reset_media_item
from program.media.item import MediaItem
from program.db.db_functions import clear_streams, get_media_items_by_ids, delete_media_item, reset_media_item, get_parent_items_by_ids
from program.media.item import MediaItem, Season
from program.media.state import States
from sqlalchemy import func, select

from sqlalchemy import delete, func, select
from sqlalchemy.orm import joinedload
from sqlalchemy.exc import NoResultFound
from program.symlink import Symlinker
from utils.logger import logger

Expand Down Expand Up @@ -164,8 +166,9 @@ async def add_items(
)
async def get_item(request: Request, id: int):
with db.Session() as session:
item = session.execute(select(MediaItem).where(MediaItem._id == id)).unique().scalar_one()
if not item:
try:
item = session.execute(select(MediaItem).where(MediaItem._id == id)).unique().scalar_one()
except NoResultFound:
raise HTTPException(status_code=404, detail="Item not found")
return {"success": True, "item": item.to_extended_dict()}

Expand Down Expand Up @@ -198,8 +201,13 @@ async def reset_items(
if not media_items or len(media_items) != len(ids):
raise ValueError("Invalid item ID(s) provided. Some items may not exist.")
for media_item in media_items:
request.app.program.em.cancel_job(media_item)
reset_media_item(media_item)
try:
request.app.program.em.cancel_job(media_item)
clear_streams(media_item)
reset_media_item(media_item)
except Exception as e:
logger.error(f"Failed to reset item with id {media_item._id}: {str(e)}")
continue
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return {"success": True, "message": f"Reset items with id {ids}"}
Expand All @@ -209,19 +217,20 @@ async def reset_items(
summary="Retry Media Items",
description="Retry media items with bases on item IDs",
)
async def retry_items(
request: Request, ids: str
):
async def retry_items(request: Request, ids: str):
ids = handle_ids(ids)
with db.Session() as session:
items = []
for id in ids:
items.append(session.execute(select(MediaItem).where(MediaItem._id == id)).unique().scalar_one())
for item in items:
request.app.program.em.cancel_job(item)
request.app.program.em.add_item(item)
try:
media_items = get_media_items_by_ids(ids)
if not media_items or len(media_items) != len(ids):
raise ValueError("Invalid item ID(s) provided. Some items may not exist.")
for media_item in media_items:
request.app.program.em.cancel_job(media_item)
await asyncio.sleep(0.1) # Ensure cancellation is processed
request.app.program.em.add_item(media_item)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

return {"success": True, "message": f"Retried items with id {ids}"}
return {"success": True, "message": f"Retried items with ids {ids}"}

@router.delete(
"/remove",
Expand All @@ -231,12 +240,14 @@ async def retry_items(
async def remove_item(request: Request, ids: str):
ids = handle_ids(ids)
try:
media_items = get_media_items_by_ids(ids)
if not media_items or len(media_items) != len(ids):
media_items = get_parent_items_by_ids(ids)
if not media_items:
raise ValueError("Invalid item ID(s) provided. Some items may not exist.")
for media_item in media_items:
logger.debug(f"Removing item {media_item.title} with ID {media_item._id}")
request.app.program.em.cancel_job(media_item)
await asyncio.sleep(0.1) # Ensure cancellation is processed
clear_streams(media_item)
symlink_service = request.app.program.services.get(Symlinker)
if symlink_service:
symlink_service.delete_item_symlinks(media_item)
Expand Down
22 changes: 22 additions & 0 deletions src/program/db/db_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ def get_media_items_by_ids(media_item_ids: list[int]):

return items

def get_parent_items_by_ids(media_item_ids: list[int]):
"""Retrieve multiple MediaItems of type 'movie' or 'show' by a list of MediaItem _ids."""
from program.media.item import MediaItem
with db.Session() as session:
items = []
for media_item_id in media_item_ids:
item = session.execute(select(MediaItem).where(MediaItem._id == media_item_id, MediaItem.type.in_(["movie", "show"]))).unique().scalar_one()
items.append(item)
return items

def delete_media_item(item: "MediaItem"):
"""Delete a MediaItem and all its associated relationships."""
with db.Session() as session:
Expand Down Expand Up @@ -123,6 +133,18 @@ def reset_streams(item: "MediaItem", active_stream_hash: str = None):
item.active_stream = {}
session.commit()

def clear_streams(item: "MediaItem"):
"""Clear all streams for a media item."""
with db.Session() as session:
item = session.merge(item)
session.execute(
delete(StreamRelation).where(StreamRelation.parent_id == item._id)
)
session.execute(
delete(StreamBlacklistRelation).where(StreamBlacklistRelation.media_item_id == item._id)
)
session.commit()

def blacklist_stream(item: "MediaItem", stream: Stream, session: Session = None) -> bool:
"""Blacklist a stream for a media item."""
close_session = False
Expand Down
29 changes: 16 additions & 13 deletions src/program/indexers/trakt.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,19 +189,22 @@ def get_show_aliases(imdb_id: str, item_type: str) -> List[dict]:
if not imdb_id:
return []
url = f"https://api.trakt.tv/{item_type}/{imdb_id}/aliases"
response = get(url, timeout=30, additional_headers={"trakt-api-version": "2", "trakt-api-key": CLIENT_ID})
if response.is_ok and response.data:
aliases = {}
for ns in response.data:
country = ns.country
title = ns.title
if title.startswith("Anime-"):
title = title[len("Anime-"):]
if country not in aliases:
aliases[country] = []
if title not in aliases[country]:
aliases[country].append(title)
return aliases
try:
response = get(url, timeout=30, additional_headers={"trakt-api-version": "2", "trakt-api-key": CLIENT_ID})
if response.is_ok and response.data:
aliases = {}
for ns in response.data:
country = ns.country
title = ns.title
if title.startswith("Anime-"):
title = title[len("Anime-"):]
if country not in aliases:
aliases[country] = []
if title not in aliases[country]:
aliases[country].append(title)
return aliases
except Exception:
logger.error(f"Failed to get show aliases for {imdb_id}")
return {}


Expand Down
4 changes: 3 additions & 1 deletion src/program/media/stream.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import TYPE_CHECKING
from RTN import Torrent
from sqlalchemy import Index
from sqlalchemy import Index, and_

from program.db.db import db
import sqlalchemy
from sqlalchemy.orm import Mapped, mapped_column, relationship
from utils.logger import logger

if TYPE_CHECKING:
from program.media.item import MediaItem
from program.media.state import States


class StreamRelation(db.Model):
Expand Down
5 changes: 4 additions & 1 deletion src/program/post_processing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from program.media.state import States
from program.post_processing.subliminal import Subliminal
from program.settings.manager import settings_manager
from program.db.db_functions import clear_streams
from utils.notifications import notify_on_complete
from loguru import logger

Expand All @@ -21,6 +22,8 @@ def __init__(self):
def run(self, item: MediaItem):
if Subliminal.should_submit(item):
self.services[Subliminal].run(item)
if item.last_state == States.Completed:
clear_streams(item)
yield item

def notify(item: MediaItem):
Expand All @@ -38,4 +41,4 @@ def _notify(_item: Show | Movie):
duration = round((datetime.now() - _item.requested_at).total_seconds())
logger.success(f"{_item.log_string} has been completed in {duration} seconds.")
if settings_manager.settings.notifications.enabled:
notify_on_complete(_item)
notify_on_complete(_item)
24 changes: 12 additions & 12 deletions src/program/scrapers/orionoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self):
self.initialized = True
else:
return
self.second_limiter = RateLimiter(max_calls=1, period=5)
self.rate_limiter = RateLimiter(max_calls=1, period=5)
logger.success("Orionoid initialized!")

def validate(self) -> bool:
Expand All @@ -41,9 +41,6 @@ def validate(self) -> bool:
if not isinstance(self.timeout, int) or self.timeout <= 0:
logger.error("Orionoid timeout is not set or invalid.")
return False
if not isinstance(self.settings.ratelimit, bool):
logger.error("Orionoid ratelimit must be a valid boolean.")
return False
try:
url = f"{self.base_url}?keyapp={KEY_APP}&keyuser={self.settings.api_key}&mode=user&action=retrieve"
response = get(url, retry_if_failed=True, timeout=self.timeout)
Expand Down Expand Up @@ -106,7 +103,7 @@ def run(self, item: MediaItem) -> Dict[str, str]:
try:
return self.scrape(item)
except RateLimitExceeded:
self.second_limiter.limit_hit()
self.rate_limiter.limit_hit()
except ConnectTimeout:
logger.warning(f"Orionoid connection timeout for item: {item.log_string}")
except ReadTimeout:
Expand Down Expand Up @@ -136,17 +133,22 @@ def construct_url(self, media_type, imdb_id, season=None, episode=None) -> str:
"type": media_type,
"idimdb": imdb_id[2:],
"streamtype": "torrent",
"protocoltorrent": "magnet",
"video3d": "false",
"videoquality": "sd_hd8k",
"limitcount": self.settings.limitcount or 5
"protocoltorrent": "magnet"
}

if season:
params["numberseason"] = season
if episode:
params["numberepisode"] = episode

if self.settings.cached_results_only:
params["access"] = "realdebridtorrent"
params["debridlookup"] = "realdebrid"

for key, value in self.settings.parameters.items():
if key not in params:
params[key] = value

return f"{self.base_url}?{'&'.join([f'{key}={value}' for key, value in params.items()])}"

def api_scrape(self, item: MediaItem) -> tuple[Dict, int]:
Expand All @@ -164,9 +166,7 @@ def api_scrape(self, item: MediaItem) -> tuple[Dict, int]:
imdb_id = item.parent.parent.imdb_id
url = self.construct_url("show", imdb_id, season=item.parent.number, episode=item.number)

with self.second_limiter:
response = get(url, timeout=self.timeout)

response = get(url, timeout=self.timeout, specific_rate_limiter=self.rate_limiter)
if not response.is_ok or not hasattr(response.data, "data"):
return {}, 0

Expand Down
Loading

0 comments on commit 98f9e49

Please sign in to comment.