Skip to content

Commit

Permalink
cleanup for webgpu option
Browse files Browse the repository at this point in the history
  • Loading branch information
guschmue committed Feb 26, 2025
1 parent 6a2eeeb commit 9e845e7
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
"tunable_op_tuning_enable": "1",
},
"dml": {},
"web": {},
"webgpu": {},
}

# Map input names to their types and shapes
Expand Down Expand Up @@ -246,6 +246,8 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
("cuda", TensorProto.FLOAT16),
("rocm", TensorProto.FLOAT16),
("dml", TensorProto.FLOAT16),
("webgpu", TensorProto.FLOAT16),
("webgpu", TensorProto.FLOAT),
]
if (self.ep, self.io_dtype) in valid_gqa_configurations:
# Change model settings for GroupQueryAttention
Expand All @@ -254,10 +256,10 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):

# DML doesn't support packed Q/K/V for GQA yet
# Packed MatMul with LoRA/QLoRA is not currently supported
self.attention_attrs["use_packed_matmul"] = self.ep != "dml" and not self.matmul_attrs["use_lora"]
self.attention_attrs["use_packed_matmul"] = self.ep not in ["dml", "webgpu"] and not self.matmul_attrs["use_lora"]

# GQA + Rot.Emb. does not require `position ids` as input
if self.ep != "dml":
if self.ep not in ["dml", "webgpu"]:
self.attention_attrs["use_rotemb_in_attn"] = True
self.input_names.remove("position_ids")

Expand Down Expand Up @@ -3195,7 +3197,8 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid
config.update(peft_config.__dict__)

# Set input/output precision of ONNX model
io_dtype = TensorProto.FLOAT if precision in {"int8", "fp32"} or (precision == "int4" and execution_provider == "cpu") else TensorProto.FLOAT16
use_webgpu_fp32 = extra_options.get("use_webgpu_fp32", "0") == "1"
io_dtype = TensorProto.FLOAT if precision in {"int8", "fp32"} or (precision == "int4" and execution_provider == "cpu") or use_webgpu_fp32 else TensorProto.FLOAT16

if "config_only" not in extra_options:
# List architecture options in alphabetical order
Expand Down Expand Up @@ -3299,8 +3302,8 @@ def get_args():
"-e",
"--execution_provider",
required=True,
choices=["cpu", "cuda", "rocm", "dml", "web"],
help="Execution provider to target with precision of model (e.g. FP16 CUDA, INT4 CPU, INT4 WEB)",
choices=["cpu", "cuda", "rocm", "dml", "webgpu"],
help="Execution provider to target with precision of model (e.g. FP16 CUDA, INT4 CPU, INT4 WEBGPU)",
)

parser.add_argument(
Expand Down Expand Up @@ -3358,6 +3361,8 @@ def get_args():
If true, the QMoE op will use 8-bit quantization. If false, the QMoE op will use 4-bit quantization.
use_qdq = Use the QDQ decomposition for ops.
Use this option when you want to use quantize-dequantize ops. For example, you will have a quantized MatMul op instead of the MatMulNBits op.
use_webgpu_fp32 = Use fp32 for webgpu ep.
Use this option to enable gpus that do not support fp16 on webgpu, ie gtx10xx.
adapter_path = Path to folder on disk containing the adapter files (adapter_config.json and adapter model weights).
Use this option for LoRA models.
include_prompt_templates = Include prompt templates in the GenAI config file. Default is false.
Expand All @@ -3366,7 +3371,7 @@ def get_args():
)

args = parser.parse_args()
print("Valid precision + execution provider combinations are: FP32 CPU, FP32 CUDA, FP16 CUDA, FP16 DML, INT4 CPU, INT4 CUDA, INT4 DML")
print("Valid precision + execution provider combinations are: FP32 CPU, FP32 CUDA, FP16 CUDA, FP16 DML, INT4 CPU, INT4 CUDA, INT4 DML, INT4 WEBGPU")
return args

if __name__ == '__main__':
Expand Down

0 comments on commit 9e845e7

Please sign in to comment.