-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support lowering of channelwise quantization to linalg #3
base: bartel/roo-62-fix-channelwise-quantization-in-torch-mlir-for-qlinearconv
Are you sure you want to change the base?
Conversation
Warning This pull request is not mergeable via GitHub because a downstack PR is open. Once all requirements are satisfied, merge this PR as a stack on Graphite.
This stack of pull requests is managed by Graphite. Learn more about stacking. |
QuantizationValues getQuantizationPerTensorValues( | ||
ConversionPatternRewriter &rewriter, Location loc, | ||
Aten_MakePerTensorQuantizedTensorOp makePerTensorQuantizedTensorOp, | ||
const TypeConverter *const typeConverter) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: make a reference as we assume the pointer is valid through out the function.
zeroPoint); | ||
|
||
// create a linalg op since we need to do some arithmetic on the zero point | ||
// but is it a tensor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// but is it a tensor. | |
// as it is a tensor. |
|
||
// create a linalg op since we need to do some arithmetic on the zero point | ||
// but is it a tensor. | ||
RankedTensorType zeroPointType = cast<RankedTensorType>(zeroPoint.getType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit gritty: use auto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks ok to me. Maybe in the future we break this PR even further. To me there were 3 concepts here quantized per channel conv, handling that case with transpose and a group conv implementation.
Overall an improvement but yeah we will have to revisit this one day.,
zeroPoint = torch_to_linalg::createElementwiseLinalgGeneric( | ||
rewriter, loc, zeroPoint, rewriter.getI32Type(), | ||
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) { | ||
Value result = rewriter.create<arith::ExtUIOp>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Value result = rewriter.create<arith::ExtUIOp>( | |
Value result = rewriter.create<arith::ExtIOp>( |
Since we assume it is an integer?
ConversionPatternRewriter &rewriter, Location loc, | ||
Aten_MakePerChannelQuantizedTensorOp makePerChannelQuantizedTensorOp, | ||
const TypeConverter *const typeConverter) { | ||
QuantizationValues values; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: move the definition down to where the members are being set at the end of the funciton.
convolutionAttributes.outputPadding[i])); | ||
|
||
// Set stride to 1 | ||
convolutionAttributes.stride.clear(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the stride cleared and simply set to 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is handled by the InsertSliceOp in line 1050. Not sure if this is the most performant way, but I basically just copied code and moved it into a function.
Yeah, I only added the the channel wise case and the rest was refactoring. I think it would have been easy with Graphite, I will do it next time. Sorry! |
No description provided.