From 140babd952c0cb86bc888a4d616ff41cd3496f4f Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Tue, 29 Mar 2022 22:57:31 +0000 Subject: [PATCH] Add minimal support for Union types. A recent PyTorch commit made ConstantPad2d call a helper function with a `Union[int, float]` type annotated. This commit adds minimal support for representing and dealing with that. https://github.com/pytorch/pytorch/pull/73287 Changes: - Adding support for `!torch.union`/`Torch::UnionType`, along with the importer and CAPI code. - Add support in isValidSubtype for union types. - Adding a canonicalizer for `torch.derefine` to help simplify some code that derefines to a UnionType (this also fixes #664). There is still more work to do for really supporting UnionType well, such as canonicalizing UnionType's so that they can be compared with pointer equality. --- include/torch-mlir-c/TorchTypes.h | 12 ++++ .../torch-mlir/Dialect/Torch/IR/TorchOps.td | 1 + .../torch-mlir/Dialect/Torch/IR/TorchTypes.td | 27 +++++-- lib/CAPI/TorchTypes.cpp | 18 +++++ lib/Dialect/Torch/IR/TorchOps.cpp | 14 ++++ lib/Dialect/Torch/IR/TorchTypes.cpp | 70 ++++++++++++++++--- .../jit_ir/csrc/torch_to_mlir_utils.cpp | 9 +++ requirements.txt | 5 +- test/Dialect/Torch/canonicalize.mlir | 16 +++++ test/Dialect/Torch/ops.mlir | 12 ++++ .../importer/jit_ir/node_import/union.py | 24 +++++++ 11 files changed, 189 insertions(+), 19 deletions(-) create mode 100644 test/python/importer/jit_ir/node_import/union.py diff --git a/include/torch-mlir-c/TorchTypes.h b/include/torch-mlir-c/TorchTypes.h index 4268314afc56..586468ba7d0c 100644 --- a/include/torch-mlir-c/TorchTypes.h +++ b/include/torch-mlir-c/TorchTypes.h @@ -52,6 +52,18 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes, MlirType const *containedTypes); +//===----------------------------------------------------------------------===// +// torch.union type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.union type +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchUnion(MlirType t); + +/// Gets the !torch.union type with contained types `containedTypes`. +MLIR_CAPI_EXPORTED MlirType +torchMlirTorchUnionTypeGet(MlirContext context, intptr_t numContainedTypes, + MlirType const *containedTypes); + //===----------------------------------------------------------------------===// // torch.list type. //===----------------------------------------------------------------------===// diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index 7560514fb039..63061074bb43 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -686,6 +686,7 @@ def Torch_DerefineOp : Torch_Op<"derefine", [ }]; let hasFolder = 1; + let hasCanonicalizer = 1; } def Torch_OperatorOp : Torch_Op<"operator", [ diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td index 19e0ab01a357..2eb6c5a43489 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td @@ -245,6 +245,24 @@ def Torch_TupleType : Torch_Type<"Tuple", "tuple"> { ); } +def Torch_UnionType : Torch_Type<"Union", "union"> { + let summary = "!torch.union"; + let description = [{ + Union type with 0-N alternative types. + + NOTE: We use the terminology "contained types" for consistency with + PyTorch. Strictly speaking, the types aren't "contained" though. + + TODO: Canonicalize unions based on subtype relations, to allow + using pointer equality to compare two unions for being the same. + For now, `!torch.union` is different from `!torch.union`, + and same for `!torch.union` vs `!torch.union`. + }]; + let parameters = (ins + ArrayRefParameter<"::mlir::Type", "contained types">:$containedTypes + ); +} + def Torch_DeviceType : Torch_Type<"Device", "Device"> { let summary = "Torch device"; } @@ -417,11 +435,9 @@ def AnyTorchOptionalTensorListType : "Any optional tensor list type (Tensor?[])">; // Note: TorchScript does not consider !torch.bool to be a Scalar. -def AnyTorchScalarType : AnyTypeOf<[ - Torch_IntType, - Torch_FloatType, - Torch_NumberType, -], "Any Python numeric type compatible with being the scalar type of a tensor (`Scalar`)">; +def AnyTorchScalarType : + Type, + "Any Python numeric type compatible with being the scalar type of a tensor">; def AnyTorchOptionalScalarType: OptionalOf; @@ -454,6 +470,7 @@ def AnyTorchType : AnyTypeOf<[ Torch_OptionalType, Torch_StringType, Torch_TupleType, + Torch_UnionType, ], "Any type that is legal to pass to a Torch kernel">; def AnyTorchListType : ListOf<[AnyType], "Any Torch list Type">; diff --git a/lib/CAPI/TorchTypes.cpp b/lib/CAPI/TorchTypes.cpp index df9d9771f6ce..ff01b9518639 100644 --- a/lib/CAPI/TorchTypes.cpp +++ b/lib/CAPI/TorchTypes.cpp @@ -60,6 +60,24 @@ MlirType torchMlirTorchTupleTypeGet(MlirContext context, [](MlirType t) { return unwrap(t); })))); } +//===----------------------------------------------------------------------===// +// torch.union type. +//===----------------------------------------------------------------------===// + +bool torchMlirTypeIsATorchUnion(MlirType t) { + return unwrap(t).isa(); +} + +MlirType torchMlirTorchUnionTypeGet(MlirContext context, + intptr_t numContainedTypes, + MlirType const *containedTypes) { + return wrap(Torch::UnionType::get( + unwrap(context), + llvm::to_vector<6>( + llvm::map_range(llvm::makeArrayRef(containedTypes, numContainedTypes), + [](MlirType t) { return unwrap(t); })))); +} + //===----------------------------------------------------------------------===// // torch.list type. //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 0834be386612..ac03e1d123f7 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -452,6 +452,20 @@ OpFoldResult DerefineOp::fold(ArrayRef operands) { return nullptr; } +void DerefineOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](DerefineOp op, PatternRewriter &rewriter) { + bool madeChange = false; + for (OpOperand &use : llvm::make_early_inc_range(op->getUses())) { + if (use.getOwner()->hasTrait()) { + use.set(op.getOperand()); + madeChange = true; + } + } + return success(madeChange); + }); +} + static OpFoldResult atenIsOrIsNotFoldHelper(Operation *op, bool equalIsTrue) { Value lhs = op->getOperand(0); Value rhs = op->getOperand(1); diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index a5fdc824bfad..bf4046cbcd76 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -25,6 +25,16 @@ bool Torch::isValidSubtype(Type subtype, Type type) { if (subtype == type) return true; + // For a UnionType to be a subtype, all of its contained types must be + // subtypes. + if (auto unionType = subtype.dyn_cast()) { + for (auto containedType : unionType.getContainedTypes()) { + if (!isValidSubtype(containedType, type)) + return false; + } + return true; + } + if (auto any = type.dyn_cast()) return true; @@ -35,6 +45,14 @@ bool Torch::isValidSubtype(Type subtype, Type type) { return isValidSubtype(subtype, optional.getContainedType()) || subtype.isa(); + if (auto unionType = type.dyn_cast()) { + for (auto containedType : unionType.getContainedTypes()) { + if (isValidSubtype(subtype, containedType)) + return true; + } + return false; + } + if (auto tuple = type.dyn_cast()) { if (!subtype.isa()) return false; @@ -63,36 +81,66 @@ bool Torch::isValidSubtype(Type subtype, Type type) { } //===----------------------------------------------------------------------===// -// TupleType +// Helpers for TupleType and UnionType //===----------------------------------------------------------------------===// -Type Torch::TupleType::parse(AsmParser &parser) { - MLIRContext *context = parser.getContext(); +// Parse the `` of a type such as `!torch.tuple`. +static Optional> +parseMultipleContainedTypes(AsmParser &parser) { if (parser.parseLess()) - return Type(); - if (!parser.parseOptionalGreater()) - return Torch::TupleType::get(context, {}); + return None; SmallVector containedTypes; + if (!parser.parseOptionalGreater()) + return containedTypes; do { Type containedType = parseTorchDialectType(parser); if (!containedType) - return Type(); + return None; containedTypes.push_back(containedType); } while (!parser.parseOptionalComma()); if (parser.parseGreater()) - return Type(); - return Torch::TupleType::get(context, containedTypes); + return None; + return containedTypes; } -void Torch::TupleType::print(::mlir::AsmPrinter &printer) const { +static void printMultipleContainedTypes(AsmPrinter &printer, + ArrayRef containedTypes) { printer << "<"; - llvm::interleaveComma(getContainedTypes(), printer, [&](Type type) { + llvm::interleaveComma(containedTypes, printer, [&](Type type) { printTorchDialectType(type, printer); }); printer << ">"; } +//===----------------------------------------------------------------------===// +// TupleType +//===----------------------------------------------------------------------===// + +Type Torch::TupleType::parse(AsmParser &parser) { + if (auto containedTypes = parseMultipleContainedTypes(parser)) + return TupleType::get(parser.getContext(), *containedTypes); + return Type(); +} + +void Torch::TupleType::print(AsmPrinter &printer) const { + printMultipleContainedTypes(printer, getContainedTypes()); +} + +//===----------------------------------------------------------------------===// +// UnionType +//===----------------------------------------------------------------------===// + +Type Torch::UnionType::parse(AsmParser &parser) { + if (auto containedTypes = parseMultipleContainedTypes(parser)) + return UnionType::get(parser.getContext(), *containedTypes); + return Type(); +} + +void Torch::UnionType::print(AsmPrinter &printer) const { + printMultipleContainedTypes(printer, getContainedTypes()); +} + //===----------------------------------------------------------------------===// // BaseTensorType //===----------------------------------------------------------------------===// diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp index 7b0fade2fa07..feace543b801 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp @@ -181,6 +181,15 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, return torchMlirTorchTupleTypeGet(context, containedTypes.size(), containedTypes.data()); } + case TypeKind::UnionType: { + std::vector containedTypes; + for (const c10::TypePtr &type : + torchType->cast()->containedTypes()) { + containedTypes.push_back(getMlirTypeFromTorchType(loc, type)); + } + return torchMlirTorchUnionTypeGet(context, containedTypes.size(), + containedTypes.data()); + } case TypeKind::ListType: { return torchMlirTorchListTypeGet(getMlirTypeFromTorchType( loc, torchType->cast()->getElementType())); diff --git a/requirements.txt b/requirements.txt index ead7c8ebf5f4..45e653a35ac2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,9 +2,8 @@ --pre numpy -# TODO: Fix for latest PyTorch. -torch==1.12.0.dev20220328+cpu -torchvision==0.13.0.dev20220328+cpu +torch +torchvision # Build requirements. pybind11 diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index a03326ac0a64..cd2527b2322e 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -938,6 +938,22 @@ func @torch.derefine$of_unchecked_cast(%arg0: !torch.optional) -> !torch.op return %1 : !torch.optional } +// CHECK-LABEL: func @torch.derefine$use_allows_type_refinement( +// CHECK-SAME: %{{.*}}: !torch.int) -> (!torch.vtensor, !torch.optional) { +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[DEREFINED:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional +// For the use that allows type refinement, we replace it with the refined value. +// CHECK: %[[ARANGE:.*]] = torch.aten.arange.start %{{.*}}, %{{.*}}, %[[NONE]], %{{.*}}, %{{.*}}, %{{.*}} : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor +// For the use that does not allow type refinement, don't replace. +// CHECK: return %[[ARANGE]], %[[DEREFINED]] : !torch.vtensor, !torch.optional +func @torch.derefine$use_allows_type_refinement(%arg0: !torch.int) -> (!torch.vtensor, !torch.optional) { + %none = torch.constant.none + %optional = torch.derefine %none : !torch.none to !torch.optional + %ret = torch.aten.arange.start %arg0, %arg0, %optional, %none, %none, %none: !torch.int, !torch.int, !torch.optional, !torch.none, !torch.none, !torch.none -> !torch.vtensor + return %ret, %optional : !torch.vtensor, !torch.optional +} + + // CHECK-LABEL: func @torch.tensor_static_info_cast$downcast_first( // CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor { // CHECK: return %[[T]] : !torch.tensor diff --git a/test/Dialect/Torch/ops.mlir b/test/Dialect/Torch/ops.mlir index 43339cc0b2b3..83ec03c7599c 100644 --- a/test/Dialect/Torch/ops.mlir +++ b/test/Dialect/Torch/ops.mlir @@ -35,6 +35,13 @@ func private @tuple.one_element() -> !torch.tuple // CHECK: @tuple.two_elements() -> !torch.tuple func private @tuple.two_elements() -> !torch.tuple +// CHECK: @union.empty() -> !torch.union<> +func private @union.empty() -> !torch.union<> +// CHECK: @union.one_element() -> !torch.union +func private @union.one_element() -> !torch.union +// CHECK: @union.two_elements() -> !torch.union +func private @union.two_elements() -> !torch.union + // CHECK: @dict() -> !torch.dict func private @dict() -> !torch.dict @@ -134,3 +141,8 @@ func @shape_calculations(%arg0: !torch.vtensor) -> !torch.vtensor { } : !torch.vtensor return %0 : !torch.vtensor } + +func @number_type_subtypes(%arg0: !torch.tensor, %arg1: !torch.list, %arg2: !torch.union) { + %0 = torch.aten.constant_pad_nd %arg0, %arg1, %arg2 : !torch.tensor, !torch.list, !torch.union -> !torch.tensor + return +} diff --git a/test/python/importer/jit_ir/node_import/union.py b/test/python/importer/jit_ir/node_import/union.py new file mode 100644 index 000000000000..f87ba76654d6 --- /dev/null +++ b/test/python/importer/jit_ir/node_import/union.py @@ -0,0 +1,24 @@ +# -*- Python -*- +# This file is licensed under a pytorch-style license +# See LICENSE.pytorch for license information. + +from typing import Union + +import torch +from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder + +# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s + +mb = ModuleBuilder() + +# CHECK-LABEL: func @__torch__.f( +# CHECK-SAME: %{{.*}}: !torch.union) -> !torch.none { + +@mb.import_function +@torch.jit.script +def f(x: Union[int, float]): + return + +assert isinstance(f, torch.jit.ScriptFunction) +mb.module.operation.print() +print()