diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 50d83f5af26e0..a0010df4643a4 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -13,6 +13,7 @@ import { ProgramManager } from './webgpu/program-manager'; import { AdapterInfo, ComputeContext, + DeviceInfo, GpuArchitecture, GpuData, GpuVendor, @@ -134,6 +135,26 @@ class AdapterInfoImpl implements AdapterInfo { } } +class DeviceInfoImpl implements DeviceInfo { + readonly subgroupsSupported: boolean; + readonly subgroupsF16Supported: boolean; + readonly subgroupSizeRange?: readonly [number, number]; + + constructor(device: GPUDevice) { + this.subgroupsSupported = device.features.has('subgroups' as GPUFeatureName); + this.subgroupsF16Supported = device.features.has('subgroups' as GPUFeatureName); + // Currently subgroups feature is still experimental and size attributes are not in the WebGPU IDL, so we have to + // workaround the IDL type checks. + // TODO: clean this after subgroups feature is settled in IDL. + const deviceSubgroupsLimits = device.limits as { minSubgroupSize?: number; maxSubgroupSize?: number }; + if (!this.subgroupsSupported || !deviceSubgroupsLimits.minSubgroupSize || !deviceSubgroupsLimits.maxSubgroupSize) { + this.subgroupSizeRange = undefined; + } else { + this.subgroupSizeRange = [deviceSubgroupsLimits.minSubgroupSize, deviceSubgroupsLimits.maxSubgroupSize]; + } + } +} + /** * this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as * the first parameter so that it is stored for future use. @@ -141,6 +162,7 @@ class AdapterInfoImpl implements AdapterInfo { export class WebGpuBackend { adapterInfo: AdapterInfoImpl; device: GPUDevice; + deviceInfo: DeviceInfoImpl; /** * an instance of GpuDataManager to manage a GpuDataId -> GpuBuffer mapping */ @@ -243,16 +265,22 @@ export class WebGpuBackend { requiredFeatures, }; - if (adapter.features.has('chromium-experimental-timestamp-query-inside-passes')) { - requiredFeatures.push('chromium-experimental-timestamp-query-inside-passes' as GPUFeatureName); - } else if (adapter.features.has('timestamp-query')) { - requiredFeatures.push('timestamp-query'); + // Try requiring WebGPU features + const requireFeatureIfAvailable = (feature: GPUFeatureName) => + adapter.features.has(feature) && requiredFeatures.push(feature) && true; + // Try chromium-experimental-timestamp-query-inside-passes and fallback to timestamp-query + if (!requireFeatureIfAvailable('chromium-experimental-timestamp-query-inside-passes' as GPUFeatureName)) { + requireFeatureIfAvailable('timestamp-query'); } - if (adapter.features.has('shader-f16')) { - requiredFeatures.push('shader-f16'); + requireFeatureIfAvailable('shader-f16'); + // Try subgroups + if (requireFeatureIfAvailable('subgroups' as GPUFeatureName)) { + // If subgroups feature is available, also try subgroups-f16 + requireFeatureIfAvailable('subgroups-f16' as GPUFeatureName); } this.device = await adapter.requestDevice(deviceDescriptor); + this.deviceInfo = new DeviceInfoImpl(this.device); this.adapterInfo = new AdapterInfoImpl(adapter.info || (await adapter.requestAdapterInfo())); this.gpuDataManager = createGpuDataManager(this); this.programManager = new ProgramManager(this); diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index fddc061cd775a..48bd3ef2bc36f 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -11,7 +11,13 @@ import { WebGpuBackend } from './backend-webgpu'; import { LOG_DEBUG } from './log'; import { TensorView } from './tensor-view'; import { ShapeUtil } from './util'; -import { AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo } from './webgpu/types'; +import { + AdapterInfo, + ComputeContext, + ComputeContextInputsOutputsMapping, + DeviceInfo, + ProgramInfo, +} from './webgpu/types'; import { WebNNBackend } from './backend-webnn'; /* eslint-disable no-bitwise */ @@ -70,6 +76,7 @@ class TensorViewImpl implements TensorView { class ComputeContextImpl implements ComputeContext { readonly adapterInfo: AdapterInfo; + readonly deviceInfo: DeviceInfo; readonly opKernelContext: number; readonly inputs: readonly TensorView[]; readonly outputCount: number; @@ -87,6 +94,7 @@ class ComputeContextImpl implements ComputeContext { contextDataOffset: number, ) { this.adapterInfo = backend.adapterInfo; + this.deviceInfo = backend.deviceInfo; // extract context data const ptrSize = module.PTR_SIZE; @@ -112,18 +120,6 @@ class ComputeContextImpl implements ComputeContext { this.inputs = inputs; } - getMaxComputeWorkgroupSizes(): [number, number, number] { - return [ - this.backend.device.limits.maxComputeWorkgroupSizeX, - this.backend.device.limits.maxComputeWorkgroupSizeY, - this.backend.device.limits.maxComputeWorkgroupSizeZ, - ]; - } - - getMaxComputeWorkgroupStoragesize(): number { - return this.backend.device.limits.maxComputeWorkgroupStorageSize; - } - compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[] { // prepare inputs. inputs should always be valid data. const mappedInputs = diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index c5b8f579c3aae..2c5180c5db3ee 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -93,13 +93,23 @@ export class ProgramManager { build(programInfo: ProgramInfo, normalizedDispatchGroupSize: [number, number, number]): Artifact { TRACE_FUNC_BEGIN(programInfo.name); const device = this.backend.device; - const extensions: string[] = []; - if (device.features.has('shader-f16')) { - extensions.push('enable f16;'); - } + const enableDirectives: string[] = []; + + // Enable WGSL extensions based on available WebGPU features + const extensionsInfo: Array<{ feature: GPUFeatureName; extension: string }> = [ + { feature: 'shader-f16', extension: 'f16' }, + { feature: 'subgroups' as GPUFeatureName, extension: 'subgroups' }, + { feature: 'subgroups-f16' as GPUFeatureName, extension: 'subgroups_f16' }, + ]; + extensionsInfo.forEach((info) => { + if (device.features.has(info.feature)) { + enableDirectives.push(`enable ${info.extension};`); + } + }); + const shaderHelper = createShaderHelper(normalizedDispatchGroupSize, this.backend.device.limits); const userCode = programInfo.getShaderSource(shaderHelper); - const code = `${extensions.join('\n')}\n${shaderHelper.additionalImplementations}\n${userCode}`; + const code = `${enableDirectives.join('\n')}\n${shaderHelper.additionalImplementations}\n${userCode}`; const shaderModule = device.createShaderModule({ code, label: programInfo.name }); LOG_DEBUG('verbose', () => `[WebGPU] ${programInfo.name} shader code: ${code}`); diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index 3b3c55733c973..9321ac170d036 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -21,6 +21,11 @@ export interface AdapterInfo { isArchitecture: (architecture: GpuArchitecture) => boolean; isVendor: (vendor: GpuVendor) => boolean; } +export interface DeviceInfo { + readonly subgroupsSupported: boolean; + readonly subgroupsF16Supported: boolean; + readonly subgroupSizeRange?: readonly [number, number]; +} export interface GpuData { type: GpuDataType; @@ -160,6 +165,11 @@ export interface ComputeContext { */ readonly adapterInfo: AdapterInfo; + /** + * gpu device info + */ + readonly deviceInfo: DeviceInfo; + /** * stores the pointer to OpKernelContext */ @@ -187,8 +197,6 @@ export interface ComputeContext { compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[]; output(index: number, dims: readonly number[]): number; - getMaxComputeWorkgroupSizes(): [number, number, number]; - getMaxComputeWorkgroupStoragesize(): number; } export type TimestampQuery = 'none' | 'inside-passes' | 'at-passes';