You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Using evaluate() after mixed-precision training with ORTTrainer. But I think the problem is general for converting PyTorch models to ONNX with FP16 weights.
Error message
Traceback (most recent call last):
File "test_onnxruntime_train.py", line 279, in test_ort_trainer_decoder
ort_eval_metrics = trainer.evaluate(inference_with_ort=inference_with_ort)
File "/workspace/optimum/onnxruntime/trainer.py", line 813, in evaluate
output = eval_loop(
File "/workspace/optimum/onnxruntime/trainer.py", line 969, in evaluation_loop_ort
self._export(onnx_model_path, with_loss=with_loss, device=export_device)
File "/workspace/optimum/onnxruntime/trainer.py", line 1501, in _export
_ = export(
File "/workspace/optimum/exporters/onnx/convert.py", line 607, in export
return export_pytorch(model, config, opset, output, device=device, input_shapes=input_shapes)
File "/workspace/optimum/exporters/onnx/convert.py", line 370, in export_pytorch
onnx_export(
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 350, in export
return utils.export(
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 163, in export
_export(
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1074, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 727, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 602, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 517, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 1175, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 127, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 118, in wrapper
outs.append(self.inner(*trace_inputs))
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1118, in _slow_forward
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/transformers/models/gpt2/modeling_gpt2.py", line 1047, in forward
transformer_outputs = self.transformer(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1118, in _slow_forward
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/transformers/models/gpt2/modeling_gpt2.py", line 891, in forward
outputs = block(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1118, in _slow_forward
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/transformers/models/gpt2/modeling_gpt2.py", line 392, in forward
attn_outputs = self.attn(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1118, in _slow_forward
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/transformers/models/gpt2/modeling_gpt2.py", line 333, in forward
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
File "/usr/local/lib/python3.8/dist-packages/transformers/models/gpt2/modeling_gpt2.py", line 185, in _attn
attn_weights = torch.matmul(query, key.transpose(-1, -2))
RuntimeError: expected scalar type Float but found Half
Expected behavior
Decoder with past with fp16 weights can be successfully exported to ONNX model.
Contribution
I can take a closer look, but I don't have the bandwidth for the moment.
The text was updated successfully, but these errors were encountered:
This is weird as I've verified that both query and key are float16 and on CUDA. Besides, the export of decoder without past doesn't have that issue. Will investigate further when I have the bandwidth...
[Update] I won't fix it unless large need from the community. As proper inference with ORT should go with subclasses of ORTModel instead of ORTTrainer. The inference part of ORTTrainer is just for the fast test.
JingyaHuang
changed the title
Decode with cache ONNX export failed after mixed precision training
Decoder with cache ONNX export failed after mixed precision training
Jan 24, 2023
System Info
Who can help?
@JingyaHuang
Reproduction
Using
evaluate()
after mixed-precision training withORTTrainer
. But I think the problem is general for converting PyTorch models to ONNX with FP16 weights.Error message
Expected behavior
Decoder with past with fp16 weights can be successfully exported to ONNX model.
Contribution
I can take a closer look, but I don't have the bandwidth for the moment.
The text was updated successfully, but these errors were encountered: