Skip to content

Commit

Permalink
WIP: sketch out wasi-nn extensions
Browse files Browse the repository at this point in the history
This change alters the `wasi-nn` world to split out two different modes
of operation:
- `inference`: this continues the traditional mechanism for computing
  with wasi-nn, by passing named `tensor`s to a `context`. Now that
  `tensor`s are resources, we pass all inputs and return all outputs
  together, eliminating `get-input` and `set-output`
- `prompt`: this new mode expects a `string` prompt which is passed
  along to a backend LLM. The returned string is not streamed, but could
  be in the future

This change also adds metadata modification of the `graph` via
`list-properties`, `get-property` and `set-property`. It is unclear
whether these methods should hang off the `context` objects instead
(TODO). It is also unclear whether the model of `load`-ing a `graph` and
then initializing it into one of the two modes via `inference::init` or
`prompt::init` is the best approach; most graphs are one or the other so
it does not make sense to open the door to `init` failures.

[bytecodealliance#74] (replace `load` with `load-by-name`) is replicated in this commit.
[bytecodealliance#75] (return errors as records) and [bytecodealliance#76] (remove the error
constructor) is superseded by this commit, since every error is simply
returned as a `string` and the `error` resource is removed.

[bytecodealliance#74]: WebAssembly/wasi-nn#74
[bytecodealliance#75]: WebAssembly/wasi-nn#75
[bytecodealliance#76]: WebAssembly/wasi-nn#76
  • Loading branch information
abrown committed Aug 8, 2024
1 parent 6907868 commit 7a9dd1f
Showing 1 changed file with 51 additions and 89 deletions.
140 changes: 51 additions & 89 deletions crates/wasi-nn/wit/wasi-nn.wit
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ world ml {
import tensor;
import graph;
import inference;
import errors;
import prompt;
}

/// All inputs and outputs to an ML inference are represented as `tensor`s.
Expand Down Expand Up @@ -61,108 +61,70 @@ interface tensor {
/// A `graph` is a loaded instance of a specific ML model (e.g., MobileNet) for a specific ML
/// framework (e.g., TensorFlow):
interface graph {
use errors.{error};
use tensor.{tensor};
use inference.{graph-execution-context};

/// An execution graph for performing inference (i.e., a model).
resource graph {
init-execution-context: func() -> result<graph-execution-context, error>;
}

/// Describes the encoding of the graph. This allows the API to be implemented by various
/// backends that encode (i.e., serialize) their graph IR with different formats.
enum graph-encoding {
openvino,
onnx,
tensorflow,
pytorch,
tensorflowlite,
ggml,
autodetect,
}

/// Define where the graph should be executed.
enum execution-target {
cpu,
gpu,
tpu
}

/// The graph initialization data.
///
/// This gets bundled up into an array of buffers because implementing backends may encode their
/// graph IR in parts (e.g., OpenVINO stores its IR and weights separately).
type graph-builder = list<u8>;

/// Load a `graph` from an opaque sequence of bytes to use for inference.
load: func(builder: list<graph-builder>, encoding: graph-encoding, target: execution-target) -> result<graph, error>;

/// Load a `graph` by name.
///
/// How the host expects the names to be passed and how it stores the graphs for retrieval via
/// this function is **implementation-specific**. This allows hosts to choose name schemes that
/// range from simple to complex (e.g., URLs?) and caching mechanisms of various kinds.
load-by-name: func(name: string) -> result<graph, error>;
}
load: func(name: string) -> result<graph, string>;

/// An inference "session" is encapsulated by a `graph-execution-context`. This structure binds a
/// `graph` to input tensors before `compute`-ing an inference:
interface inference {
use errors.{error};
use tensor.{tensor, tensor-data};

/// Bind a `graph` to the input and output tensors for an inference.
///
/// TODO: this may no longer be necessary in WIT
/// (https://github.com/WebAssembly/wasi-nn/issues/43)
resource graph-execution-context {
/// Define the inputs to use for inference.
set-input: func(name: string, tensor: tensor) -> result<_, error>;
/// An execution graph for performing inference (i.e., a model).
resource graph {
/// Retrieve the properties of the graph.
///
/// These are metadata about the graph, unique to the graph and the
/// ML backend providing it.
list-properties: func() -> list<string>;

/// Compute the inference on the given inputs.
/// Retrieve the value of a property.
///
/// Note the expected sequence of calls: `set-input`, `compute`, `get-output`. TODO: this
/// expectation could be removed as a part of
/// https://github.com/WebAssembly/wasi-nn/issues/43.
compute: func() -> result<_, error>;
/// If the property does not exist, this function returns `none`.
get-property: func(name: string) -> option<string>;

/// Extract the outputs after inference.
get-output: func(name: string) -> result<tensor, error>;
/// Modify the value of a property.
///
/// If the operation fails, this function returns a string from the ML
/// backend describing the error.
set-property: func(name: string, value: string) -> result<_, string>;
}
}

/// TODO: create function-specific errors (https://github.com/WebAssembly/wasi-nn/issues/42)
interface errors {
enum error-code {
// Caller module passed an invalid argument.
invalid-argument,
// Invalid encoding.
invalid-encoding,
// The operation timed out.
timeout,
// Runtime Error.
runtime-error,
// Unsupported operation.
unsupported-operation,
// Graph is too large.
too-large,
// Graph not found.
not-found,
// The operation is insecure or has insufficient privilege to be performed.
// e.g., cannot access a hardware feature requested
security,
// The operation failed for an unspecified reason.
unknown
}
/// An inference "session" is encapsulated by a `context`; use this to `compute`
/// an inference.
interface inference {
use graph.{graph};
use tensor.{tensor};

resource error {
constructor(code: error-code, data: string);
/// Initialize an inference session with a graph.
///
/// Note that not all graphs are inference-ready (see `prompt`); this
/// function may fail in this case.
init: func(graph: graph) -> result<context, string>;

/// Identify a tensor by name; this is necessary to associate tensors to
/// graph inputs and outputs.
type named-tensor = tuple<string, tensor>;

/// An inference "session."
resource context {
/// Compute an inference request with the given inputs.
compute: func(inputs: list<named-tensor>) -> result<list<named-tensor>, string>;
}
}

/// Return the error code.
code: func() -> error-code;
/// A prompt "session" is encapsulated by a `context`.
interface prompt {
use graph.{graph};

/// Errors can propagated with backend specific status through a string value.
data: func() -> string;
/// Initialize a prompt session with a graph.
///
/// Note that not all graphs are prompt-ready (see `inference`); this
/// function may fail in this case.
init: func(graph: graph) -> result<context, string>;

/// A prompt "session."
resource context {
/// Compute an inference request with the given inputs.
compute: func(prompt: string) -> result<string, string>;
}
}

0 comments on commit 7a9dd1f

Please sign in to comment.