Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch, QNN] Support quantized mobilenet v3 from torch 1.8 #7606

Merged
merged 5 commits into from
Mar 8, 2021

Conversation

masahi
Copy link
Member

@masahi masahi commented Mar 8, 2021

Added support for quantized hardsigmoid and hardswish to support the new, QAT-ed, quantized mobilenet v3 large model from PyTorch 1.8 (see their release note https://github.com/pytorch/vision/releases/tag/v0.9.0).

Here are accuracy and latency numbers on a VNNI capable Icelake laptop.

Model name Torch-Top1 Torch-Top5 TVM-Top1 TVM-Top5 Torch latency (milli sec) TVM latency (milli sec)
mobilenet_v3_large (QAT), per channel 73.70 91.60 72.10 91.90 8.181 5.929

cc @siju-samuel @anijain2305 please review

Copy link
Member

@siju-samuel siju-samuel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@siju-samuel siju-samuel merged commit 760e9b2 into apache:main Mar 8, 2021
@siju-samuel
Copy link
Member

Thanks @masahi @anijain2305

masahi added a commit to masahi/tvm that referenced this pull request Mar 8, 2021
* [Torch] support hardsigmoid

* qhswish first impl

* add qhardsigmoid but the result is not correct

* add qmv3 to test

* comment fix
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request May 6, 2021
* [Torch] support hardsigmoid

* qhswish first impl

* add qhardsigmoid but the result is not correct

* add qmv3 to test

* comment fix
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request May 11, 2021
* [Torch] support hardsigmoid

* qhswish first impl

* add qhardsigmoid but the result is not correct

* add qmv3 to test

* comment fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants