Skip to content

Commit

Permalink
Adds Some uint8 Quantization Fixes (llvm#3122)
Browse files Browse the repository at this point in the history
1. Changes the linalg lowering for dequantization ops to always sign
cast to float to prevent misrepresenting uint32 overflow on subtraction
with zero point.
2. Adds a basic quantized model test which only quantizes and
dequantizes and now passes with these changes in linalg and onnx
configs.
3. Changes the aten.mm lowering to allow mismatched quantized types. 
4. If a quantized matmul arg is uint8, we shift by 128 to faithfully
represent the quantization as a signed i8 quantization. This worked fine
in the AtenMmOp lowering, but I'd be happy to move it to a rewrite in
FuseQuantizedOps.cpp instead if that seems more appropriate.

With the changes 3 and 4, the QuantizedMLP_basic and
QuantizedSingleLayer_basic e2e tests now passes with the onnx config.
  • Loading branch information
zjgarvey authored Apr 10, 2024
1 parent 3b84a71 commit aa5e150
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 41 deletions.
43 changes: 36 additions & 7 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,16 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
}

if (lhsTorchType.getDtype() != rhsTorchType.getDtype()) {
return rewriter.notifyMatchFailure(
op, "unsupported: aten.mm with different input element types");
if (!lhsZeroPoint) {
return rewriter.notifyMatchFailure(
op, "unsupported: aten.mm with different input element types");
}
// Allows quantized types to mismatch since they will be cast to the same
// type.
}

bool isUnsigned = torch_to_linalg::isUnsignedTorchType(lhsTorchType);
if (lhsZeroPoint && isUnsigned) {
return rewriter.notifyMatchFailure(
op, "unsupported: unsigned quantized matmul not supported");
}
bool isUnsignedR = torch_to_linalg::isUnsignedTorchType(rhsTorchType);

Value lhsDim0 = rewriter.create<tensor::DimOp>(loc, lhs, 0);
Value rhsDim1 = rewriter.create<tensor::DimOp>(loc, rhs, 1);
Expand Down Expand Up @@ -139,7 +140,7 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);

Value matmul;
if (lhsZeroPoint && !isUnsigned) {
if (lhsZeroPoint) {
lhsZeroPoint = typeConverter->materializeTargetConversion(
rewriter, loc,
getTypeConverter()->convertType(lhsZeroPoint.getType()),
Expand All @@ -152,6 +153,34 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
loc, rewriter.getI32Type(), lhsZeroPoint);
rhsZeroPoint = rewriter.create<arith::TruncIOp>(
loc, rewriter.getI32Type(), rhsZeroPoint);

// for uint8 types, we shift down by 128 so that we can faithfully
// represent the quantization with signed i8 types.
auto signShift = [&](Value &arg, Value &zp, bool isUnsignedType,
int64_t numBits) {
if (!isUnsignedType)
return;
int64_t minSI = -std::pow(2, numBits - 1);
Value minSIValue =
rewriter.create<arith::ConstantIntOp>(loc, minSI, 32);
zp = rewriter.create<arith::AddIOp>(loc, zp, minSIValue);
minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits);
arg = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, ValueRange{arg},
arg.getType().cast<TensorType>().getElementType(),
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value result = rewriter.create<arith::AddIOp>(loc, payloadArgs[0],
minSIValue);
b.create<linalg::YieldOp>(loc, result);
});
};

int64_t numBits =
lhsType.getElementType().cast<mlir::IntegerType>().getWidth();
signShift(lhs, lhsZeroPoint, isUnsigned, numBits);
numBits = rhsType.getElementType().cast<mlir::IntegerType>().getWidth();
signShift(rhs, rhsZeroPoint, isUnsignedR, numBits);

matmul =
rewriter
.create<linalg::QuantizedMatmulOp>(
Expand Down
9 changes: 3 additions & 6 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1481,12 +1481,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}

value = b.create<arith::SubIOp>(loc, value, zp);

if (torch_to_linalg::isUnsignedTorchType(qtensorTy)) {
value = b.create<arith::UIToFPOp>(loc, outFpTy, value);
} else {
value = b.create<arith::SIToFPOp>(loc, outFpTy, value);
}
// treat the i32 as a signed int regardless of original signed-ness
// this will prevent overflow from subtraction for unsigned quantizations.
value = b.create<arith::SIToFPOp>(loc, outFpTy, value);

