Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: integrate dependency injection with kink library #859

Merged
merged 5 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ psutil = "^6.0.0"
python-dotenv = "^1.0.1"
requests-ratelimiter = "^0.7.0"
requests-cache = "^1.2.1"
kink = "^0.8.1"

[tool.poetry.group.dev.dependencies]
pyright = "^1.1.352"
Expand Down
41 changes: 41 additions & 0 deletions src/program/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,42 @@
from .listrr_api import ListrrAPI, ListrrAPIError
from .trakt_api import TraktAPI, TraktAPIError
from .plex_api import PlexAPI, PlexAPIError
from .overseerr_api import OverseerrAPI, OverseerrAPIError
from .mdblist_api import MdblistAPI, MdblistAPIError
from program.settings.manager import settings_manager
from kink import di

def bootstrap_apis():
__setup_trakt()
__setup_plex()
__setup_mdblist()
__setup_overseerr()
__setup_listrr()

def __setup_trakt():
traktApi = TraktAPI(settings_manager.settings.content.trakt)
di[TraktAPI] = traktApi

def __setup_plex():
if not settings_manager.settings.updaters.plex.enabled:
return
plexApi = PlexAPI(settings_manager.settings.updaters.plex.token, settings_manager.settings.updaters.plex.url)
di[PlexAPI] = plexApi

def __setup_overseerr():
if not settings_manager.settings.content.overseerr.enabled:
return
overseerrApi = OverseerrAPI(settings_manager.settings.content.overseerr.api_key, settings_manager.settings.content.overseerr.url)
di[OverseerrAPI] = overseerrApi

def __setup_mdblist():
if not settings_manager.settings.content.mdblist.enabled:
return
mdblistApi = MdblistAPI(settings_manager.settings.content.mdblist.api_key)
di[MdblistAPI] = mdblistApi

def __setup_listrr():
if not settings_manager.settings.content.listrr.enabled:
return
listrrApi = ListrrAPI(settings_manager.settings.content.listrr.api_key)
di[ListrrAPI] = listrrApi
iPromKnight marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 2 additions & 2 deletions src/program/apis/listrr_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from loguru import logger
from requests.exceptions import HTTPError

from kink import di
from program.apis.trakt_api import TraktAPI
from program.media.item import MediaItem
from program.utils.request import create_service_session, BaseRequestHandler, Session, ResponseType, ResponseObject, HttpMethod
Expand All @@ -25,7 +25,7 @@ def __init__(self, api_key: str):
session = create_service_session()
session.headers.update(self.headers)
self.request_handler = ListrrRequestHandler(session, base_url=self.BASE_URL)
self.trakt_api = TraktAPI(rate_limit=False)
self.trakt_api = di[TraktAPI]

def validate(self):
return self.request_handler.execute(HttpMethod.GET, self.BASE_URL)
Expand Down
4 changes: 2 additions & 2 deletions src/program/apis/overseerr_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from loguru import logger
from requests.exceptions import ConnectionError, RetryError
from urllib3.exceptions import MaxRetryError

from kink import di
from program.apis.trakt_api import TraktAPI
from program.media.item import MediaItem
from program.settings.manager import settings_manager
Expand All @@ -27,7 +27,7 @@ def __init__(self, api_key: str, base_url: str):
self.api_key = api_key
rate_limit_params = get_rate_limit_params(max_calls=1000, period=300)
session = create_service_session(rate_limit_params=rate_limit_params)
self.trakt_api = TraktAPI(rate_limit=False)
self.trakt_api = di[TraktAPI]
self.headers = {"X-Api-Key": self.api_key}
session.headers.update(self.headers)
self.request_handler = OverseerrRequestHandler(session, base_url=base_url)
Expand Down
11 changes: 9 additions & 2 deletions src/program/apis/plex_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def execute(self, method: HttpMethod, endpoint: str, overriden_response_type: Re
class PlexAPI:
"""Handles Plex API communication"""

def __init__(self, token: str, base_url: str, rss_urls: Optional[List[str]]):
self.rss_urls = rss_urls
def __init__(self, token: str, base_url: str):
self.rss_urls: Optional[List[str]] = None
self.token = token
self.BASE_URL = base_url
session = create_service_session()
Expand All @@ -43,6 +43,13 @@ def validate_account(self):
def validate_server(self):
self.plex_server = PlexServer(self.BASE_URL, token=self.token, session=self.request_handler.session, timeout=60)

def set_rss_urls(self, rss_urls: List[str]):
self.rss_urls = rss_urls

def clear_rss_urls(self):
self.rss_urls = None
self.rss_enabled = False

Comment on lines +46 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add validation and improve documentation for RSS URL management methods.

The new methods provide good encapsulation of RSS URL management, but could benefit from additional validation and documentation.

Consider these improvements:

-def set_rss_urls(self, rss_urls: List[str]):
-    self.rss_urls = rss_urls
+def set_rss_urls(self, rss_urls: List[str]) -> None:
+    """Set RSS URLs for Plex feed fetching.
+    
+    Args:
+        rss_urls (List[str]): List of valid RSS feed URLs
+    
+    Raises:
+        ValueError: If any URL in the list is malformed
+    """
+    if not all(url.startswith(('http://', 'https://')) for url in rss_urls):
+        raise ValueError("All RSS URLs must be valid HTTP(S) URLs")
+    self.rss_urls = rss_urls
+    self.rss_enabled = bool(rss_urls)

-def clear_rss_urls(self):
+def clear_rss_urls(self) -> None:
+    """Clear all RSS URLs and disable RSS functionality."""
     self.rss_urls = None
     self.rss_enabled = False

This implementation:

  1. Adds basic URL validation
  2. Updates rss_enabled flag based on URL presence
  3. Adds type hints for return values
  4. Includes docstrings with parameter descriptions
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def set_rss_urls(self, rss_urls: List[str]):
self.rss_urls = rss_urls
def clear_rss_urls(self):
self.rss_urls = None
self.rss_enabled = False
def set_rss_urls(self, rss_urls: List[str]) -> None:
"""Set RSS URLs for Plex feed fetching.
Args:
rss_urls (List[str]): List of valid RSS feed URLs
Raises:
ValueError: If any URL in the list is malformed
"""
if not all(url.startswith(('http://', 'https://')) for url in rss_urls):
raise ValueError("All RSS URLs must be valid HTTP(S) URLs")
self.rss_urls = rss_urls
self.rss_enabled = bool(rss_urls)
def clear_rss_urls(self) -> None:
"""Clear all RSS URLs and disable RSS functionality."""
self.rss_urls = None
self.rss_enabled = False

def validate_rss(self, url: str):
return self.request_handler.execute(HttpMethod.GET, url)

Expand Down
80 changes: 62 additions & 18 deletions src/program/apis/trakt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,26 @@
from datetime import datetime
from types import SimpleNamespace
from typing import Union, List, Optional
from urllib.parse import urlencode
from requests import RequestException, Session
from program import MediaItem
from program.media import Movie, Show, Season, Episode
from program.settings.manager import settings_manager
from program.settings.models import TraktModel
from program.utils.request import get_rate_limit_params, create_service_session, logger, BaseRequestHandler, \
ResponseType, HttpMethod, ResponseObject

ResponseType, HttpMethod, ResponseObject, get_cache_params

class TraktAPIError(Exception):
"""Base exception for TraktApi related errors"""

class TraktRequestHandler(BaseRequestHandler):
def __init__(self, session: Session, request_logging: bool = False):
super().__init__(session, response_type=ResponseType.SIMPLE_NAMESPACE, custom_exception=TraktAPIError, request_logging=request_logging)
def __init__(self, session: Session, response_type=ResponseType.SIMPLE_NAMESPACE, request_logging: bool = False):
super().__init__(session, response_type=response_type, custom_exception=TraktAPIError, request_logging=request_logging)

def execute(self, method: HttpMethod, endpoint: str, **kwargs) -> ResponseObject:
return super()._request(method, endpoint, **kwargs)


class TraktAPI:
"""Handles Trakt API communication"""
BASE_URL = "https://api.trakt.tv"
Expand All @@ -29,16 +32,17 @@ class TraktAPI:
"short_list": re.compile(r"https://trakt.tv/lists/\d+")
}

