-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Add converter for QAttention from Microsoft onnxruntime contrib opset #13654
Conversation
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment. Generated by tvm-bot |
…atmul_tensorcore_cuda
Hello @vvchernov, @echuraev, @AndrewZhaoLuo! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. But I don't have a lot of knowledge in this codebase. @jwfromm, @AndrewZhaoLuo could you please take a look at this PR?
# Currently only (batch, past_seq_length + seq_length) shape is supported. | ||
mask_index = inputs[5] | ||
|
||
# Scalar, which means a per-tensor/layer quantization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: You have absolutely the same comment for input[3]
and input[7]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/tvm/relay/frontend/onnx.py
Outdated
result, | ||
_op.multiply(lhs_scale, rhs_scale), | ||
zero_point_zero, | ||
axis=-1, # TODO(agladyshev): what is 'axis' parameter for? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you still need this todo comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Apologies, I've been quite sick. I'll try to look at this Thursday. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Sorry for late review. LGTM
…b opset (apache#13654) * init QAttention converter * add type and shape checking * add test for QAttention * add tests for optional parameters * change mask_index shape * add support for 'past' input * add support for 'unidirectional' attribute * expand test coverage * fix lint * fix pylint * fix batch dimension for topi/cuda/batch_matmul_tensorcore.py::batch_matmul_tensorcore_cuda * code review fix
This PR adds support for QAttention - quantized version of Attention from Microsoft onnxruntime contrib opset.
An explanation and illustration of how this layer works can be found, for example, in @lena-voita NLP course.