Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use '-e webgpu' to generate a model for webgpu #1278

Merged
merged 3 commits into from
Feb 26, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
"tunable_op_tuning_enable": "1",
},
"dml": {},
"web": {},
"webgpu": {},
}

# Map input names to their types and shapes
Expand Down Expand Up @@ -246,6 +246,8 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
("cuda", TensorProto.FLOAT16),
("rocm", TensorProto.FLOAT16),
("dml", TensorProto.FLOAT16),
("webgpu", TensorProto.FLOAT16),
("webgpu", TensorProto.FLOAT),
]
if (self.ep, self.io_dtype) in valid_gqa_configurations:
# Change model settings for GroupQueryAttention
Expand All @@ -254,10 +256,10 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):

# DML doesn't support packed Q/K/V for GQA yet
# Packed MatMul with LoRA/QLoRA is not currently supported
self.attention_attrs["use_packed_matmul"] = self.ep != "dml" and not self.matmul_attrs["use_lora"]
self.attention_attrs["use_packed_matmul"] = self.ep not in ["dml", "webgpu"] and not self.matmul_attrs["use_lora"]

# GQA + Rot.Emb. does not require `position ids` as input
if self.ep != "dml":
if self.ep not in ["dml", "webgpu"]:
self.attention_attrs["use_rotemb_in_attn"] = True
self.input_names.remove("position_ids")

Expand Down Expand Up @@ -3195,7 +3197,8 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid
config.update(peft_config.__dict__)

# Set input/output precision of ONNX model
io_dtype = TensorProto.FLOAT if precision in {"int8", "fp32"} or (precision == "int4" and execution_provider == "cpu") else TensorProto.FLOAT16
use_webgpu_fp32 = extra_options.get("use_webgpu_fp32", "0") == "1"
io_dtype = TensorProto.FLOAT if precision in {"int8", "fp32"} or (precision == "int4" and execution_provider == "cpu") or use_webgpu_fp32 else TensorProto.FLOAT16

if "config_only" not in extra_options:
# List architecture options in alphabetical order
Expand Down Expand Up @@ -3299,8 +3302,8 @@ def get_args():
"-e",
"--execution_provider",
required=True,
choices=["cpu", "cuda", "rocm", "dml", "web"],
help="Execution provider to target with precision of model (e.g. FP16 CUDA, INT4 CPU, INT4 WEB)",
choices=["cpu", "cuda", "rocm", "dml", "webgpu"],
help="Execution provider to target with precision of model (e.g. FP16 CUDA, INT4 CPU, INT4 WEBGPU)",
)

parser.add_argument(
Expand Down Expand Up @@ -3358,6 +3361,8 @@ def get_args():
If true, the QMoE op will use 8-bit quantization. If false, the QMoE op will use 4-bit quantization.
use_qdq = Use the QDQ decomposition for ops.
Use this option when you want to use quantize-dequantize ops. For example, you will have a quantized MatMul op instead of the MatMulNBits op.
use_webgpu_fp32 = Use FP32 for WebGPU EP.
Use this option to enable GPUs that do not support FP16 on WebGPU (e.g. GTX 10xx).
adapter_path = Path to folder on disk containing the adapter files (adapter_config.json and adapter model weights).
Use this option for LoRA models.
include_prompt_templates = Include prompt templates in the GenAI config file. Default is false.
Expand All @@ -3366,7 +3371,7 @@ def get_args():
)

args = parser.parse_args()
print("Valid precision + execution provider combinations are: FP32 CPU, FP32 CUDA, FP16 CUDA, FP16 DML, INT4 CPU, INT4 CUDA, INT4 DML")
print("Valid precision + execution provider combinations are: FP32 CPU, FP32 CUDA, FP16 CUDA, FP16 DML, INT4 CPU, INT4 CUDA, INT4 DML, INT4 WEBGPU")
return args

if __name__ == '__main__':
Expand Down
Loading