From 9a29db9a968213851785d9baf14628595505e675 Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Wed, 12 Feb 2025 18:47:48 +0200 Subject: [PATCH] feat(fal_client): support multipart (#413) --- projects/fal_client/src/fal_client/client.py | 417 +++++++++++++++++++ 1 file changed, 417 insertions(+) diff --git a/projects/fal_client/src/fal_client/client.py b/projects/fal_client/src/fal_client/client.py index 2741fff7..66e99fb1 100644 --- a/projects/fal_client/src/fal_client/client.py +++ b/projects/fal_client/src/fal_client/client.py @@ -1,9 +1,11 @@ from __future__ import annotations import io +import math import os import mimetypes import asyncio +from pathlib import Path import time import base64 import threading @@ -119,6 +121,374 @@ async def get_token(self) -> CDNToken: return self._token +MULTIPART_THRESHOLD = 100 * 1024 * 1024 +MULTIPART_CHUNK_SIZE = 10 * 1024 * 1024 +MULTIPART_MAX_CONCURRENCY = 10 + + +class MultipartUpload: + def __init__( + self, + *, + file_name: str, + client: httpx.Client, + token_manager: CDNTokenManager, + chunk_size: int | None = None, + content_type: str | None = None, + max_concurrency: int | None = None, + ) -> None: + self.file_name = file_name + self._client = client + self._token_manager = token_manager + self.chunk_size = chunk_size or MULTIPART_CHUNK_SIZE + self.content_type = content_type or "application/octet-stream" + self.max_concurrency = max_concurrency or MULTIPART_MAX_CONCURRENCY + self._access_url: str | None = None + self._upload_id: str | None = None + self._parts: list[dict] = [] + + @property + def access_url(self) -> str: + if not self._access_url: + raise ValueError("Upload not initiated") + return self._access_url + + @property + def upload_id(self) -> str: + if not self._upload_id: + raise ValueError("Upload not initiated") + return self._upload_id + + @property + def auth_headers(self) -> dict[str, str]: + token = self._token_manager.get_token() + return { + "Authorization": f"{token.token_type} {token.token}", + "User-Agent": "fal/0.1.0", + } + + def create(self): + token = self._token_manager.get_token() + url = f"{token.base_upload_url}/files/upload/multipart" + response = _maybe_retry_request( + self._client, + "POST", + url, + headers={ + **self.auth_headers, + "Accept": "application/json", + "Content-Type": self.content_type, + "X-Fal-File-Name": self.file_name, + }, + ) + result = response.json() + self._access_url = result["access_url"] + self._upload_id = result["uploadId"] + + def upload_part(self, part_number: int, data: bytes) -> None: + url = f"{self.access_url}/multipart/{self.upload_id}/{part_number}" + + response = _request( + self._client, + "PUT", + url, + headers={ + **self.auth_headers, + "Content-Type": self.content_type, + "Accept-Encoding": "identity", # Keep this to ensure we get ETag headers + }, + content=data, + timeout=None, + ) + + etag = response.headers["etag"] + self._parts.append( + { + "partNumber": part_number, + "etag": etag, + } + ) + + def complete(self) -> str: + url = f"{self.access_url}/multipart/{self.upload_id}/complete" + _maybe_retry_request( + self._client, + "POST", + url, + headers=self.auth_headers, + json={"parts": self._parts}, + ) + return self.access_url + + @classmethod + def save( + cls, + *, + client: httpx.Client, + token_manager: CDNTokenManager, + file_name: str, + data: bytes, + content_type: str | None = None, + chunk_size: int | None = None, + max_concurrency: int | None = None, + ): + import concurrent.futures + + multipart = cls( + file_name=file_name, + client=client, + token_manager=token_manager, + chunk_size=chunk_size, + content_type=content_type, + max_concurrency=max_concurrency, + ) + multipart.create() + parts = math.ceil(len(data) / multipart.chunk_size) + with concurrent.futures.ThreadPoolExecutor( + max_workers=multipart.max_concurrency + ) as executor: + futures = [] + for part_number in range(1, parts + 1): + start = (part_number - 1) * multipart.chunk_size + data = data[start : start + multipart.chunk_size] + futures.append( + executor.submit(multipart.upload_part, part_number, data) + ) + for future in concurrent.futures.as_completed(futures): + future.result() + return multipart.complete() + + @classmethod + def save_file( + cls, + *, + client: httpx.Client, + token_manager: CDNTokenManager, + file_path: str | Path, + chunk_size: int | None = None, + content_type: str | None = None, + max_concurrency: int | None = None, + ) -> str: + import concurrent.futures + + file_name = os.path.basename(file_path) + size = os.path.getsize(file_path) + multipart = cls( + file_name=file_name, + client=client, + token_manager=token_manager, + chunk_size=chunk_size, + content_type=content_type, + max_concurrency=max_concurrency, + ) + multipart.create() + parts = math.ceil(size / multipart.chunk_size) + with concurrent.futures.ThreadPoolExecutor( + max_workers=multipart.max_concurrency + ) as executor: + futures = [] + for part_number in range(1, parts + 1): + + def _upload_part(pn: int) -> None: + with open(file_path, "rb") as f: + start = (pn - 1) * multipart.chunk_size + f.seek(start) + data = f.read(multipart.chunk_size) + multipart.upload_part(pn, data) + + futures.append(executor.submit(_upload_part, part_number)) + for future in concurrent.futures.as_completed(futures): + future.result() + return multipart.complete() + + +class AsyncMultipartUpload: + def __init__( + self, + *, + file_name: str, + client: httpx.AsyncClient, + token_manager: AsyncCDNTokenManager, + chunk_size: int | None = None, + content_type: str | None = None, + max_concurrency: int | None = None, + ) -> None: + self.file_name = file_name + self._client = client + self._token_manager = token_manager + self.chunk_size = chunk_size or MULTIPART_CHUNK_SIZE + self.content_type = content_type or "application/octet-stream" + self.max_concurrency = max_concurrency or MULTIPART_MAX_CONCURRENCY + self._access_url: str | None = None + self._upload_id: str | None = None + self._parts: list[dict] = [] + + @property + def access_url(self) -> str: + if not self._access_url: + raise ValueError("Upload not initiated") + return self._access_url + + @property + def upload_id(self) -> str: + if not self._upload_id: + raise ValueError("Upload not initiated") + return self._upload_id + + @property + async def auth_headers(self) -> dict[str, str]: + token = await self._token_manager.get_token() + return { + "Authorization": f"{token.token_type} {token.token}", + "User-Agent": "fal/0.1.0", + } + + async def create(self): + token = await self._token_manager.get_token() + url = f"{token.base_upload_url}/files/upload/multipart" + headers = await self.auth_headers + response = await _async_maybe_retry_request( + self._client, + "POST", + url, + headers={ + **headers, + "Accept": "application/json", + "Content-Type": self.content_type, + "X-Fal-File-Name": self.file_name, + }, + ) + result = response.json() + self._access_url = result["access_url"] + self._upload_id = result["uploadId"] + + async def upload_part(self, part_number: int, data: bytes) -> None: + url = f"{self.access_url}/multipart/{self.upload_id}/{part_number}" + headers = await self.auth_headers + + response = await _async_request( + self._client, + "PUT", + url, + headers={ + **headers, + "Content-Type": self.content_type, + "Accept-Encoding": "identity", # Keep this to ensure we get ETag headers + }, + content=data, + timeout=None, + ) + + etag = response.headers["etag"] + self._parts.append( + { + "partNumber": part_number, + "etag": etag, + } + ) + + async def complete(self) -> str: + url = f"{self.access_url}/multipart/{self.upload_id}/complete" + headers = await self.auth_headers + await _async_maybe_retry_request( + self._client, + "POST", + url, + headers=headers, + json={"parts": self._parts}, + ) + return self.access_url + + @classmethod + async def save( + cls, + *, + client: httpx.AsyncClient, + token_manager: AsyncCDNTokenManager, + file_name: str, + data: bytes, + content_type: str | None = None, + chunk_size: int | None = None, + max_concurrency: int | None = None, + ) -> str: + multipart = cls( + file_name=file_name, + client=client, + token_manager=token_manager, + chunk_size=chunk_size, + content_type=content_type, + max_concurrency=max_concurrency, + ) + await multipart.create() + parts = math.ceil(len(data) / multipart.chunk_size) + + async def upload_part(part_number: int) -> None: + start = (part_number - 1) * multipart.chunk_size + chunk = data[start : start + multipart.chunk_size] + await multipart.upload_part(part_number, chunk) + + tasks = [ + asyncio.create_task(upload_part(part_number)) + for part_number in range(1, parts + 1) + ] + + # Process concurrent uploads with semaphore to limit concurrency + sem = asyncio.Semaphore(multipart.max_concurrency) + + async def bounded_upload(task): + async with sem: + await task + + await asyncio.gather(*[bounded_upload(task) for task in tasks]) + return await multipart.complete() + + @classmethod + async def save_file( + cls, + *, + client: httpx.AsyncClient, + token_manager: AsyncCDNTokenManager, + file_path: str | Path, + chunk_size: int | None = None, + content_type: str | None = None, + max_concurrency: int | None = None, + ) -> str: + file_name = os.path.basename(file_path) + size = os.path.getsize(file_path) + multipart = cls( + file_name=file_name, + client=client, + token_manager=token_manager, + chunk_size=chunk_size, + content_type=content_type, + max_concurrency=max_concurrency, + ) + await multipart.create() + parts = math.ceil(size / multipart.chunk_size) + + async def upload_part(part_number: int) -> None: + with open(file_path, "rb") as f: + start = (part_number - 1) * multipart.chunk_size + f.seek(start) + data = f.read(multipart.chunk_size) + await multipart.upload_part(part_number, data) + + tasks = [ + asyncio.create_task(upload_part(part_number)) + for part_number in range(1, parts + 1) + ] + + # Process concurrent uploads with semaphore to limit concurrency + sem = asyncio.Semaphore(multipart.max_concurrency) + + async def bounded_upload(task): + async with sem: + await task + + await asyncio.gather(*[bounded_upload(task) for task in tasks]) + return await multipart.complete() + + class FalClientError(Exception): pass @@ -623,6 +993,20 @@ async def upload( client = await self._get_cdn_client() + if isinstance(data, str): + data = data.encode("utf-8") + + if len(data) > MULTIPART_THRESHOLD: + if file_name is None: + file_name = "upload.bin" + return await AsyncMultipartUpload.save( + client=client, + token_manager=self._token_manager, + file_name=file_name, + data=data, + content_type=content_type, + ) + headers = {"Content-Type": content_type} if file_name is not None: headers["X-Fal-File-Name"] = file_name @@ -643,6 +1027,15 @@ async def upload_file(self, path: os.PathLike) -> str: if mime_type is None: mime_type = "application/octet-stream" + if os.path.getsize(path) > MULTIPART_THRESHOLD: + client = await self._get_cdn_client() + return await AsyncMultipartUpload.save_file( + file_path=path, + client=client, + token_manager=self._token_manager, + content_type=mime_type, + ) + with open(path, "rb") as file: return await self.upload( file.read(), mime_type, file_name=os.path.basename(path) @@ -675,6 +1068,7 @@ def _client(self) -> httpx.Client: "User-Agent": USER_AGENT, }, timeout=self.default_timeout, + follow_redirects=True, ) @cached_property @@ -849,6 +1243,20 @@ def upload( client = self._get_cdn_client() + if isinstance(data, str): + data = data.encode("utf-8") + + if len(data) > MULTIPART_THRESHOLD: + if file_name is None: + file_name = "upload.bin" + return MultipartUpload.save( + client=client, + token_manager=self._token_manager, + file_name=file_name, + data=data, + content_type=content_type, + ) + headers = {"Content-Type": content_type} if file_name is not None: headers["X-Fal-File-Name"] = file_name @@ -869,6 +1277,15 @@ def upload_file(self, path: os.PathLike) -> str: if mime_type is None: mime_type = "application/octet-stream" + if os.path.getsize(path) > MULTIPART_THRESHOLD: + client = self._get_cdn_client() + return MultipartUpload.save_file( + file_path=path, + client=client, + token_manager=self._token_manager, + content_type=mime_type, + ) + with open(path, "rb") as file: return self.upload(file.read(), mime_type, file_name=os.path.basename(path))