Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] Optimize ConvTranspose #22774

Merged
merged 4 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
313 changes: 83 additions & 230 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,229 +29,27 @@ import {
ShaderHelper,
tensorTypeToWsglStorageType,
UniformsArrayType,
getMaxComponents,
} from '../common';
import { ConvTransposeAttributes } from '../conv-transpose';

const createConvTranspose2DOpProgramShaderSource = (
shaderHelper: ShaderHelper,
inputs: readonly TensorView[],
outputShape: readonly number[],
hasBias: boolean,
is1DimensionDispatch: boolean,
isVec4 = false,
dataType: string,
uniforms: UniformsArrayType,
isChannelsLast = false,
): string => {
const rowDim = isChannelsLast ? 1 : 2;
const colDim = isChannelsLast ? 2 : 3;
const channelDim = isChannelsLast ? 3 : 1;
const workPerThread = isVec4 ? 2 : 1;

let declareFunctions = `
fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) {
result[flatIndex] = ${isVec4 ? `vec4<${dataType}>` : dataType}(value);
}`;
if (hasBias) {
declareFunctions += `
fn getBiasByOutputCoords(coords : vec4<u32>) -> ${isVec4 ? `vec4<${dataType}>` : dataType} {
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
}`;
}
const components = isVec4 ? 4 : 1;
const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components);
const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims.length, components);
const inputVariables = [dy, w];
if (hasBias) {
inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]].length, components));
}
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);

const codeSnippet4 = `{
let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / uniforms.result_shape[1];
let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % uniforms.result_shape[1];
let c = ${is1DimensionDispatch ? 'global_id.y' : 'workgroup_id.y'} * ${workPerThread};
let d1: u32 = ${is1DimensionDispatch ? 'global_id.x' : 'workgroup_id.x'} * 4;

let dyCorner = vec2<i32>(i32(r), i32(c)) - vec2<i32>(uniforms.pads);

// Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
// ? = to be determined. : = across all values in that axis.
var dotProd: array<vec4<${dataType}>, ${workPerThread}>;
for (var i = 0; i < ${workPerThread}; i++) {
dotProd[i] = vec4<${dataType}>(0.0);
}
for (var wR: u32 = 0; wR < uniforms.filter_dims[0]; wR = wR + 1) {
var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(uniforms.strides.x);
let wRPerm = uniforms.filter_dims[0] - 1 - wR;
if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[1]) ||
fract(dyR) > 0.0 || wRPerm < 0) {
continue;
}
let idyR: u32 = u32(dyR);

for (var wC: u32 = 0; wC < uniforms.filter_dims[1]; wC = wC + 1) {
let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y);
let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(uniforms.strides.y);
let wCPerm = uniforms.filter_dims[1] - 1 - wC;
if (wCPerm < 0) {
continue;
}
var bDyCVal = true;
var bDyCVal2 = true;
if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[2]) ||
fract(dyC) > 0.0) {
bDyCVal = false;
}
if (dyC2 < 0.0 || dyC2 >= ${dataType}(uniforms.Dy_shape[2]) ||
fract(dyC2) > 0.0) {
bDyCVal2 = false;
}

let idyC: u32 = u32(dyC);
let idyC2: u32 = u32(dyC2);
if (bDyCVal && bDyCVal2) {
let d2Length = uniforms.Dy_shape[3];
for (var d2 :u32 = 0; d2 < d2Length; d2 = d2 + 4) {
let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')};
let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')};
let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')};
let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')};

var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')};
let tmpval = vec4<${dataType}>(dot(xValue, wValue0),
dot(xValue, wValue1),
dot(xValue, wValue2),
dot(xValue, wValue3));
dotProd[0] = dotProd[0] + tmpval;

xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')};

dotProd[1] = dotProd[1] + vec4<${dataType}>(dot(xValue, wValue0),
dot(xValue, wValue1),
dot(xValue, wValue2),
dot(xValue, wValue3));
}
} else if (bDyCVal) {
let d2Length = uniforms.Dy_shape[${channelDim}];
for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) {
let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')};
let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')};
let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')};
let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')};

var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')};
let tmpval = vec4<${dataType}>(dot(xValue, wValue0),
dot(xValue, wValue1),
dot(xValue, wValue2),
dot(xValue, wValue3));
dotProd[0] = dotProd[0] + tmpval;
}
} else if (bDyCVal2) {
let d2Length = uniforms.Dy_shape[3];
for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) {
let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')};
let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')};
let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')};
let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')};

var xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')};
let tmpval = vec4<${dataType}>(dot(xValue, wValue0),
dot(xValue, wValue1),
dot(xValue, wValue2),
dot(xValue, wValue3));
dotProd[1] = dotProd[1] + tmpval;
}
}
}
}

for (var i: u32 = 0; i < ${workPerThread}; i = i + 1) {
let value = dotProd[i] + ${hasBias ? 'bias[c+i]' : `vec4<${dataType}>(0.0)`};
${output.set('batch', 'r', 'c + i', 'd1', 'value')};
}
}`;
const codeSnippet = `
let outputIndices = ${output.offsetToIndices('global_idx')};
let batch = ${output.indicesGet('outputIndices', 0)};
let d1 = ${output.indicesGet('outputIndices', channelDim)};
let r = ${output.indicesGet('outputIndices', rowDim)};
let c = ${output.indicesGet('outputIndices', colDim)};
let dyCorner = vec2<i32>(i32(r), i32(c)) - uniforms.pads;
let dyRCorner = dyCorner.x;
let dyCCorner = dyCorner.y;
let groupId = d1 / uniforms.output_channels_per_group;
let wOutChannel = d1 - groupId * uniforms.output_channels_per_group;
// Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
// ? = to be determined. : = across all values in that axis.
var dotProd = ${dataType}(0.0);
for (var wR: u32 = 0; wR < uniforms.effective_filter_dims.x; wR = wR + 1) {
if (wR % uniforms.dilations.x != 0) {
continue;
}
let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(uniforms.strides[0]);
let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x;
if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[${rowDim}]) || fract(dyR) > 0.0 ||
wRPerm < 0) {
continue;
}
let idyR: u32 = u32(dyR);

for (var wC: u32 = 0; wC < uniforms.effective_filter_dims.y; wC = wC + 1) {
if (wC % uniforms.dilations.y != 0) {
continue;
}
let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y);
let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y;
if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[${colDim}]) ||
fract(dyC) > 0.0 || wCPerm < 0) {
continue;
}
let idyC: u32 = u32(dyC);
var inputChannel = groupId * uniforms.input_channels_per_group;
for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + 1) {
let xValue = ${
isChannelsLast
? dy.get('batch', 'idyR', 'idyC', 'inputChannel')
: dy.get('batch', 'inputChannel', 'idyR', 'idyC')
};
let wValue = ${w.get('inputChannel', 'wOutChannel', 'u32(wRPerm)', 'u32(wCPerm)')};
dotProd = dotProd + xValue * wValue;
inputChannel = inputChannel + 1;
}
}
}
let value = dotProd + ${hasBias ? 'bias[d1]' : `${dataType}(0.0)`};
${output.setByOffset('global_idx', 'value')};
`;

return `
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
${declareFunctions}

${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')};
${isVec4 ? codeSnippet4 : codeSnippet}}`;
};

export const createConvTranspose2DProgramInfo = (
inputs: readonly TensorView[],
attributes: ConvTransposeAttributes,
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
): ProgramInfo => {
const hasBias = inputs.length > 2;
// const isChannelsLast = attributes.format === 'NHWC';
const outputShape = attributes.outputShape;
const outputSize = ShapeUtil.size(outputShape);

// const inChannels = inputs[0].dims[isChannelsLast ? 3 : 1];
// TODO Enable isVec4 for performance
// Disabled due to weight matrix layout issue
// const isVec4 = attributes.group === 1 && isChannelsLast && inChannels % 4 === 0 && outChannels % 4 === 0;
const isChannelsLast = attributes.format === 'NHWC';
const group = attributes.group;
const wShape = inputs[1].dims;
const inputChannelsPerGroup = wShape[2] / group;
const outputChannelsPerGroup = wShape[3];
const components = isChannelsLast ? getMaxComponents(outputChannelsPerGroup) : 1;
const outputSize = ShapeUtil.size(outputShape) / components;
const dispatch = [Math.ceil(outputSize / 64), 1, 1];
LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`);

const isChannelsLast = attributes.format === 'NHWC';
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
const strides = [attributes.strides[0], attributes.strides[1]];
const filterDims = [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]];
Expand All @@ -268,15 +66,9 @@ export const createConvTranspose2DProgramInfo = (
];
const pads = [
effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2),
effectiveFilterDims[1] - 1 - Math.floor(attributes.pads[1] + attributes.pads[3]) / 2,
effectiveFilterDims[1] - 1 - Math.floor((attributes.pads[1] + attributes.pads[3]) / 2),
];

const isVec4 = false;
const group = attributes.group;
const wShape = inputs[1].dims;
const inputChannelsPerGroup = wShape[0] / group;
const outputChannelsPerGroup = wShape[1];

const programUniforms: ProgramUniform[] = [
{ type: DataType.uint32, data: outputSize },
{ type: DataType.uint32, data: strides },
Expand All @@ -294,7 +86,6 @@ export const createConvTranspose2DProgramInfo = (
}
programUniforms.push(...createTensorShapeVariables(outputShape));

const is1DimensionDispatch = dispatch[1] === 1 && dispatch[2] === 1;
const getShaderSource = (shaderHelper: ShaderHelper) => {
const uniforms: UniformsArrayType = [
{ name: 'output_size', type: 'u32' },
Expand All @@ -307,21 +98,83 @@ export const createConvTranspose2DProgramInfo = (
{ name: 'output_channels_per_group', type: 'u32' },
];
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
return `${createConvTranspose2DOpProgramShaderSource(
shaderHelper,
inputs,
outputShape,
hasBias,
is1DimensionDispatch,
isVec4,
dataType,
uniforms,
isChannelsLast,
)}`;
const rowDim = isChannelsLast ? 1 : 2;
const colDim = isChannelsLast ? 2 : 3;
const channelDim = isChannelsLast ? 3 : 1;

const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components);
const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims.length);
const inputVariables = [dy, w];
if (hasBias) {
inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]].length, components));
}
const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);

const codeSnippet = `
let outputIndices = ${output.offsetToIndices(`global_idx * ${components}`)};
let batch = ${output.indicesGet('outputIndices', 0)};
let d1 = ${output.indicesGet('outputIndices', channelDim)};
let r = ${output.indicesGet('outputIndices', rowDim)};
let c = ${output.indicesGet('outputIndices', colDim)};
let dyCorner = vec2<i32>(i32(r), i32(c)) - uniforms.pads;
let dyRCorner = dyCorner.x;
let dyCCorner = dyCorner.y;
let groupId = d1 / uniforms.output_channels_per_group;
let wOutChannel = d1 - groupId * uniforms.output_channels_per_group;
// Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
// ? = to be determined. : = across all values in that axis.
var dotProd = ${output.type.value}(0.0);
for (var wR: u32 = 0; wR < uniforms.effective_filter_dims.x; wR = wR + 1) {
if (wR % uniforms.dilations.x != 0) {
continue;
}
let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(uniforms.strides[0]);
let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x;
if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[${rowDim}]) || fract(dyR) > 0.0 ||
wRPerm < 0) {
continue;
}
let idyR: u32 = u32(dyR);