def __init__(self, api_key: Optional[str] = None, rate_limit: bool = True):
self.api_key = api_key
rate_limit_params = get_rate_limit_params(max_calls=1000, period=300) if rate_limit else None
session = create_service_session(
rate_limit_params=rate_limit_params,
use_cache=False
)
def __init__(self, settings: TraktModel):
self.settings = settings
self.oauth_client_id = self.settings.oauth.oauth_client_id
self.oauth_client_secret = self.settings.oauth.oauth_client_secret
self.oauth_redirect_uri = self.settings.oauth.oauth_redirect_uri
rate_limit_params = get_rate_limit_params(max_calls=1000, period=300)
trakt_cache = get_cache_params("trakt", 86400)
session = create_service_session(rate_limit_params=rate_limit_params, use_cache=True, cache_params=trakt_cache)
self.headers = {
"Content-type": "application/json",
"trakt-api-key": self.api_key or self.CLIENT_ID,
"trakt-api-key": self.CLIENT_ID,
iPromKnight marked this conversation as resolved.
Show resolved Hide resolved
"trakt-api-version": "2"
}
session.headers.update(self.headers)
Expand Down Expand Up @@ -148,15 +152,15 @@ def get_show(self, imdb_id: str) -> dict:
"""Wrapper for trakt.tv API show method."""
if not imdb_id:
return {}
url = f"https://api.trakt.tv/shows/{imdb_id}/seasons?extended=episodes,full"
url = f"{self.BASE_URL}/shows/{imdb_id}/seasons?extended=episodes,full"
response = self.request_handler.execute(HttpMethod.GET, url, timeout=30)
return response.data if response.is_ok and response.data else {}

def get_show_aliases(self, imdb_id: str, item_type: str) -> List[dict]:
"""Wrapper for trakt.tv API show method."""
if not imdb_id:
return []
url = f"https://api.trakt.tv/{item_type}/{imdb_id}/aliases"
url = f"{self.BASE_URL}/{item_type}/{imdb_id}/aliases"
try:
response = self.request_handler.execute(HttpMethod.GET, url, timeout=30)
if response.is_ok and response.data:
Expand All @@ -178,7 +182,7 @@ def get_show_aliases(self, imdb_id: str, item_type: str) -> List[dict]:

