Skip to content

Commit

Permalink
Switched from bytes to bits to avoid precision issues on int4
Browse files Browse the repository at this point in the history
  • Loading branch information
egalli committed Nov 5, 2024
1 parent 959600f commit 63b59ce
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions js/web/lib/wasm/jsep/webnn/tensor-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,19 @@ let tensorGuid = 1;
const createNewTensorId = (): TensorId => tensorGuid++;

/**
* Map from MLOperandDataType to size in bytes.
* Map from MLOperandDataType to size in bits. Using bits instead of bytes to avoid possible precision loss on int4 and uint4.
*/
const webnnDataTypeToSize = new Map<MLOperandDataType, number>([
['float32', 4],
['float16', 2],
['int32', 4],
['uint32', 4],
['int64', 8],
['uint64', 8],
['int8', 1],
['uint8', 1],
['int4', 0.5],
['uint4', 0.5],
['float32', 32],
['float16', 16],
['int32', 32],
['uint32', 32],
['int64', 64],
['uint64', 64],
['int8', 8],
['uint8', 8],
['int4', 4],
['uint4', 4],
]);

/**
Expand All @@ -78,7 +78,7 @@ const calculateByteLength = (dataType: MLOperandDataType, shape: readonly number
if (!size) {
throw new Error('Unsupported data type.');
}
return Math.ceil(shape.reduce((a, b) => a * b) * size);
return Math.ceil((shape.reduce((a, b) => a * b) * size) / 8);
};

/**
Expand Down

0 comments on commit 63b59ce

Please sign in to comment.