Skip to content

Commit

Permalink
Implement lowering of torch.aten.norm.Scalar (llvm#2899)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptrifunovic98 authored Feb 26, 2024
1 parent 89e02c1 commit c5a1da1
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 8 deletions.
25 changes: 25 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6325,6 +6325,31 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
}];
}

def Torch_AtenNormScalarOp : Torch_Op<"aten.norm.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$p
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNormScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenNormScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasVerifier = 1;
}

def Torch_AtenNormScalarOptDimOp : Torch_Op<"aten.norm.ScalarOpt_dim", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
53 changes: 45 additions & 8 deletions lib/Conversion/TorchToLinalg/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
elementType.getIntOrFloatBitWidth())));
}

if (isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op))
if (isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op) ||
isa<AtenNormScalarOp>(op))
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));

if (isa<AtenAllDimOp>(op)) {
Expand Down Expand Up @@ -341,6 +342,26 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
if (intType.isSigned())
return b.create<arith::MinSIOp>(loc, self, result);
}
} else if (isa<AtenNormScalarOp>(op)) {
// This creates payload for only the first of the two linalg.generic ops.
// TODO: Short-circuit operations if `p` is zero or one.
Value elem = payloadArgs[0];
Value result = payloadArgs[1];

// TODO: Fix this part to support complex elements.
if (elem.getType().isa<mlir::ComplexType>()) {
op->emitError("lowering of complex input type for torch.aten.norm.Scalar "
"is currently unimplemented");
return nullptr;
}

Value self = convertScalarToDtype(b, loc, elem, resultElementType);

auto abs = b.create<math::AbsFOp>(loc, self);
AtenNormScalarOp::Adaptor adaptor(operands);
Value p = convertScalarToDtype(b, loc, adaptor.getP(), resultElementType);
auto pow = b.create<math::PowFOp>(loc, abs, p);
return b.create<arith::AddFOp>(loc, pow, result);
} else if (isa<AtenLinalgVectorNormOp>(op)) {
// This creates payload for only the first of the two linalg.generic ops.
// TODO: Short-circuit operations if `ord` is zero or one.
Expand Down Expand Up @@ -433,7 +454,7 @@ class ConvertReductionOp : public ConversionPattern {
ConversionPatternRewriter &rewriter) const {
auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}};

