Skip to content

Commit

Permalink
Lite: SharedWorker mode (#6099)
Browse files Browse the repository at this point in the history
* Create the SharedWorker mode

* Fix the app_id management mechanism to be context-based so it works properly with run_file() and run_code()

* Adjust updateProgress() for the SharedWorker mode

* Make `get_continuous_fn()` Wasm-compatible in order for the created function not to occupy the thread with a loop with `time.sleep()` which is no-op in the Wasm env

* add changeset

* Add the `shared-worker` option to the custom-element API

* Apply formatter to js/app/src/lite/index.ts

* format

* Revert pnpm-lock.yml file version which is not needed to be changed in this PR

* Fix the unittest for `get_continuous_fn()`

* Fix the unit test because `anext()` is not available in Python 3.8, which is used in the CI env

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
3 people authored Nov 13, 2023
1 parent 9d6d72f commit d842097
Show file tree
Hide file tree
Showing 14 changed files with 468 additions and 246 deletions.
7 changes: 7 additions & 0 deletions .changeset/strong-bars-cross.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@gradio/app": minor
"@gradio/wasm": minor
"gradio": minor
---

feat:Lite: SharedWorker mode
12 changes: 7 additions & 5 deletions gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import pkgutil
import re
import threading
import time
import traceback
import typing
import warnings
Expand Down Expand Up @@ -597,16 +596,19 @@ def is_update(val):


def get_continuous_fn(fn: Callable, every: float) -> Callable:
def continuous_fn(*args):
# For Wasm-compatibility, we need to use asyncio.sleep() instead of time.sleep(),
# so we need to make the function async.
async def continuous_coro(*args):
while True:
output = fn(*args)
if isinstance(output, GeneratorType):
yield from output
for item in output:
yield item
else:
yield output
time.sleep(every)
await asyncio.sleep(every)

return continuous_fn
return continuous_coro


def function_wrapper(
Expand Down
41 changes: 34 additions & 7 deletions gradio/wasm_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from __future__ import annotations

import sys
from contextlib import contextmanager
from contextvars import ContextVar

# See https://pyodide.org/en/stable/usage/faq.html#how-to-detect-that-code-is-run-with-pyodide
IS_WASM = sys.platform == "emscripten"
Expand All @@ -8,19 +12,42 @@ class WasmUnsupportedError(Exception):
pass


app = None
# Mapping from app ID to the Gradio's FastAPI app instance (`app`).
# To support the SharedWorker mode where multiple apps are running in the same worker,
# we need to keep track of the app instances for each app ID.
app_map = {}


# `with app_id_context(app_id):` is used to set the app ID
# which `register_app()` uses to register the app instance.
# Context variables are natively supported in asyncio and
# can manage data in each task (https://docs.python.org/3/library/contextvars.html#asyncio-support),
# so we can use them for this purpose.
_app_id_context_var: ContextVar[str | None] = ContextVar("app_id", default=None)


@contextmanager
def app_id_context(app_id: str):
token = _app_id_context_var.set(app_id)
yield
_app_id_context_var.reset(token)


# `register_app` and `get_registered_app` are used
# for the Wasm worker to get a reference to
# the Gradio's FastAPI app instance (`app`).
def register_app(_app):
global app
if app:
global app_map

app_id = _app_id_context_var.get()

if app_id in app_map:
app = app_map[app_id]
app.blocks.close()
app = _app

app_map[app_id] = _app


def get_registered_app():
global app
return app
def get_registered_app(app_id: str):
global app_map
return app_map[app_id]
5 changes: 4 additions & 1 deletion js/app/src/lite/custom-element/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ interface GradioComponentOptions {
autoScroll: Options["autoScroll"];
controlPageTitle: Options["controlPageTitle"];
appMode: Options["appMode"];
sharedWorkerMode?: Options["sharedWorkerMode"];
}

interface GradioLiteAppOptions {
Expand Down Expand Up @@ -75,6 +76,7 @@ export function bootstrap_custom_element(): void {
const autoScroll = this.hasAttribute("auto-scroll");
const controlPageTitle = this.hasAttribute("control-page-title");
const appMode = this.hasAttribute("app-mode");
const sharedWorkerMode = this.hasAttribute("shared-worker");

return {
info,
Expand All @@ -88,7 +90,8 @@ export function bootstrap_custom_element(): void {
: null,
autoScroll,
controlPageTitle,
appMode
appMode,
sharedWorkerMode
};
}

Expand Down
1 change: 1 addition & 0 deletions js/app/src/lite/dev/App.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def hi(name):
files,
entrypoint,
requirements,
sharedWorkerMode: true,
info: true,
container: true,
isEmbed: false,
Expand Down
4 changes: 3 additions & 1 deletion js/app/src/lite/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ export interface Options {
requirements?: WorkerProxyOptions["requirements"];
code?: string;
entrypoint?: string;
sharedWorkerMode?: boolean;
info: boolean;
container: boolean;
isEmbed: boolean;
Expand All @@ -75,7 +76,8 @@ export function create(options: Options): GradioAppController {
gradioWheelUrl: new URL(gradioWheel, import.meta.url).href,
gradioClientWheelUrl: new URL(gradioClientWheel, import.meta.url).href,
files: options.files ?? {},
requirements: options.requirements ?? []
requirements: options.requirements ?? [],
sharedWorkerMode: options.sharedWorkerMode ?? false
});

worker_proxy.addEventListener("initialization-error", (event) => {
Expand Down
14 changes: 10 additions & 4 deletions js/wasm/src/cross-origin-worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
// This implementation was based on https://github.com/whitphx/stlite/blob/v0.34.0/packages/kernel/src/kernel.ts,
// and this technique was introduced originally for Webpack at https://github.com/webpack/webpack/discussions/14648#discussioncomment-1589272
export class CrossOriginWorkerMaker {
public readonly worker: Worker;
public readonly worker: Worker | SharedWorker;

constructor(url: URL, options?: WorkerOptions & { shared?: boolean }) {
const { shared = false, ...workerOptions } = options ?? {};

constructor(url: URL) {
try {
// This is the normal way to load a worker script, which is the best straightforward if possible.
this.worker = new Worker(url);
this.worker = shared
? new SharedWorker(url, workerOptions)
: new Worker(url, workerOptions);
} catch (e) {
console.debug(
`Failed to load a worker script from ${url.toString()}. Trying to load a cross-origin worker...`
Expand All @@ -23,7 +27,9 @@ export class CrossOriginWorkerMaker {
type: "text/javascript"
});
const workerBlobUrl = URL.createObjectURL(workerBlob);
this.worker = new Worker(workerBlobUrl);
this.worker = shared
? new SharedWorker(workerBlobUrl, workerOptions)
: new Worker(workerBlobUrl, workerOptions);
URL.revokeObjectURL(workerBlobUrl);
}
}
Expand Down
12 changes: 9 additions & 3 deletions js/wasm/src/message-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,16 @@ export interface InMessageBase {
data: unknown;
}

export interface InMessageInit extends InMessageBase {
type: "init";
export interface InMessageInitEnv extends InMessageBase {
type: "init-env";
data: {
gradioWheelUrl: string;
gradioClientWheelUrl: string;
};
}
export interface InMessageInitApp extends InMessageBase {
type: "init-app";
data: {
files: Record<string, EmscriptenFile | EmscriptenFileUrl>;
requirements: string[];
};
Expand Down Expand Up @@ -92,7 +97,8 @@ export interface InMessageEcho extends InMessageBase {
}

export type InMessage =
| InMessageInit
| InMessageInitEnv
| InMessageInitApp
| InMessageRunPythonCode
| InMessageRunPythonFile
| InMessageWebSocket
Expand Down
8 changes: 8 additions & 0 deletions js/wasm/src/webworker/file.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import path from "path-browserify";
import type { PyodideInterface } from "pyodide";

export function addAppIdIfRelative(appId: string, filePath: string): string {
if (path.isAbsolute(filePath)) {
return filePath;
}

return path.join(appId, filePath);
}

function ensureParent(pyodide: PyodideInterface, filePath: string): void {
const normalized = path.normalize(filePath);

Expand Down
Loading

0 comments on commit d842097

Please sign in to comment.