diff --git a/source/onnx.js b/source/onnx.js index 4f06b6f36c..aa69b136fc 100644 --- a/source/onnx.js +++ b/source/onnx.js @@ -164,7 +164,57 @@ onnx.ModelFactory = class { async open(context, target) { const open = async (model, format) => { const metadata = await onnx.Metadata.open(context); - return new onnx.Model(metadata, model, format); + const graphs = new Set(); + const queue = [ model.graph ]; + const locations = new Set(); + const tensor = (value) => { + if ((onnx.proto && value instanceof onnx.proto.SparseTensorProto) || + (onnx.schema && value instanceof onnx.schema.SparseTensor)) { + tensor(value.indices); + tensor(value.indices); + } else if (value.data_location === onnx.DataLocation.EXTERNAL && Array.isArray(value.external_data)) { + for (const entry of value.external_data) { + if (entry.key === 'location') { + locations.add(entry.value); + } + } + } + }; + while (queue.length > 0) { + const graph = queue.shift(); + for (const initializer of graph.initializer) { + tensor(initializer); + } + for (const sparse_initializer of graph.sparse_initializer) { + tensor(sparse_initializer); + } + if (Array.isArray(graph.node)) { + for (const node of graph.node) { + if (Array.isArray(node.attribute)) { + for (const attribute of node.attribute) { + if (attribute.g) { + queue.push(attribute.g); + } else if (Array.isArray(attribute.graphs) && attribute.graphs.length > 0) { + queue.push(...attribute.graphs); + } else if (attribute.t) { + tensor(attribute.t); + } else if (Array.isArray(attribute.tensors) && attribute.tensors.length > 0) { + attribute.tensors.every((value) => tensor(value)); + } else if (attribute.sparse_tensor) { + tensor(attribute.sparse_tensor); + } else if (Array.isArray(attribute.sparse_tensors) && attribute.sparse_tensors.length > 0) { + attribute.sparse_tensors.every((value) => tensor(value)); + } + } + } + } + } + graphs.add(graph); + } + const keys = Array.from(locations); + const streams = await Promise.all(keys.map((location) => context.request(location, null))); + const weights = new Map(keys.map((key, index) => [ key, streams[index] ])); + return new onnx.Model(metadata, format, model, Array.from(graphs), weights); }; switch (target) { case 'onnx.pbtxt.ModelProto': @@ -276,7 +326,7 @@ onnx.ModelFactory = class { onnx.Model = class { - constructor(metadata, model, format) { + constructor(metadata, format, model, graphs, locations) { this._graphs = []; this._format = format; this._producer = model.producer_name && model.producer_name.length > 0 ? model.producer_name + (model.producer_version && model.producer_version.length > 0 ? ' ' + model.producer_version : '') : null; @@ -286,7 +336,6 @@ onnx.Model = class { this._description = model.doc_string; this._metadata = []; this._imports = null; - const imports = new Map(); if (model.opset_import && model.opset_import.length > 0) { for (const opset_import of model.opset_import) { @@ -302,7 +351,6 @@ onnx.Model = class { imports.set('ai.onnx', 1); imports.set('ai.onnx.ml', 1); } - let imageFormat = ''; const metadata_props = model.metadata_props; if (metadata_props) { @@ -347,27 +395,14 @@ onnx.Model = class { } imageFormat = [ imageMetadata['Image.BitmapPixelFormat'], imageMetadata['Image.ColorSpaceGamma'], imageMetadata['Image.NominalPixelRange'] ].filter((item) => item); } + metadata = new onnx.GraphMetadata(metadata, imports); + const context = new onnx.ModelContext(metadata, locations, imageFormat); + for (const func of model.functions || []) { + context.metadata.add(new onnx.Function(context, func)); + } this._graphs = []; - if (model && model.graph) { - const graphMetadata = new onnx.GraphMetadata(metadata, imports); - const context = new onnx.ModelContext(graphMetadata, imageFormat); - for (const func of model.functions || []) { - context.metadata.add(new onnx.Function(context, func)); - } - const graphs = [ model.graph ]; - while (graphs.length > 0) { - const graph = graphs.shift(); - this._graphs.push(context.graph(graph)); - for (const node of graph.node || []) { - for (const attribute of node.attribute || []) { - if (attribute.g) { - graphs.push(attribute.g); - } else if (attribute.graphs && attribute.graphs.length > 0) { - graphs.push(...attribute.graphs); - } - } - } - } + for (const graph of graphs) { + this._graphs.push(context.graph(graph)); } } @@ -788,7 +823,7 @@ onnx.Tensor = class { (onnx.schema && tensor instanceof onnx.schema.SparseTensor)) { this._name = tensor.values.name || ''; this._type = context.createTensorType(tensor.values.data_type, tensor.dims.map((dim) => dim), null); - this._location = Array.from(new Set([ context.createLocation(tensor.values.data_location), context.createLocation(tensor.indices.data_location) ])).join(':'); + this._location = context.createLocation(tensor.values.data_location); this._layout = 'sparse'; this._values = new onnx.Tensor(context, tensor.values); this._indices = new onnx.Tensor(context, tensor.indices); @@ -796,96 +831,119 @@ onnx.Tensor = class { this._name = tensor.name || ''; this._type = context.createTensorType(tensor.data_type, tensor.dims.map((dim) => dim), null); this._location = context.createLocation(tensor.data_location); - if (tensor.data_location === onnx.DataLocation.DEFAULT) { - switch (tensor.data_type) { - case onnx.DataType.UNDEFINED: { - break; - } - case onnx.DataType.FLOAT: - this._data = new Float32Array(tensor.float_data); - this._layout = '|'; - break; - case onnx.DataType.DOUBLE: - this._data = new Float64Array(tensor.double_data); - this._layout = '|'; - break; - case onnx.DataType.BOOL: - if (tensor.int32_data && tensor.int32_data.length > 0) { - const array = tensor.int32_data; - this._data = new Array(array.length); - for (let i = 0; i < this._data.length; i++) { - this._data[i] = array[i] === 0 ? false : true; + switch (tensor.data_location) { + case onnx.DataLocation.DEFAULT: { + switch (tensor.data_type) { + case onnx.DataType.UNDEFINED: { + break; + } + case onnx.DataType.FLOAT: + this._data = new Float32Array(tensor.float_data); + this._layout = '|'; + break; + case onnx.DataType.DOUBLE: + this._data = new Float64Array(tensor.double_data); + this._layout = '|'; + break; + case onnx.DataType.BOOL: + if (tensor.int32_data && tensor.int32_data.length > 0) { + const array = tensor.int32_data; + this._data = new Array(array.length); + for (let i = 0; i < this._data.length; i++) { + this._data[i] = array[i] === 0 ? false : true; + } + this._layout = '|'; } + break; + case onnx.DataType.INT8: + this._data = new Int8Array(tensor.int32_data); this._layout = '|'; - } - break; - case onnx.DataType.INT8: - this._data = new Int8Array(tensor.int32_data); - this._layout = '|'; - break; - case onnx.DataType.UINT8: - this._data = new Uint8Array(tensor.int32_data); - this._layout = '|'; - break; - case onnx.DataType.INT16: - this._data = new Int32Array(tensor.int32_data); - this._layout = '|'; - break; - case onnx.DataType.UINT16: - this._data = new Int32Array(tensor.int32_data); - this._layout = '|'; - break; - case onnx.DataType.INT32: - this._data = new Int32Array(tensor.int32_data); - this._layout = '|'; - break; - case onnx.DataType.UINT32: - case onnx.DataType.UINT64: - this._data = tensor.uint64_data; - this._layout = '|'; - break; - case onnx.DataType.INT64: - this._data = tensor.int64_data; - this._layout = '|'; - break; - case onnx.DataType.STRING: - this._data = tensor.string_data; - this._layout = '|'; - break; - case onnx.DataType.COMPLEX64: - case onnx.DataType.COMPLEX128: - break; - case onnx.DataType.FLOAT16: - case onnx.DataType.BFLOAT16: - if (tensor.int32_data && tensor.int32_data.length > 0) { - const array = tensor.int32_data; - const buffer = new Uint8Array(array.length << 1); - const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength); - for (let i = 0; i < array.length; i++) { - view.setUint16(i << 1, array[i], true); + break; + case onnx.DataType.UINT8: + this._data = new Uint8Array(tensor.int32_data); + this._layout = '|'; + break; + case onnx.DataType.INT16: + this._data = new Int32Array(tensor.int32_data); + this._layout = '|'; + break; + case onnx.DataType.UINT16: + this._data = new Int32Array(tensor.int32_data); + this._layout = '|'; + break; + case onnx.DataType.INT32: + this._data = new Int32Array(tensor.int32_data); + this._layout = '|'; + break; + case onnx.DataType.UINT32: + case onnx.DataType.UINT64: + this._data = tensor.uint64_data; + this._layout = '|'; + break; + case onnx.DataType.INT64: + this._data = tensor.int64_data; + this._layout = '|'; + break; + case onnx.DataType.STRING: + this._data = tensor.string_data; + this._layout = '|'; + break; + case onnx.DataType.COMPLEX64: + case onnx.DataType.COMPLEX128: + break; + case onnx.DataType.FLOAT16: + case onnx.DataType.BFLOAT16: + if (tensor.int32_data && tensor.int32_data.length > 0) { + const array = tensor.int32_data; + const buffer = new Uint8Array(array.length << 1); + const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength); + for (let i = 0; i < array.length; i++) { + view.setUint16(i << 1, array[i], true); + } + this._data = buffer; + this._layout = '<'; + } + break; + case onnx.DataType.FLOAT8E4M3FN: + case onnx.DataType.FLOAT8E4M3FNUZ: + case onnx.DataType.FLOAT8E5M2: + case onnx.DataType.FLOAT8E5M2FNUZ: + if (tensor.int32_data && tensor.int32_data.length > 0) { + this._data = new Uint8Array(Array.from(tensor.int32_data)); + this._layout = '<'; } - this._data = buffer; - this._layout = '<'; + break; + default: + throw new onnx.Error("Unsupported tensor data type '" + tensor.data_type + "'."); + } + if (this._data && (Array.isArray(this._data) || ArrayBuffer.isView(this._data)) && this._data.length === 0) { + this._data = undefined; + } + if (!this._data && tensor.raw_data && tensor.raw_data.length > 0) { + this._data = tensor.raw_data; + this._layout = '<'; + } + break; + } + case onnx.DataLocation.EXTERNAL: { + if (Array.isArray(tensor.external_data)) { + const external_data = {}; + for (const entry of tensor.external_data) { + external_data[entry.key] = entry.value; } - break; - case onnx.DataType.FLOAT8E4M3FN: - case onnx.DataType.FLOAT8E4M3FNUZ: - case onnx.DataType.FLOAT8E5M2: - case onnx.DataType.FLOAT8E5M2FNUZ: - if (tensor.int32_data && tensor.int32_data.length > 0) { - this._data = new Uint8Array(Array.from(tensor.int32_data)); - this._layout = '<'; + if (external_data.location && external_data.offset && external_data.length) { + const offset = parseInt(external_data.offset, 10); + const length = parseInt(external_data.length, 10); + if (Number.isInteger(offset) && Number.isInteger(length)) { + this._data = context.location(external_data.location, offset, length); + this._layout = '<'; + } } - break; - default: - throw new onnx.Error("Unsupported tensor data type '" + tensor.data_type + "'."); - } - if (this._data && (Array.isArray(this._data) || ArrayBuffer.isView(this._data)) && this._data.length === 0) { - this._data = undefined; + } + break; } - if (!this._data && tensor.raw_data && tensor.raw_data.length > 0) { - this._data = tensor.raw_data; - this._layout = '<'; + default: { + throw new Error(); } } } @@ -912,9 +970,21 @@ onnx.Tensor = class { } get values() { - return this._layout === 'sparse' ? this._values : this._data; + switch (this._layout) { + case 'sparse': { + return this._values; + } + default: { + if (!this._data || this._data instanceof Uint8Array) { + return this._data; + } + if (Array.isArray(this._data) || ArrayBuffer.isView(this._data)) { + return this._data; + } + return this._data.peek(); + } + } } - }; onnx.TensorType = class { @@ -1289,8 +1359,9 @@ onnx.AttributeType = { onnx.ModelContext = class { - constructor(metadata, imageFormat) { + constructor(metadata, locations, imageFormat) { this._metadata = metadata; + this._locations = locations; this._imageFormat = imageFormat; this._graphs = new Map(); } @@ -1303,6 +1374,20 @@ onnx.ModelContext = class { return this._imageFormat; } + location(name, offset, length) { + if (this._locations.has(name)) { + const stream = this._locations.get(name); + if (offset < stream.length && (offset + length) < stream.length) { + const position = stream.position; + stream.seek(offset); + const value = stream.stream(length); + stream.seek(position); + return value; + } + } + return this._locations; + } + graph(value) { if (!this._graphs.has(value)) { this._graphs.set(value, new onnx.Graph(this, value)); @@ -1374,6 +1459,10 @@ onnx.GraphContext = class { return this._tensors.get(name); } + location(name, offset, length) { + return this._context.location(name, offset, length); + } + group(name) { if (!this._groups.has(name)) { const path = name.split('/');