def create_item_from_imdb_id(self, imdb_id: str, type: str = None) -> Optional[MediaItem]:
"""Wrapper for trakt.tv API search method."""
url = f"https://api.trakt.tv/search/imdb/{imdb_id}?extended=full"
url = f"{self.BASE_URL}/search/imdb/{imdb_id}?extended=full"
response = self.request_handler.execute(HttpMethod.GET, url, timeout=30)
if not response.is_ok or not response.data:
logger.error(
Expand All @@ -194,7 +198,7 @@ def create_item_from_imdb_id(self, imdb_id: str, type: str = None) -> Optional[M

def get_imdbid_from_tmdb(self, tmdb_id: str, type: str = "movie") -> Optional[str]:
"""Wrapper for trakt.tv API search method."""
url = f"https://api.trakt.tv/search/tmdb/{tmdb_id}" # ?extended=full
url = f"{self.BASE_URL}/search/tmdb/{tmdb_id}" # ?extended=full
response = self.request_handler.execute(HttpMethod.GET, url, timeout=30)
if not response.is_ok or not response.data:
return None
Expand All @@ -206,7 +210,7 @@ def get_imdbid_from_tmdb(self, tmdb_id: str, type: str = "movie") -> Optional[st

def get_imdbid_from_tvdb(self, tvdb_id: str, type: str = "show") -> Optional[str]:
"""Wrapper for trakt.tv API search method."""
url = f"https://api.trakt.tv/search/tvdb/{tvdb_id}"
url = f"{self.BASE_URL}/search/tvdb/{tvdb_id}"
response = self.request_handler.execute(HttpMethod.GET, url, timeout=30)
if not response.is_ok or not response.data:
return None
Expand All @@ -219,7 +223,7 @@ def get_imdbid_from_tvdb(self, tvdb_id: str, type: str = "show") -> Optional[str
def resolve_short_url(self, short_url) -> Union[str, None]:
"""Resolve short URL to full URL"""
try:
response = self.request_handler.execute(HttpMethod.GET, url=short_url, additional_headers={"Content-Type": "application/json", "Accept": "text/html"})
response = self.request_handler.execute(HttpMethod.GET, endpoint=short_url, additional_headers={"Content-Type": "application/json", "Accept": "text/html"})
if response.is_ok:
return response.response.url
else:
Expand Down Expand Up @@ -279,6 +283,46 @@ def map_item_from_data(self, data, item_type: str, show_genres: List[str] = None
logger.error(f"Unknown item type {item_type} for {data.title} not found in list of acceptable items")
return None

def perform_oauth_flow(self) -> str:
"""Initiate the OAuth flow and return the authorization URL."""
if not self.oauth_client_id or not self.oauth_client_secret or not self.oauth_redirect_uri:
logger.error("OAuth settings not found in Trakt settings")
raise TraktAPIError("OAuth settings not found in Trakt settings")

params = {
"response_type": "code",
"client_id": self.oauth_client_id,
"redirect_uri": self.oauth_redirect_uri,
}
return f"{self.BASE_URL}/oauth/authorize?{urlencode(params)}"

def handle_oauth_callback(self, api_key:str, code: str) -> bool:
"""Handle the OAuth callback and exchange the code for an access token."""
if not self.oauth_client_id or not self.oauth_client_secret or not self.oauth_redirect_uri:
logger.error("OAuth settings not found in Trakt settings")
return False

token_url = f"{self.BASE_URL}/oauth/token"
payload = {
"code": code,
"client_id": self.oauth_client_id,
"client_secret": self.oauth_client_secret,
"redirect_uri": self.oauth_redirect_uri,
"grant_type": "authorization_code",
}
headers = self.headers.copy()
headers["trakt-api-key"] = api_key
response = self.request_handler.execute(HttpMethod.POST, token_url, data=payload, additional_headers=headers)
if response.is_ok:
token_data = response.data
self.settings.access_token = token_data.get("access_token")
self.settings.refresh_token = token_data.get("refresh_token")
settings_manager.save() # Save the tokens to settings
iPromKnight marked this conversation as resolved.
Show resolved Hide resolved
return True
else:
logger.error(f"Failed to obtain OAuth token: {response.status_code}")
return False

def _get_imdb_id_from_list(self, namespaces: List[SimpleNamespace], id_type: str = None, _id: str = None,
type: str = None) -> Optional[str]:
"""Get the imdb_id from the list of namespaces."""
Expand Down
2 changes: 1 addition & 1 deletion src/program/db/db_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from sqlalchemy import delete, desc, func, insert, inspect, select, text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session, joinedload, selectinload
from program.utils import root_dir

import alembic
from program.utils import root_dir
iPromKnight marked this conversation as resolved.
Show resolved Hide resolved
from program.media.stream import Stream, StreamBlacklistRelation, StreamRelation
from program.services.libraries.symlink import fix_broken_symlinks
from program.settings.manager import settings_manager
Expand Down
8 changes: 7 additions & 1 deletion src/program/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
PlexWatchlist,
TraktContent,
)
from program.apis import bootstrap_apis
from program.services.downloaders import Downloader
from program.services.indexers.trakt import TraktIndexer
from program.services.libraries import SymlinkLibrary
Expand Down Expand Up @@ -64,8 +65,11 @@ def __init__(self):
self.malloc_time = time.monotonic()-50
self.last_snapshot = None

def initialize_services(self):
def initialize_apis(self):
bootstrap_apis()
Comment on lines +68 to +69
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling and logging to API initialization.

The method should include error handling and logging to track API initialization status and handle potential failures gracefully.

Consider this implementation:

 def initialize_apis(self):
+    logger.log("PROGRAM", "Initializing APIs...")
+    try:
         bootstrap_apis()
+        logger.success("APIs initialized successfully")
+    except Exception as e:
+        logger.error(f"Failed to initialize APIs: {e}")
+        raise
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def initialize_apis(self):
bootstrap_apis()
def initialize_apis(self):
logger.log("PROGRAM", "Initializing APIs...")
try:
bootstrap_apis()
logger.success("APIs initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize APIs: {e}")
raise


def initialize_services(self):
"""Initialize all services."""
self.requesting_services = {
Overseerr: Overseerr(),
PlexWatchlist: PlexWatchlist(),
Expand Down Expand Up @@ -122,13 +126,15 @@ def start(self):
latest_version = get_version()
logger.log("PROGRAM", f"Riven v{latest_version} starting!")

settings_manager.register_observer(self.initialize_apis)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add validation for API initialization success.

While the initialization order is correct, the code should validate that APIs are properly initialized before proceeding with service initialization.

Consider adding a validation method and using it:

+    def validate_apis(self) -> bool:
+        """Validate that all required APIs are initialized."""
+        try:
+            # Add specific API validation logic here
+            return True
+        except Exception as e:
+            logger.error(f"API validation failed: {e}")
+            return False

     def start(self):
         # ... existing code ...
         self.initialize_apis()
+        if not self.validate_apis():
+            logger.error("APIs failed to initialize properly")
+            return
         self.initialize_services()

Also applies to: 137-137

settings_manager.register_observer(self.initialize_services)
os.makedirs(data_dir_path, exist_ok=True)

if not settings_manager.settings_file.exists():
logger.log("PROGRAM", "Settings file not found, creating default settings")
settings_manager.save()

self.initialize_apis()
self.initialize_services()

max_worker_env_vars = [var for var in os.environ if var.endswith("_MAX_WORKERS")]
Expand Down
4 changes: 3 additions & 1 deletion src/program/services/content/listrr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Listrr content module"""
from typing import Generator
from kink import di
from program.utils.request import logger
from program.media.item import MediaItem
from program.settings.manager import settings_manager
Expand All @@ -12,7 +13,7 @@ class Listrr:
def __init__(self):
self.key = "listrr"
self.settings = settings_manager.settings.content.listrr
self.api = ListrrAPI(self.settings.api_key)
self.api = None
self.initialized = self.validate()
if not self.initialized:
return
Expand Down Expand Up @@ -40,6 +41,7 @@ def validate(self) -> bool:
logger.error("Both Movie and Show lists are empty or not set.")
return False
try:
self.api = di[ListrrAPI]
response = self.api.validate()
if not response.is_ok:
logger.error(
Expand Down
5 changes: 3 additions & 2 deletions src/program/services/content/mdblist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Generator
from loguru import logger

from kink import di
from program.apis.mdblist_api import MdblistAPI
from program.media.item import MediaItem
from program.settings.manager import settings_manager
Expand All @@ -14,7 +14,7 @@ class Mdblist:
def __init__(self):
self.key = "mdblist"
self.settings = settings_manager.settings.content.mdblist
self.api = MdblistAPI(self.settings.api_key)
self.api = None
self.initialized = self.validate()
if not self.initialized:
return
Expand All @@ -30,6 +30,7 @@ def validate(self):
if not self.settings.lists:
logger.error("Mdblist is enabled, but list is empty.")
return False
self.api = di[MdblistAPI]
iPromKnight marked this conversation as resolved.
Show resolved Hide resolved
response = self.api.validate()
if "Invalid API key!" in response.response.text:
logger.error("Mdblist api key is invalid.")
Expand Down
Loading
Loading