diff --git a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts index 35b1640afa266..a19afd4bac732 100644 --- a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts +++ b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts @@ -55,19 +55,19 @@ let tensorGuid = 1; const createNewTensorId = (): TensorId => tensorGuid++; /** - * Map from MLOperandDataType to size in bytes. + * Map from MLOperandDataType to size in bits. Using bits instead of bytes to avoid possible precision loss on int4 and uint4. */ const webnnDataTypeToSize = new Map([ - ['float32', 4], - ['float16', 2], - ['int32', 4], - ['uint32', 4], - ['int64', 8], - ['uint64', 8], - ['int8', 1], - ['uint8', 1], - ['int4', 0.5], - ['uint4', 0.5], + ['float32', 32], + ['float16', 16], + ['int32', 32], + ['uint32', 32], + ['int64', 64], + ['uint64', 64], + ['int8', 8], + ['uint8', 8], + ['int4', 4], + ['uint4', 4], ]); /** @@ -78,7 +78,7 @@ const calculateByteLength = (dataType: MLOperandDataType, shape: readonly number if (!size) { throw new Error('Unsupported data type.'); } - return Math.ceil(shape.reduce((a, b) => a * b) * size); + return Math.ceil((shape.reduce((a, b) => a * b) * size) / 8); }; /**