From 8a8509ae8fc019df2114e010153d45fdd52f9622 Mon Sep 17 00:00:00 2001 From: PromKnight Date: Tue, 5 Nov 2024 13:48:32 +0000 Subject: [PATCH 1/5] feat: integrate dependency injection with kink library - Added dependency injection using the kink library to manage API instances and service initialization. - Updated various modules to utilize dependency injection for better modularity and testability. - Refactored API initialization and validation logic to be more centralized and consistent. - Enhanced Trakt, Plex, Overseerr, Mdblist, and Listrr services to use injected dependencies. - Updated CLI and service modules to align with the new dependency injection approach. - Modified pyproject.toml to include kink as a dependency. # Conflicts: # src/program/db/db_functions.py # src/program/utils/cli.py --- poetry.lock | 16 ++++- pyproject.toml | 1 + src/program/apis/__init__.py | 40 +++++++++++++ src/program/apis/listrr_api.py | 4 +- src/program/apis/overseerr_api.py | 4 +- src/program/apis/plex_api.py | 11 +++- src/program/apis/trakt_api.py | 60 +++++++++++++++---- src/program/db/db_functions.py | 1 + src/program/program.py | 8 ++- src/program/services/content/listrr.py | 4 +- src/program/services/content/mdblist.py | 5 +- src/program/services/content/overseerr.py | 5 +- .../services/content/plex_watchlist.py | 5 +- src/program/services/content/trakt.py | 50 +--------------- src/program/services/indexers/trakt.py | 4 +- src/program/services/updaters/plex.py | 5 +- src/routers/secure/default.py | 21 ++++--- src/routers/secure/webhooks.py | 3 +- 18 files changed, 160 insertions(+), 87 deletions(-) diff --git a/poetry.lock b/poetry.lock index 335d741e..23987ecf 100644 --- a/poetry.lock +++ b/poetry.lock @@ -900,6 +900,20 @@ files = [ [package.dependencies] referencing = ">=0.31.0" +[[package]] +name = "kink" +version = "0.8.1" +description = "Dependency injection for python." +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "kink-0.8.1-py3-none-any.whl", hash = "sha256:c046be42395de6e18776daa93ac78280a70b3aa5c70b9ea5ca716cc71b3ff91a"}, + {file = "kink-0.8.1.tar.gz", hash = "sha256:9310fa5860ad4df3cdd4a2b998517a718cbc83ed4975c51b8ebd60f640a9702c"}, +] + +[package.dependencies] +typing_extensions = ">=4.9.0,<5.0.0" + [[package]] name = "levenshtein" version = "0.25.1" @@ -3334,4 +3348,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "2757aff75c37be8d41e01a73644d010cd2a08b75cd1438758d5bf052b3e205b8" +content-hash = "683396cff8b9e5aecb1bf9f4c6f567e357af3b5f990a95237d99ebf032dc61b0" diff --git a/pyproject.toml b/pyproject.toml index b933b70e..55355a37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ psutil = "^6.0.0" python-dotenv = "^1.0.1" requests-ratelimiter = "^0.7.0" requests-cache = "^1.2.1" +kink = "^0.8.1" [tool.poetry.group.dev.dependencies] pyright = "^1.1.352" diff --git a/src/program/apis/__init__.py b/src/program/apis/__init__.py index 8b137891..168b915b 100644 --- a/src/program/apis/__init__.py +++ b/src/program/apis/__init__.py @@ -1 +1,41 @@ +from .listrr_api import ListrrAPI, ListrrAPIError +from .trakt_api import TraktAPI, TraktAPIError +from .plex_api import PlexAPI, PlexAPIError +from .overseerr_api import OverseerrAPI, OverseerrAPIError +from .mdblist_api import MdblistAPI, MdblistAPIError +from program.settings.manager import settings_manager +from kink import di +def bootstrap_apis(): + __setup_trakt() + __setup_plex() + __setup_mdblist() + __setup_overseerr() + +def __setup_trakt(): + traktApi = TraktAPI() + di[TraktAPI] = traktApi + +def __setup_plex(): + if not settings_manager.settings.updaters.plex.enabled: + return + plexApi = PlexAPI(settings_manager.settings.updaters.plex.token, settings_manager.settings.updaters.plex.url) + di[PlexAPI] = plexApi + +def __setup_overseerr(): + if not settings_manager.settings.content.overseerr.enabled: + return + overseerrApi = OverseerrAPI(settings_manager.settings.content.overseerr.api_key, settings_manager.settings.content.overseerr.url) + di[OverseerrAPI] = overseerrApi + +def __setup_mdblist(): + if not settings_manager.settings.content.mdblist.enabled: + return + mdblistApi = MdblistAPI(settings_manager.settings.content.mdblist.api_key) + di[MdblistAPI] = mdblistApi + +def __setup_listrr(): + if not settings_manager.settings.content.listrr.enabled: + return + listrrApi = ListrrAPI(settings_manager.settings.content.listrr.api_key) + di[ListrrAPI] = listrrApi diff --git a/src/program/apis/listrr_api.py b/src/program/apis/listrr_api.py index d6c8590c..88d8fd1d 100644 --- a/src/program/apis/listrr_api.py +++ b/src/program/apis/listrr_api.py @@ -1,6 +1,6 @@ from loguru import logger from requests.exceptions import HTTPError - +from kink import di from program.apis.trakt_api import TraktAPI from program.media.item import MediaItem from program.utils.request import create_service_session, BaseRequestHandler, Session, ResponseType, ResponseObject, HttpMethod @@ -25,7 +25,7 @@ def __init__(self, api_key: str): session = create_service_session() session.headers.update(self.headers) self.request_handler = ListrrRequestHandler(session, base_url=self.BASE_URL) - self.trakt_api = TraktAPI(rate_limit=False) + self.trakt_api = di[TraktAPI] def validate(self): return self.request_handler.execute(HttpMethod.GET, self.BASE_URL) diff --git a/src/program/apis/overseerr_api.py b/src/program/apis/overseerr_api.py index 99031f91..14fa497c 100644 --- a/src/program/apis/overseerr_api.py +++ b/src/program/apis/overseerr_api.py @@ -3,7 +3,7 @@ from loguru import logger from requests.exceptions import ConnectionError, RetryError from urllib3.exceptions import MaxRetryError - +from kink import di from program.apis.trakt_api import TraktAPI from program.media.item import MediaItem from program.settings.manager import settings_manager @@ -27,7 +27,7 @@ def __init__(self, api_key: str, base_url: str): self.api_key = api_key rate_limit_params = get_rate_limit_params(max_calls=1000, period=300) session = create_service_session(rate_limit_params=rate_limit_params) - self.trakt_api = TraktAPI(rate_limit=False) + self.trakt_api = di[TraktAPI] self.headers = {"X-Api-Key": self.api_key} session.headers.update(self.headers) self.request_handler = OverseerrRequestHandler(session, base_url=base_url) diff --git a/src/program/apis/plex_api.py b/src/program/apis/plex_api.py index 927a2e10..e14b4d10 100644 --- a/src/program/apis/plex_api.py +++ b/src/program/apis/plex_api.py @@ -22,8 +22,8 @@ def execute(self, method: HttpMethod, endpoint: str, overriden_response_type: Re class PlexAPI: """Handles Plex API communication""" - def __init__(self, token: str, base_url: str, rss_urls: Optional[List[str]]): - self.rss_urls = rss_urls + def __init__(self, token: str, base_url: str): + self.rss_urls: Optional[List[str]] = None self.token = token self.BASE_URL = base_url session = create_service_session() @@ -43,6 +43,13 @@ def validate_account(self): def validate_server(self): self.plex_server = PlexServer(self.BASE_URL, token=self.token, session=self.request_handler.session, timeout=60) + def set_rss_urls(self, rss_urls: List[str]): + self.rss_urls = rss_urls + + def clear_rss_urls(self): + self.rss_urls = None + self.rss_enabled = False + def validate_rss(self, url: str): return self.request_handler.execute(HttpMethod.GET, url) diff --git a/src/program/apis/trakt_api.py b/src/program/apis/trakt_api.py index 17d57822..dcc88358 100644 --- a/src/program/apis/trakt_api.py +++ b/src/program/apis/trakt_api.py @@ -2,23 +2,26 @@ from datetime import datetime from types import SimpleNamespace from typing import Union, List, Optional +from urllib.parse import urlencode from requests import RequestException, Session from program import MediaItem from program.media import Movie, Show, Season, Episode +from program.settings.manager import settings_manager from program.utils.request import get_rate_limit_params, create_service_session, logger, BaseRequestHandler, \ - ResponseType, HttpMethod, ResponseObject + ResponseType, HttpMethod, ResponseObject, get_cache_params class TraktAPIError(Exception): """Base exception for TraktApi related errors""" class TraktRequestHandler(BaseRequestHandler): - def __init__(self, session: Session, request_logging: bool = False): - super().__init__(session, response_type=ResponseType.SIMPLE_NAMESPACE, custom_exception=TraktAPIError, request_logging=request_logging) + def __init__(self, session: Session, response_type=ResponseType.SIMPLE_NAMESPACE, request_logging: bool = False): + super().__init__(session, response_type=response_type, custom_exception=TraktAPIError, request_logging=request_logging) def execute(self, method: HttpMethod, endpoint: str, **kwargs) -> ResponseObject: return super()._request(method, endpoint, **kwargs) + class TraktAPI: """Handles Trakt API communication""" BASE_URL = "https://api.trakt.tv" @@ -29,16 +32,17 @@ class TraktAPI: "short_list": re.compile(r"https://trakt.tv/lists/\d+") } - def __init__(self, api_key: Optional[str] = None, rate_limit: bool = True): - self.api_key = api_key - rate_limit_params = get_rate_limit_params(max_calls=1000, period=300) if rate_limit else None - session = create_service_session( - rate_limit_params=rate_limit_params, - use_cache=False - ) + def __init__(self, oauth_client_id: Optional[str] = None, oauth_client_secret: Optional[str] = None, oauth_redirect_uri: Optional[str] = None): + self.settings = settings_manager.settings.content.trakt + self.oauth_client_id = oauth_client_id + self.oauth_client_secret = oauth_client_secret + self.oauth_redirect_uri = oauth_redirect_uri + rate_limit_params = get_rate_limit_params(max_calls=1000, period=300) + trakt_cache = get_cache_params("trakt", 86400) + session = create_service_session(rate_limit_params=rate_limit_params, use_cache=True, cache_params=trakt_cache) self.headers = { "Content-type": "application/json", - "trakt-api-key": self.api_key or self.CLIENT_ID, + "trakt-api-key": self.CLIENT_ID, "trakt-api-version": "2" } session.headers.update(self.headers) @@ -219,7 +223,7 @@ def get_imdbid_from_tvdb(self, tvdb_id: str, type: str = "show") -> Optional[str def resolve_short_url(self, short_url) -> Union[str, None]: """Resolve short URL to full URL""" try: - response = self.request_handler.execute(HttpMethod.GET, url=short_url, additional_headers={"Content-Type": "application/json", "Accept": "text/html"}) + response = self.request_handler.execute(HttpMethod.GET, endpoint=short_url, additional_headers={"Content-Type": "application/json", "Accept": "text/html"}) if response.is_ok: return response.response.url else: @@ -279,6 +283,38 @@ def map_item_from_data(self, data, item_type: str, show_genres: List[str] = None logger.error(f"Unknown item type {item_type} for {data.title} not found in list of acceptable items") return None + def perform_oauth_flow(self) -> str: + """Initiate the OAuth flow and return the authorization URL.""" + params = { + "response_type": "code", + "client_id": self.oauth_client_id, + "redirect_uri": self.oauth_redirect_uri, + } + return f"{self.BASE_URL}/oauth/authorize?{urlencode(params)}" + + def handle_oauth_callback(self, api_key:str, code: str) -> bool: + """Handle the OAuth callback and exchange the code for an access token.""" + token_url = f"{self.BASE_URL}/oauth/token" + payload = { + "code": code, + "client_id": self.oauth_client_id, + "client_secret": self.oauth_client_secret, + "redirect_uri": self.oauth_redirect_uri, + "grant_type": "authorization_code", + } + headers = self.headers.copy() + headers["trakt-api-key"] = f"Bearer {api_key}" + response = self.request_handler.execute(HttpMethod.POST, token_url, data=payload, additional_headers=headers) + if response.is_ok: + token_data = response.data + self.settings.access_token = token_data.get("access_token") + self.settings.refresh_token = token_data.get("refresh_token") + settings_manager.save() # Save the tokens to settings + return True + else: + logger.error(f"Failed to obtain OAuth token: {response.status_code}") + return False + def _get_imdb_id_from_list(self, namespaces: List[SimpleNamespace], id_type: str = None, _id: str = None, type: str = None) -> Optional[str]: """Get the imdb_id from the list of namespaces.""" diff --git a/src/program/db/db_functions.py b/src/program/db/db_functions.py index b49a6f52..e2317436 100644 --- a/src/program/db/db_functions.py +++ b/src/program/db/db_functions.py @@ -10,6 +10,7 @@ from program.utils import root_dir import alembic +from program.utils import root_dir from program.media.stream import Stream, StreamBlacklistRelation, StreamRelation from program.services.libraries.symlink import fix_broken_symlinks from program.settings.manager import settings_manager diff --git a/src/program/program.py b/src/program/program.py index 36a60ca4..f3d8c940 100644 --- a/src/program/program.py +++ b/src/program/program.py @@ -19,6 +19,7 @@ PlexWatchlist, TraktContent, ) +from program.apis import bootstrap_apis from program.services.downloaders import Downloader from program.services.indexers.trakt import TraktIndexer from program.services.libraries import SymlinkLibrary @@ -64,8 +65,11 @@ def __init__(self): self.malloc_time = time.monotonic()-50 self.last_snapshot = None - def initialize_services(self): + def initialize_apis(self): + bootstrap_apis() + def initialize_services(self): + """Initialize all services.""" self.requesting_services = { Overseerr: Overseerr(), PlexWatchlist: PlexWatchlist(), @@ -122,6 +126,7 @@ def start(self): latest_version = get_version() logger.log("PROGRAM", f"Riven v{latest_version} starting!") + settings_manager.register_observer(self.initialize_apis) settings_manager.register_observer(self.initialize_services) os.makedirs(data_dir_path, exist_ok=True) @@ -129,6 +134,7 @@ def start(self): logger.log("PROGRAM", "Settings file not found, creating default settings") settings_manager.save() + self.initialize_apis() self.initialize_services() max_worker_env_vars = [var for var in os.environ if var.endswith("_MAX_WORKERS")] diff --git a/src/program/services/content/listrr.py b/src/program/services/content/listrr.py index 764ea50e..1a7768fe 100644 --- a/src/program/services/content/listrr.py +++ b/src/program/services/content/listrr.py @@ -1,5 +1,6 @@ """Listrr content module""" from typing import Generator +from kink import di from program.utils.request import logger from program.media.item import MediaItem from program.settings.manager import settings_manager @@ -12,7 +13,7 @@ class Listrr: def __init__(self): self.key = "listrr" self.settings = settings_manager.settings.content.listrr - self.api = ListrrAPI(self.settings.api_key) + self.api = None self.initialized = self.validate() if not self.initialized: return @@ -40,6 +41,7 @@ def validate(self) -> bool: logger.error("Both Movie and Show lists are empty or not set.") return False try: + self.api = di[ListrrAPI] response = self.api.validate() if not response.is_ok: logger.error( diff --git a/src/program/services/content/mdblist.py b/src/program/services/content/mdblist.py index 8199a4af..2dd8558a 100644 --- a/src/program/services/content/mdblist.py +++ b/src/program/services/content/mdblist.py @@ -2,7 +2,7 @@ from typing import Generator from loguru import logger - +from kink import di from program.apis.mdblist_api import MdblistAPI from program.media.item import MediaItem from program.settings.manager import settings_manager @@ -14,7 +14,7 @@ class Mdblist: def __init__(self): self.key = "mdblist" self.settings = settings_manager.settings.content.mdblist - self.api = MdblistAPI(self.settings.api_key) + self.api = None self.initialized = self.validate() if not self.initialized: return @@ -30,6 +30,7 @@ def validate(self): if not self.settings.lists: logger.error("Mdblist is enabled, but list is empty.") return False + self.api = di[MdblistAPI] response = self.api.validate() if "Invalid API key!" in response.response.text: logger.error("Mdblist api key is invalid.") diff --git a/src/program/services/content/overseerr.py b/src/program/services/content/overseerr.py index d145d4b6..1eca2faa 100644 --- a/src/program/services/content/overseerr.py +++ b/src/program/services/content/overseerr.py @@ -3,7 +3,7 @@ from loguru import logger from requests.exceptions import ConnectionError, RetryError from urllib3.exceptions import MaxRetryError, NewConnectionError - +from kink import di from program.apis.overseerr_api import OverseerrAPI from program.media.item import MediaItem from program.settings.manager import settings_manager @@ -15,7 +15,7 @@ class Overseerr: def __init__(self): self.key = "overseerr" self.settings = settings_manager.settings.content.overseerr - self.api = OverseerrAPI(self.settings.api_key, self.settings.url) + self.api = None self.initialized = self.validate() self.run_once = False if not self.initialized: @@ -29,6 +29,7 @@ def validate(self) -> bool: logger.error("Overseerr api key is not set.") return False try: + self.api = di[OverseerrAPI] response = self.api.validate() if response.status_code >= 201: logger.error( diff --git a/src/program/services/content/plex_watchlist.py b/src/program/services/content/plex_watchlist.py index 8bbef8c6..c0b725fd 100644 --- a/src/program/services/content/plex_watchlist.py +++ b/src/program/services/content/plex_watchlist.py @@ -3,6 +3,7 @@ from loguru import logger from requests import HTTPError from program.apis.plex_api import PlexAPI +from kink import di from program.media.item import MediaItem from program.settings.manager import settings_manager @@ -13,7 +14,7 @@ class PlexWatchlist: def __init__(self): self.key = "plex_watchlist" self.settings = settings_manager.settings.content.plex_watchlist - self.api = PlexAPI(settings_manager.settings.updaters.plex.token, settings_manager.settings.updaters.plex.url, self.settings.rss) + self.api = None self.initialized = self.validate() if not self.initialized: return @@ -26,11 +27,13 @@ def validate(self): logger.error("Plex token is not set!") return False try: + self.api = di[PlexAPI] self.api.validate_account() except Exception as e: logger.error(f"Unable to authenticate Plex account: {e}") return False if self.settings.rss: + self.api.set_rss_urls(self.settings.rss) for rss_url in self.settings.rss: try: response = self.api.validate_rss(rss_url) diff --git a/src/program/services/content/trakt.py b/src/program/services/content/trakt.py index b8e831f5..10be967f 100644 --- a/src/program/services/content/trakt.py +++ b/src/program/services/content/trakt.py @@ -1,24 +1,12 @@ """Trakt content module""" from datetime import datetime, timedelta -from typing import Type, Optional -from urllib.parse import urlencode - from loguru import logger from requests import RequestException - +from kink import di from program.apis.trakt_api import TraktAPI from program.media.item import MediaItem from program.settings.manager import settings_manager -from program.utils.request import create_service_session, BaseRequestHandler, Session, ResponseType, ResponseObject, HttpMethod - - -class TraktOAuthRequestHandler(BaseRequestHandler): - def __init__(self, session: Session, response_type=ResponseType.SIMPLE_NAMESPACE, custom_exception: Optional[Type[Exception]] = None, request_logging: bool = False): - super().__init__(session, response_type=response_type, custom_exception=custom_exception, request_logging=request_logging) - - def execute(self, method: HttpMethod, endpoint: str, **kwargs) -> ResponseObject: - return super()._request(method, endpoint, **kwargs) class TraktContent: """Content class for Trakt""" @@ -26,9 +14,7 @@ class TraktContent: def __init__(self): self.key = "trakt" self.settings = settings_manager.settings.content.trakt - self.api = TraktAPI(self.settings.api_key) - session = create_service_session() - self.oauth_request_handler = TraktOAuthRequestHandler(session) + self.api = di[TraktAPI] self.initialized = self.validate() if not self.initialized: return @@ -185,34 +171,4 @@ def _extract_imdb_ids_with_none_type(items: list) -> list: imdb_id = getattr(ids, "imdb", None) if imdb_id: imdb_ids.append((imdb_id, None)) - return imdb_ids - - def perform_oauth_flow(self) -> str: - """Initiate the OAuth flow and return the authorization URL.""" - params = { - "response_type": "code", - "client_id": self.settings.oauth_client_id, - "redirect_uri": self.settings.oauth_redirect_uri, - } - return f"{self.api.BASE_URL}/oauth/authorize?{urlencode(params)}" - - def handle_oauth_callback(self, code: str) -> bool: - """Handle the OAuth callback and exchange the code for an access token.""" - token_url = f"{self.api.BASE_URL}/oauth/token" - payload = { - "code": code, - "client_id": self.settings.oauth_client_id, - "client_secret": self.settings.oauth_client_secret, - "redirect_uri": self.settings.oauth_redirect_uri, - "grant_type": "authorization_code", - } - response = self.oauth_request_handler.execute(HttpMethod.POST, token_url, data=payload, additional_headers=self.api.headers) - if response.is_ok: - token_data = response.data - self.settings.access_token = token_data.get("access_token") - self.settings.refresh_token = token_data.get("refresh_token") - settings_manager.save() # Save the tokens to settings - return True - else: - logger.error(f"Failed to obtain OAuth token: {response.status_code}") - return False \ No newline at end of file + return imdb_ids \ No newline at end of file diff --git a/src/program/services/indexers/trakt.py b/src/program/services/indexers/trakt.py index 2c60f477..a2c854cc 100644 --- a/src/program/services/indexers/trakt.py +++ b/src/program/services/indexers/trakt.py @@ -3,7 +3,7 @@ from datetime import datetime, timedelta from typing import Generator, Union from loguru import logger - +from kink import di from program.apis.trakt_api import TraktAPI from program.media.item import Episode, MediaItem, Movie, Season, Show from program.settings.manager import settings_manager @@ -19,7 +19,7 @@ def __init__(self): self.initialized = True self.settings = settings_manager.settings.indexer self.failed_ids = set() - self.api = TraktAPI(rate_limit=False) + self.api = di[TraktAPI] @staticmethod def copy_attributes(source, target): diff --git a/src/program/services/updaters/plex.py b/src/program/services/updaters/plex.py index 5d7aa137..469503d3 100644 --- a/src/program/services/updaters/plex.py +++ b/src/program/services/updaters/plex.py @@ -7,7 +7,7 @@ from plexapi.library import LibrarySection from requests.exceptions import ConnectionError as RequestsConnectionError from urllib3.exceptions import MaxRetryError, NewConnectionError, RequestError - +from kink import di from program.apis.plex_api import PlexAPI from program.media.item import Episode, Movie, Season, Show from program.settings.manager import settings_manager @@ -21,7 +21,7 @@ def __init__(self): os.path.dirname(settings_manager.settings.symlink.library_path) ) self.settings = settings_manager.settings.updaters.plex - self.api = PlexAPI(self.settings.token, self.settings.url, self.settings) + self.api = None self.sections: Dict[LibrarySection, List[str]] = {} self.initialized = self.validate() if not self.initialized: @@ -43,6 +43,7 @@ def validate(self) -> bool: # noqa: C901 return False try: + self.api = di[PlexAPI] self.api.validate_server() self.sections = self.api.map_sections_with_paths() self.initialized = True diff --git a/src/routers/secure/default.py b/src/routers/secure/default.py index 0b76ead8..1b9f102b 100644 --- a/src/routers/secure/default.py +++ b/src/routers/secure/default.py @@ -5,12 +5,12 @@ from loguru import logger from pydantic import BaseModel, Field from sqlalchemy import func, select - +from kink import di +from program.apis import TraktAPI from program.db.db import db from program.managers.event_manager import EventUpdate from program.media.item import Episode, MediaItem, Movie, Season, Show from program.media.state import States -from program.services.content.trakt import TraktContent from program.settings.manager import settings_manager from program.utils import generate_api_key @@ -99,19 +99,22 @@ class TraktOAuthInitiateResponse(BaseModel): @router.get("/trakt/oauth/initiate", operation_id="trakt_oauth_initiate") async def initiate_trakt_oauth(request: Request) -> TraktOAuthInitiateResponse: - trakt = request.app.program.services.get(TraktContent) - if trakt is None: + trakt_api = di[TraktAPI] + if trakt_api is None: raise HTTPException(status_code=404, detail="Trakt service not found") - auth_url = trakt.perform_oauth_flow() + auth_url = trakt_api.perform_oauth_flow() return {"auth_url": auth_url} @router.get("/trakt/oauth/callback", operation_id="trakt_oauth_callback") async def trakt_oauth_callback(code: str, request: Request) -> MessageResponse: - trakt = request.app.program.services.get(TraktContent) - if trakt is None: - raise HTTPException(status_code=404, detail="Trakt service not found") - success = trakt.handle_oauth_callback(code) + trakt_api = di[TraktAPI] + trakt_api_key = settings_manager.settings.content.trakt.api_key + if trakt_api is None: + raise HTTPException(status_code=404, detail="Trakt Api not found") + if trakt_api_key is None: + raise HTTPException(status_code=404, detail="Trakt Api key not found in settings") + success = trakt_api.handle_oauth_callback(trakt_api_key, code) if success: return {"message": "OAuth token obtained successfully"} else: diff --git a/src/routers/secure/webhooks.py b/src/routers/secure/webhooks.py index 345c0d28..3d0969a7 100644 --- a/src/routers/secure/webhooks.py +++ b/src/routers/secure/webhooks.py @@ -8,6 +8,7 @@ from program.apis.trakt_api import TraktAPI from program.media.item import MediaItem from program.services.content.overseerr import Overseerr +from kink import di from ..models.overseerr import OverseerrWebhook @@ -50,7 +51,7 @@ async def overseerr(request: Request) -> Dict[str, Any]: def get_imdbid_from_overseerr(req: OverseerrWebhook) -> str: """Get the imdb_id from the Overseerr webhook""" imdb_id = req.media.imdbId - trakt_api = TraktAPI(rate_limit=False) + trakt_api = di[TraktAPI] if not imdb_id: try: _type = req.media.media_type From 4fff037fa46a12a98f107d51b1fc797f83ef86a7 Mon Sep 17 00:00:00 2001 From: PromKnight Date: Tue, 5 Nov 2024 14:33:15 +0000 Subject: [PATCH 2/5] feat: enhance Trakt API with OAuth support and settings integration - Updated TraktAPI to accept settings via TraktModel, enabling OAuth configuration. - Added OAuth flow methods to handle authorization and token exchange. - Integrated TraktOauthModel into TraktModel for structured OAuth settings. - Modified API bootstrap to pass settings to TraktAPI. - Ensured backward compatibility with existing settings structure. --- src/program/apis/__init__.py | 3 ++- src/program/apis/trakt_api.py | 28 ++++++++++++++++++---------- src/program/settings/models.py | 2 +- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/src/program/apis/__init__.py b/src/program/apis/__init__.py index 168b915b..f8a3c7aa 100644 --- a/src/program/apis/__init__.py +++ b/src/program/apis/__init__.py @@ -11,9 +11,10 @@ def bootstrap_apis(): __setup_plex() __setup_mdblist() __setup_overseerr() + __setup_listrr() def __setup_trakt(): - traktApi = TraktAPI() + traktApi = TraktAPI(settings_manager.settings.content.trakt) di[TraktAPI] = traktApi def __setup_plex(): diff --git a/src/program/apis/trakt_api.py b/src/program/apis/trakt_api.py index dcc88358..a4260206 100644 --- a/src/program/apis/trakt_api.py +++ b/src/program/apis/trakt_api.py @@ -7,10 +7,10 @@ from program import MediaItem from program.media import Movie, Show, Season, Episode from program.settings.manager import settings_manager +from program.settings.models import TraktModel from program.utils.request import get_rate_limit_params, create_service_session, logger, BaseRequestHandler, \ ResponseType, HttpMethod, ResponseObject, get_cache_params - class TraktAPIError(Exception): """Base exception for TraktApi related errors""" @@ -32,11 +32,11 @@ class TraktAPI: "short_list": re.compile(r"https://trakt.tv/lists/\d+") } - def __init__(self, oauth_client_id: Optional[str] = None, oauth_client_secret: Optional[str] = None, oauth_redirect_uri: Optional[str] = None): + def __init__(self, settings: TraktModel): self.settings = settings_manager.settings.content.trakt - self.oauth_client_id = oauth_client_id - self.oauth_client_secret = oauth_client_secret - self.oauth_redirect_uri = oauth_redirect_uri + self.oauth_client_id = settings.oauth.oauth_client_id + self.oauth_client_secret = settings.oauth.oauth_client_secret + self.oauth_redirect_uri = settings.oauth.oauth_redirect_uri rate_limit_params = get_rate_limit_params(max_calls=1000, period=300) trakt_cache = get_cache_params("trakt", 86400) session = create_service_session(rate_limit_params=rate_limit_params, use_cache=True, cache_params=trakt_cache) @@ -152,7 +152,7 @@ def get_show(self, imdb_id: str) -> dict: """Wrapper for trakt.tv API show method.""" if not imdb_id: return {} - url = f"https://api.trakt.tv/shows/{imdb_id}/seasons?extended=episodes,full" + url = f"{self.BASE_URL}/shows/{imdb_id}/seasons?extended=episodes,full" response = self.request_handler.execute(HttpMethod.GET, url, timeout=30) return response.data if response.is_ok and response.data else {} @@ -160,7 +160,7 @@ def get_show_aliases(self, imdb_id: str, item_type: str) -> List[dict]: """Wrapper for trakt.tv API show method.""" if not imdb_id: return [] - url = f"https://api.trakt.tv/{item_type}/{imdb_id}/aliases" + url = f"{self.BASE_URL}/{item_type}/{imdb_id}/aliases" try: response = self.request_handler.execute(HttpMethod.GET, url, timeout=30) if response.is_ok and response.data: @@ -182,7 +182,7 @@ def get_show_aliases(self, imdb_id: str, item_type: str) -> List[dict]: def create_item_from_imdb_id(self, imdb_id: str, type: str = None) -> Optional[MediaItem]: """Wrapper for trakt.tv API search method.""" - url = f"https://api.trakt.tv/search/imdb/{imdb_id}?extended=full" + url = f"{self.BASE_URL}/search/imdb/{imdb_id}?extended=full" response = self.request_handler.execute(HttpMethod.GET, url, timeout=30) if not response.is_ok or not response.data: logger.error( @@ -198,7 +198,7 @@ def create_item_from_imdb_id(self, imdb_id: str, type: str = None) -> Optional[M def get_imdbid_from_tmdb(self, tmdb_id: str, type: str = "movie") -> Optional[str]: """Wrapper for trakt.tv API search method.""" - url = f"https://api.trakt.tv/search/tmdb/{tmdb_id}" # ?extended=full + url = f"{self.BASE_URL}/search/tmdb/{tmdb_id}" # ?extended=full response = self.request_handler.execute(HttpMethod.GET, url, timeout=30) if not response.is_ok or not response.data: return None @@ -210,7 +210,7 @@ def get_imdbid_from_tmdb(self, tmdb_id: str, type: str = "movie") -> Optional[st def get_imdbid_from_tvdb(self, tvdb_id: str, type: str = "show") -> Optional[str]: """Wrapper for trakt.tv API search method.""" - url = f"https://api.trakt.tv/search/tvdb/{tvdb_id}" + url = f"{self.BASE_URL}/search/tvdb/{tvdb_id}" response = self.request_handler.execute(HttpMethod.GET, url, timeout=30) if not response.is_ok or not response.data: return None @@ -285,6 +285,10 @@ def map_item_from_data(self, data, item_type: str, show_genres: List[str] = None def perform_oauth_flow(self) -> str: """Initiate the OAuth flow and return the authorization URL.""" + if not self.oauth_client_id or not self.oauth_client_secret or not self.oauth_redirect_uri: + logger.error("OAuth settings not found in Trakt settings") + raise TraktAPIError("OAuth settings not found in Trakt settings") + params = { "response_type": "code", "client_id": self.oauth_client_id, @@ -294,6 +298,10 @@ def perform_oauth_flow(self) -> str: def handle_oauth_callback(self, api_key:str, code: str) -> bool: """Handle the OAuth callback and exchange the code for an access token.""" + if not self.oauth_client_id or not self.oauth_client_secret or not self.oauth_redirect_uri: + logger.error("OAuth settings not found in Trakt settings") + return False + token_url = f"{self.BASE_URL}/oauth/token" payload = { "code": code, diff --git a/src/program/settings/models.py b/src/program/settings/models.py index 003986f2..8773da40 100644 --- a/src/program/settings/models.py +++ b/src/program/settings/models.py @@ -177,7 +177,7 @@ class TraktModel(Updatable): most_watched_period: str = "weekly" most_watched_count: int = 10 update_interval: int = 86400 - # oauth: TraktOauthModel = TraktOauthModel() + oauth: TraktOauthModel = TraktOauthModel() class ContentModel(Observable): From e52b47242de5eb20c68949e32a38e678196fb900 Mon Sep 17 00:00:00 2001 From: PromKnight Date: Tue, 5 Nov 2024 14:35:01 +0000 Subject: [PATCH 3/5] fix: assignment of trakt api key in oauth --- src/program/apis/trakt_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/program/apis/trakt_api.py b/src/program/apis/trakt_api.py index a4260206..98a9a1d8 100644 --- a/src/program/apis/trakt_api.py +++ b/src/program/apis/trakt_api.py @@ -311,7 +311,7 @@ def handle_oauth_callback(self, api_key:str, code: str) -> bool: "grant_type": "authorization_code", } headers = self.headers.copy() - headers["trakt-api-key"] = f"Bearer {api_key}" + headers["trakt-api-key"] = api_key response = self.request_handler.execute(HttpMethod.POST, token_url, data=payload, additional_headers=headers) if response.is_ok: token_data = response.data From f96717bed5ee6bf6f2d70103cb3056bc4ee756de Mon Sep 17 00:00:00 2001 From: PromKnight Date: Tue, 5 Nov 2024 14:37:24 +0000 Subject: [PATCH 4/5] fix: duplicate import --- src/program/db/db_functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/program/db/db_functions.py b/src/program/db/db_functions.py index e2317436..c7339bb7 100644 --- a/src/program/db/db_functions.py +++ b/src/program/db/db_functions.py @@ -7,7 +7,6 @@ from sqlalchemy import delete, desc, func, insert, inspect, select, text from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Session, joinedload, selectinload -from program.utils import root_dir import alembic from program.utils import root_dir From 21edfd39e8bdb571c9a2dcdbe8ebcc13184176b6 Mon Sep 17 00:00:00 2001 From: PromKnight Date: Tue, 5 Nov 2024 14:38:58 +0000 Subject: [PATCH 5/5] fix: correct TraktAPI settings initialization in constructor --- src/program/apis/trakt_api.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/program/apis/trakt_api.py b/src/program/apis/trakt_api.py index 98a9a1d8..3a911ce8 100644 --- a/src/program/apis/trakt_api.py +++ b/src/program/apis/trakt_api.py @@ -33,10 +33,10 @@ class TraktAPI: } def __init__(self, settings: TraktModel): - self.settings = settings_manager.settings.content.trakt - self.oauth_client_id = settings.oauth.oauth_client_id - self.oauth_client_secret = settings.oauth.oauth_client_secret - self.oauth_redirect_uri = settings.oauth.oauth_redirect_uri + self.settings = settings + self.oauth_client_id = self.settings.oauth.oauth_client_id + self.oauth_client_secret = self.settings.oauth.oauth_client_secret + self.oauth_redirect_uri = self.settings.oauth.oauth_redirect_uri rate_limit_params = get_rate_limit_params(max_calls=1000, period=300) trakt_cache = get_cache_params("trakt", 86400) session = create_service_session(rate_limit_params=rate_limit_params, use_cache=True, cache_params=trakt_cache)