if (isa<AtenMaxOp, AtenMinOp, AtenSumOp>(op)) {
if (isa<AtenMaxOp, AtenMinOp, AtenSumOp, AtenNormScalarOp>(op)) {
opInfo.tensorOperand = operands[0];
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();

Expand Down Expand Up @@ -484,10 +505,12 @@ class ConvertReductionOp : public ConversionPattern {
return err ? Value{} : powOp;
}

FailureOr<Value> createSecondReductionForVectorNormOp(
Location loc, Type elemType, AtenLinalgVectorNormOp op, Value ordOp,
Value firstReduction, const torch_to_linalg::ReductionOpInfo &opInfo,
ConversionPatternRewriter &rewriter) const {
template <typename TOp>
FailureOr<Value>
createSecondReductionForNormOp(Location loc, Type elemType, TOp op,
Value ordOp, Value firstReduction,
const torch_to_linalg::ReductionOpInfo &opInfo,
ConversionPatternRewriter &rewriter) const {
// Cast `ord` to float so that we can readily pass it math.powf.
Value ordValue = convertScalarToDtype(rewriter, loc, ordOp, elemType);

Expand Down Expand Up @@ -544,13 +567,15 @@ class ConvertReductionOp : public ConversionPattern {
LogicalResult
validateReductionElementType(Operation *op, Type elemType,
ConversionPatternRewriter &rewriter) const {
if ((isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op)) &&
if ((isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op) ||
isa<AtenNormScalarOp>(op)) &&
!elemType.isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(
op, "only float types are valid for vector norm ops");
if (isa<AtenAllDimOp>(op) && elemType.isa<mlir::IntegerType>() &&
elemType.getIntOrFloatBitWidth() == 8)
return rewriter.notifyMatchFailure(op, "uint8 is not supported");

// No checks for all other reduction operations
return success();
}
Expand Down Expand Up @@ -587,11 +612,22 @@ class ConvertReductionOp : public ConversionPattern {
return rewriter.notifyMatchFailure(
op, "failed to create linalg.generic operation for reduction");

// If this is aten.norm.Scalar op, then we need to generate another
// linalg.generic op that references the first linalg.generic op.
if (isa<AtenNormScalarOp>(op)) {
AtenNormScalarOp::Adaptor adaptor(operands);
FailureOr<Value> secondReduceOp = createSecondReductionForNormOp(
loc, elemType, op, adaptor.getP(), reduceOp, *opInfo, rewriter);
if (failed(secondReduceOp))
return secondReduceOp;
reduceOp = *secondReduceOp;
}

// If this is aten.linalg_vector_norm op, then we need to generate another
// linalg.generic op that references the first linalg.generic op.
if (auto normOp = dyn_cast<AtenLinalgVectorNormOp>(op)) {
AtenLinalgVectorNormOp::Adaptor adaptor(operands);
FailureOr<Value> secondReduceOp = createSecondReductionForVectorNormOp(
FailureOr<Value> secondReduceOp = createSecondReductionForNormOp(
loc, elemType, normOp, adaptor.getOrd(), reduceOp, *opInfo, rewriter);
if (failed(secondReduceOp))
return secondReduceOp;
Expand Down Expand Up @@ -627,6 +663,7 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
target.addIllegalOp<AtenMaxOp>();
target.addIllegalOp<AtenMinOp>();
target.addIllegalOp<AtenAllDimOp>();
target.addIllegalOp<AtenNormScalarOp>();
target.addIllegalOp<AtenLinalgVectorNormOp>();
target.addIllegalOp<AtenFrobeniusNormDimOp>();
patterns.add<ConvertReductionOp>(typeConverter, context);
Expand Down
36 changes: 36 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3767,6 +3767,42 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// AtenNormScalarOp
//===----------------------------------------------------------------------===//

LogicalResult AtenNormScalarOp::verify() {

// Verificaion of input type for torch.aten.norm.Scalar.
// Per PyTorch docs, only float and complex types are valid for norm
// operation.

auto inTensor = getSelf().getType().cast<BaseTensorType>();

// If no dtype is specified, it will default to a float one.
if (!inTensor.hasDtype()) {
return success();
}

auto inTensorDtype = inTensor.getDtype();

// Check if dtype is one of those supported by norm operation.
// ComplexType will match any torch complex types, but each float must be
// checked individually.
if (!inTensorDtype.isa<mlir::ComplexType, mlir::Float16Type,
mlir::Float32Type, mlir::Float64Type>()) {
return emitOpError(
"expected a float or complex type for input tensor, but got ")
<< inTensorDtype;
}

return success();
}

//===----------------------------------------------------------------------===//
// AtenPermuteOp
//===----------------------------------------------------------------------===//

LogicalResult AtenPermuteOp::verify() {

// Verification of the permute op for input & output dimensions with
Expand Down
32 changes: 32 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9339,6 +9339,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.norm.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
" %false = torch.constant.bool false\n"
" %none = torch.constant.none\n"
" %0 = torch.derefine %none : !torch.none to !torch.optional<list<int>>\n"
" %1 = torch.derefine %none : !torch.none to !torch.any\n"
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %false, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.norm.ScalarOpt_dim\"(%arg0: !torch.list<int>, %arg1: !torch.optional<float>, %arg2: !torch.list<int>, %arg3: !torch.bool) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.derefine %arg2 : !torch.list<int> to !torch.optional<list<int>>\n"
Expand Down Expand Up @@ -12038,6 +12046,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.norm.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %int5 = torch.constant.int 5\n"
" %int8 = torch.constant.int 8\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.aten.eq.int %0#1, %int8 : !torch.int, !torch.int -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
" torch.prim.If.yield %int5 : !torch.int\n"
" } else {\n"
" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
" torch.prim.If.yield %5 : !torch.int\n"
" }\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.tensor.float\"(%arg0: !torch.float, %arg1: !torch.optional<int>, %arg2: !torch.optional<Device>, %arg3: !torch.bool) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %none = torch.constant.none\n"
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1667,6 +1667,7 @@
"NllLossModule_ignore_index_out_of_bounds_basic",
"NllLossModule_mean_basic",
"NllLossModule_sum_basic",
"NormScalarModule_basic",
"NormScalarOptDimKeepDimModule_basic",
"NormScalarOptDimModule_basic",
"NormalFunctionalModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1722,6 +1722,9 @@ def aten〇linalg_vector_norm〡shape(self: List[int], ord: float = 2, dim: Opti
def aten〇frobenius_norm〇dim〡shape(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0)

def aten〇norm〇Scalar〡shape(self: List[int], p: float = 2) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, None, False, None)

def aten〇norm〇ScalarOpt_dim〡shape(self: List[int], p: Optional[float], dim: List[int], keepdim: bool = False) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0)

Expand Down Expand Up @@ -3924,6 +3927,21 @@ def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Uni
return dtype
return aten〇std〡dtype(self_rank_dtype)

@check_dtype_function(
_check_tensors_with_the_same_dtype(
num_of_tensors=1,
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}))
def aten〇norm〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], p: Union[int, float, complex] = 2) -> int:
self_rank, self_dtype = self_rank_dtype
assert not is_integer_dtype(self_dtype)
# The following check is added because aten〇std〡dtype
# does not handle complex32 transformation to float,
# so it is done manually (torch.half == torch.float16).
# Should possibly be added to aten〇std〡dtype.
if self_dtype == torch.complex32:
return torch.half
return aten〇std〡dtype(self_rank_dtype)

@check_dtype_function([Invocation(0.0),
Invocation(0.0, dtype=torch.int32),
Invocation(0.0, dtype=torch.float16),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
)
emit("aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)", has_verifier=True)
emit(
"aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)"
)
Expand Down
19 changes: 19 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,25 @@ def ReduceL3NormKeepDimModule_basic(module, tu: TestUtils):

# ==============================================================================

class NormScalarModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.p = 3.0

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.norm(a, self.p)

@register_test_case(module_factory=lambda: NormScalarModule())
def NormScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))

# ==============================================================================

class NormScalarOptDimModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
Expand Down

0 comments on commit c5a1da1

Please sign in to comment.