scale = converter->materializeTargetConversion(
b, loc, converter->convertType(scale.getType()), scale);
Expand Down
5 changes: 1 addition & 4 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
'AtenIntBoolOpModule_basic',
'QuantizedMLP_basic',
'QuantizedSingleLayer_basic',
'QuantizedNoLayer_basic',
'ScalarImplicitFloatModule_basic',
'ScalarImplicitIntModule_basic',
# END tests failing due to: torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default
Expand Down Expand Up @@ -2112,10 +2113,6 @@
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",

# Failure - torch.aten.mm lower (mixed signedness of qtypes)
"QuantizedMLP_basic",
"QuantizedSingleLayer_basic",

# Failure - torch.aten.squeeze lower
"BucketizeTensorOutInt32RightModule_basic", # unsupported by backend contract: tensor with unknown rank

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,41 @@

# ==============================================================================

def get_quant_model_input():
return 2 * torch.rand((1, 16)) - 1

class QuantizedNoLayer(nn.Module):
def __init__(self):
super().__init__()
torch.random.manual_seed(0)
self.quantize = torch.quantization.QuantStub()
self.dequantize = torch.quantization.DeQuantStub()

@export
@annotate_args([
None,
([1, 16], torch.float32, True),
])
def forward(self, x):
x = self.quantize(x)
x = self.dequantize(x)
return x

def get_quantized_no_layer():
model = QuantizedNoLayer()
model.eval()
model.qconfig = torch.quantization.default_qconfig
torch.quantization.prepare(model, inplace=True)
torch.manual_seed(0)
for _ in range(32):
model(get_quant_model_input())
torch.quantization.convert(model, inplace=True)
return model

@register_test_case(module_factory=get_quantized_no_layer)
def QuantizedNoLayer_basic(module, tu: TestUtils):
module.forward(get_quant_model_input())

class QuantizedSingleLayer(nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -22,7 +57,6 @@ def __init__(self):
self.quantize = torch.quantization.QuantStub()
self.dequantize = torch.quantization.DeQuantStub()

@export
@export
@annotate_args([
None,
Expand All @@ -34,6 +68,20 @@ def forward(self, x):
x = self.dequantize(x)
return x

def get_quantized_single_layer():
model = QuantizedSingleLayer()
model.eval()
model.qconfig = torch.quantization.default_qconfig
torch.quantization.prepare(model, inplace=True)
torch.manual_seed(0)
for _ in range(32):
model(get_quant_model_input())
torch.quantization.convert(model, inplace=True)
return model

@register_test_case(module_factory=get_quantized_single_layer)
def QuantizedSingleLayer_basic(module, tu: TestUtils):
module.forward(get_quant_model_input())

class QuantizedMLP(nn.Module):
def __init__(self):
Expand All @@ -47,7 +95,6 @@ def __init__(self):
self.quantize = torch.quantization.QuantStub()
self.dequantize = torch.quantization.DeQuantStub()

@export
@export
@annotate_args([
None,
Expand All @@ -59,37 +106,17 @@ def forward(self, x):
x = self.dequantize(x)
return x


def get_mlp_input():
return 2 * torch.rand((1, 16)) - 1


def get_quantized_mlp():
model = QuantizedMLP()
model.eval()
model.qconfig = torch.quantization.default_qconfig
torch.quantization.prepare(model, inplace=True)
torch.manual_seed(0)
for _ in range(32):
model(get_mlp_input())
torch.quantization.convert(model, inplace=True)
return model

def get_quantized_single_layer():
model = QuantizedSingleLayer()
model.eval()
model.qconfig = torch.quantization.default_qconfig
torch.quantization.prepare(model, inplace=True)
torch.manual_seed(0)
for _ in range(32):
model(get_mlp_input())
model(get_quant_model_input())
torch.quantization.convert(model, inplace=True)
return model

@register_test_case(module_factory=get_quantized_single_layer)
def QuantizedSingleLayer_basic(module, tu: TestUtils):
module.forward(get_mlp_input())

@register_test_case(module_factory=get_quantized_mlp)
def QuantizedMLP_basic(module, tu: TestUtils):
module.forward(get_mlp_input())
module.forward(get_quant_model_input())

0 comments on commit aa5e150

Please sign in to comment.