for (var wC: u32 = 0; wC < uniforms.effective_filter_dims.y; wC = wC + 1) {
if (wC % uniforms.dilations.y != 0) {
continue;
}
let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y);
let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y;
if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[${colDim}]) ||
fract(dyC) > 0.0 || wCPerm < 0) {
continue;
}
let idyC: u32 = u32(dyC);
var inputChannel = groupId * uniforms.input_channels_per_group;
for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + 1) {
let xValue = ${
isChannelsLast
? dy.get('batch', 'idyR', 'idyC', 'inputChannel')
: dy.get('batch', 'inputChannel', 'idyR', 'idyC')
};
let w_offset = ${w.indicesToOffset(`${w.type.indices}(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)`)};
let wValue = ${w.getByOffset(`w_offset / ${components}`)};
dotProd = dotProd + xValue * wValue;
inputChannel = inputChannel + 1;
}
}
}
let value = dotProd${hasBias ? ` + bias[d1 / ${components}]` : ''};
${output.setByOffset('global_idx', 'value')};
`;

return `
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')};
${codeSnippet}}`;
};

return {
name: 'ConvTranspose2D',
shaderCache: { hint: `${attributes.cacheKey};`, inputDependencies },
shaderCache: { hint: `${attributes.cacheKey};${components}`, inputDependencies },
getRunData: () => ({
dispatchGroup: { x: dispatch[0], y: dispatch[1], z: dispatch[2] },
outputs: [
Expand Down
Loading
Loading