Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add converter for QAttention from Microsoft onnxruntime contrib opset #13654

Merged
merged 12 commits into from
Jan 3, 2023

Conversation

KJlaccHoeUM9l
Copy link
Contributor

This PR adds support for QAttention - quantized version of Attention from Microsoft onnxruntime contrib opset.
An explanation and illustration of how this layer works can be found, for example, in @lena-voita NLP course.

@tvm-bot
Copy link
Collaborator

tvm-bot commented Dec 23, 2022

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@KJlaccHoeUM9l
Copy link
Contributor Author

Hello @vvchernov, @echuraev, @AndrewZhaoLuo!
Could you review this PR?

Copy link
Contributor

@echuraev echuraev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. But I don't have a lot of knowledge in this codebase. @jwfromm, @AndrewZhaoLuo could you please take a look at this PR?

# Currently only (batch, past_seq_length + seq_length) shape is supported.
mask_index = inputs[5]

# Scalar, which means a per-tensor/layer quantization
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: You have absolutely the same comment for input[3] and input[7]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

result,
_op.multiply(lhs_scale, rhs_scale),
zero_point_zero,
axis=-1, # TODO(agladyshev): what is 'axis' parameter for?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you still need this todo comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@AndrewZhaoLuo
Copy link
Contributor

Apologies, I've been quite sick. I'll try to look at this Thursday.

Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Sorry for late review. LGTM

@AndrewZhaoLuo AndrewZhaoLuo merged commit e24d4fb into apache:main Jan 3, 2023
@KJlaccHoeUM9l KJlaccHoeUM9l deleted the agladyshev/dev/qattention branch January 10, 2023 12:02
fzi-peccia pushed a commit to fzi-peccia/tvm that referenced this pull request Mar 27, 2023
…b opset (apache#13654)

* init QAttention converter

* add type and shape checking

* add test for QAttention

* add tests for optional parameters

* change mask_index shape

* add support for 'past' input

* add support for 'unidirectional' attribute

* expand test coverage

* fix lint

* fix pylint

* fix batch dimension for topi/cuda/batch_matmul_tensorcore.py::batch_matmul_tensorcore_cuda

* code review fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants