Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JS/WebGPU] Creating devices with subgroup features enabled if possible #21833

Merged
merged 6 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 59 additions & 6 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { ProgramManager } from './webgpu/program-manager';
import {
AdapterInfo,
ComputeContext,
DeviceInfo,
GpuArchitecture,
GpuData,
GpuVendor,
Expand Down Expand Up @@ -134,13 +135,59 @@ class AdapterInfoImpl implements AdapterInfo {
}
}

class DeviceInfoImpl implements DeviceInfo {
readonly deviceLimits: GPUSupportedLimits;
readonly deviceFeatures: GPUSupportedFeatures;

constructor(device: GPUDevice) {
this.deviceLimits = device.limits;
this.deviceFeatures = device.features;
}

get maxComputeWorkgroupSizes(): [number, number, number] {
return [
this.deviceLimits.maxComputeWorkgroupSizeX,
this.deviceLimits.maxComputeWorkgroupSizeY,
this.deviceLimits.maxComputeWorkgroupSizeZ,
];
}

get maxComputeWorkgroupStoragesize(): number {
return this.deviceLimits.maxComputeWorkgroupStorageSize;
}

get isSubgroupsSupported(): boolean {
return this.deviceFeatures.has('subgroups' as GPUFeatureName);
}

get isSubgroupsF16Supported(): boolean {
return this.deviceFeatures.has('subgroups-f16' as GPUFeatureName);
}

get subgroupSizeRange(): [number, number] | undefined {
// Currently subgroups feature is still experimental and size attributes are not in the WebGPU IDL, so we have to
// workaround the IDL type checks.
// TODO: clean this after subgroups feature is sattled in IDL.
const deviceSubgroupsLimits = this.deviceLimits as { minSubgroupSize?: number; maxSubgroupSize?: number };
if (
!this.isSubgroupsSupported ||
!deviceSubgroupsLimits.minSubgroupSize ||
!deviceSubgroupsLimits.maxSubgroupSize
) {
return undefined;
}
return [deviceSubgroupsLimits.minSubgroupSize, deviceSubgroupsLimits.maxSubgroupSize];
}
}

