Skip to content

Commit

Permalink
Allow shared array buffer view for MLGraphBuilder.constant()
Browse files Browse the repository at this point in the history
Fix #788
  • Loading branch information
huningxin committed Nov 16, 2024
1 parent c237cf1 commit 830e747
Showing 1 changed file with 11 additions and 23 deletions.
34 changes: 11 additions & 23 deletions index.bs
Original file line number Diff line number Diff line change
Expand Up @@ -1341,7 +1341,8 @@ interface MLGraphBuilder {
MLOperand input(USVString name, MLOperandDescriptor descriptor);

// Create an operand for a graph constant.
MLOperand constant(MLOperandDescriptor descriptor, ArrayBufferView bufferView);
MLOperand constant(MLOperandDescriptor descriptor,
[AllowShared] ArrayBufferView bufferView);

// Create a scalar operand from the specified number of the specified type.
MLOperand constant(MLOperandDataType type, MLNumber value);
Expand Down Expand Up @@ -2012,8 +2013,7 @@ partial dictionary MLOpSupportLimits {
input, builder.constant(input.dataType, options.minValue));
} else {
return builder.min(
builder.max(
input, builder.constant(input.dataType, options.minValue)),
builder.max(input, builder.constant(input.dataType, options.minValue)),
builder.constant(input.dataType, options.maxValue));
}
}
Expand Down Expand Up @@ -3421,8 +3421,8 @@ partial dictionary MLOpSupportLimits {
{shape: [4, 3]},
new Float32Array([0, 1, 2, 10, 11, 12, 20, 21, 22, 30, 31, 32]));

const indices1 = builder.constant(
{dataType: 'uint32', shape: [2]}, new Uint32Array([3, 1]));
const indices1 =
builder.constant({dataType: 'uint32', shape: [2]}, new Uint32Array([3, 1]));

const indices2 = builder.constant(
{dataType: 'uint32', shape: [3]}, new Uint32Array([2, 1, 1]));
Expand Down Expand Up @@ -3937,10 +3937,7 @@ partial dictionary MLOpSupportLimits {
let hiddenState = options.initialHiddenState;

if (!hiddenState) {
const desc = {
dataType: 'float32',
shape: [numDirections, 1, hiddenSize]
};
const desc = {dataType: 'float32', shape: [numDirections, 1, hiddenSize]};
const totalSize = numDirections * hiddenSize;
hiddenState = builder.constant(desc, new Float32Array(totalSize).fill(0));
}
Expand Down Expand Up @@ -4619,8 +4616,7 @@ partial dictionary MLOpSupportLimits {
const reduceOptions = {axes: [2, 3], keepDimensions: true};
const mean = builder.reduceMean(input, reduceOptions);
const variance = builder.reduceMean(
builder.pow(
builder.sub(input, mean), builder.constant(input.dataType, 2)),
builder.pow(builder.sub(input, mean), builder.constant(input.dataType, 2)),
reduceOptions);

// The scale and bias values are applied per input feature
Expand Down Expand Up @@ -4765,8 +4761,7 @@ partial dictionary MLOpSupportLimits {
const reduceOptions = {axes: [1, 2, 3], keepDimensions: true};
const mean = builder.reduceMean(input, reduceOptions);
const variance = builder.reduceMean(
builder.pow(
builder.sub(input, mean), builder.constant(input.dataType, 2)),
builder.pow(builder.sub(input, mean), builder.constant(input.dataType, 2)),
reduceOptions);

// The scale and bias tensors are of the shape of the input
Expand Down Expand Up @@ -5222,19 +5217,13 @@ partial dictionary MLOpSupportLimits {
let cellState = options.initialCellState;

if (!hiddenState) {
const desc = {
dataType: 'float32',
shape: [numDirections, 1, hiddenSize]
};
const desc = {dataType: 'float32', shape: [numDirections, 1, hiddenSize]};
const totalSize = numDirections * hiddenSize;
hiddenState = builder.constant(desc, new Float32Array(totalSize).fill(0));
}

if (!cellState) {
const desc = {
dataType: 'float32',
shape: [numDirections, 1, hiddenSize]
};
const desc = {dataType: 'float32', shape: [numDirections, 1, hiddenSize]};
const totalSize = numDirections * hiddenSize;
cellState = builder.constant(desc, new Float32Array(totalSize).fill(0));
}
Expand Down Expand Up @@ -5878,8 +5867,7 @@ partial dictionary MLOpSupportLimits {
<pre highlight="js">
// input: [[1,2,3], [4,5,6]]
const input = builder.constant(
{dataType: 'float32', shape: [2, 3]},
new Float32Array([1, 2, 3, 4, 5, 6]));
{dataType: 'float32', shape: [2, 3]}, new Float32Array([1, 2, 3, 4, 5, 6]));

const beginningPadding = [1, 2];
const endingPadding = [1, 2];
Expand Down

0 comments on commit 830e747

Please sign in to comment.