/**
* this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as
* the first parameter so that it is stored for future use.
*/
export class WebGpuBackend {
adapterInfo: AdapterInfoImpl;
device: GPUDevice;
deviceInfo: DeviceInfoImpl;
/**
* an instance of GpuDataManager to manage a GpuDataId -> GpuBuffer mapping
*/
Expand Down Expand Up @@ -243,16 +290,22 @@ export class WebGpuBackend {
requiredFeatures,
};

if (adapter.features.has('chromium-experimental-timestamp-query-inside-passes')) {
requiredFeatures.push('chromium-experimental-timestamp-query-inside-passes' as GPUFeatureName);
} else if (adapter.features.has('timestamp-query')) {
requiredFeatures.push('timestamp-query');
// Try requiring WebGPU features
const requireFeatureIfAvailable = (feature: GPUFeatureName) =>
adapter.features.has(feature) && requiredFeatures.push(feature) && true;
// Try chromium-experimental-timestamp-query-inside-passes and fallback to timestamp-query
if (!requireFeatureIfAvailable('chromium-experimental-timestamp-query-inside-passes' as GPUFeatureName)) {
requireFeatureIfAvailable('timestamp-query');
}
if (adapter.features.has('shader-f16')) {
requiredFeatures.push('shader-f16');
requireFeatureIfAvailable('shader-f16');
// Try subgroups
if (requireFeatureIfAvailable('subgroups' as GPUFeatureName)) {
// If subgroups feature is available, also try subgroups-f16
requireFeatureIfAvailable('subgroups-f16' as GPUFeatureName);
}

this.device = await adapter.requestDevice(deviceDescriptor);
this.deviceInfo = new DeviceInfoImpl(this.device);
this.adapterInfo = new AdapterInfoImpl(adapter.info || (await adapter.requestAdapterInfo()));
this.gpuDataManager = createGpuDataManager(this);
this.programManager = new ProgramManager(this);
Expand Down
22 changes: 9 additions & 13 deletions js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@ import { WebGpuBackend } from './backend-webgpu';
import { LOG_DEBUG } from './log';
import { TensorView } from './tensor-view';
import { ShapeUtil } from './util';
import { AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo } from './webgpu/types';
import {
AdapterInfo,
ComputeContext,
ComputeContextInputsOutputsMapping,
DeviceInfo,
ProgramInfo,
} from './webgpu/types';
import { WebNNBackend } from './backend-webnn';

/* eslint-disable no-bitwise */
Expand Down Expand Up @@ -70,6 +76,7 @@ class TensorViewImpl implements TensorView {

class ComputeContextImpl implements ComputeContext {
readonly adapterInfo: AdapterInfo;
readonly deviceInfo: DeviceInfo;
readonly opKernelContext: number;
readonly inputs: readonly TensorView[];
readonly outputCount: number;
Expand All @@ -87,6 +94,7 @@ class ComputeContextImpl implements ComputeContext {
contextDataOffset: number,
) {
this.adapterInfo = backend.adapterInfo;
this.deviceInfo = backend.deviceInfo;

// extract context data
const ptrSize = module.PTR_SIZE;
Expand All @@ -112,18 +120,6 @@ class ComputeContextImpl implements ComputeContext {
this.inputs = inputs;
}

getMaxComputeWorkgroupSizes(): [number, number, number] {
return [
this.backend.device.limits.maxComputeWorkgroupSizeX,
this.backend.device.limits.maxComputeWorkgroupSizeY,
this.backend.device.limits.maxComputeWorkgroupSizeZ,
];
}

getMaxComputeWorkgroupStoragesize(): number {
return this.backend.device.limits.maxComputeWorkgroupStorageSize;
}

compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[] {
// prepare inputs. inputs should always be valid data.
const mappedInputs =
Expand Down
20 changes: 15 additions & 5 deletions js/web/lib/wasm/jsep/webgpu/program-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,23 @@ export class ProgramManager {
build(programInfo: ProgramInfo, normalizedDispatchGroupSize: [number, number, number]): Artifact {
TRACE_FUNC_BEGIN(programInfo.name);
const device = this.backend.device;
const extensions: string[] = [];
if (device.features.has('shader-f16')) {
extensions.push('enable f16;');
}
const enableDirectives: string[] = [];

// Enable WGSL extensions based on available WebGPU features
const extensionsInfo: Array<{ feature: GPUFeatureName; extension: string }> = [
{ feature: 'shader-f16', extension: 'f16' },
{ feature: 'subgroups' as GPUFeatureName, extension: 'subgroups' },
{ feature: 'subgroups-f16' as GPUFeatureName, extension: 'subgroups_f16' },
];
extensionsInfo.forEach((info) => {
if (device.features.has(info.feature)) {
enableDirectives.push(`enable ${info.extension};`);
}
});

const shaderHelper = createShaderHelper(normalizedDispatchGroupSize, this.backend.device.limits);
const userCode = programInfo.getShaderSource(shaderHelper);
const code = `${extensions.join('\n')}\n${shaderHelper.additionalImplementations}\n${userCode}`;
const code = `${enableDirectives.join('\n')}\n${shaderHelper.additionalImplementations}\n${userCode}`;
const shaderModule = device.createShaderModule({ code, label: programInfo.name });
LOG_DEBUG('verbose', () => `[WebGPU] ${programInfo.name} shader code: ${code}`);

Expand Down
14 changes: 12 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ export interface AdapterInfo {
isArchitecture: (architecture: GpuArchitecture) => boolean;
isVendor: (vendor: GpuVendor) => boolean;
}
export interface DeviceInfo {
get maxComputeWorkgroupSizes(): [number, number, number];
get maxComputeWorkgroupStoragesize(): number;
get isSubgroupsSupported(): boolean;
get isSubgroupsF16Supported(): boolean;
get subgroupSizeRange(): [number, number] | undefined;
}

export interface GpuData {
type: GpuDataType;
Expand Down Expand Up @@ -160,6 +167,11 @@ export interface ComputeContext {
*/
readonly adapterInfo: AdapterInfo;

/**
* gpu device info
*/
readonly deviceInfo: DeviceInfo;

/**
* stores the pointer to OpKernelContext
*/
Expand Down Expand Up @@ -187,8 +199,6 @@ export interface ComputeContext {

compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[];
output(index: number, dims: readonly number[]): number;
getMaxComputeWorkgroupSizes(): [number, number, number];
getMaxComputeWorkgroupStoragesize(): number;
}

export type TimestampQuery = 'none' | 'inside-passes' | 'at-passes';