Skip to content

Commit

Permalink
Add minimal support for Union types.
Browse files Browse the repository at this point in the history
A recent PyTorch commit made ConstantPad2d call a helper function with a
`Union[int, float]` type annotated. This commit adds minimal support for
representing and dealing with that.
pytorch/pytorch#73287

Changes:
- Adding support for `!torch.union<T1, T2, T3>`/`Torch::UnionType`,
  along with the importer and CAPI code.
- Add support in isValidSubtype for union types.
- Adding a canonicalizer for `torch.derefine` to help simplify some code
  that derefines to a UnionType (this also fixes llvm#664).

There is still more work to do for really supporting UnionType well,
such as canonicalizing UnionType's so that they can be compared with
pointer equality.
  • Loading branch information
silvasean committed Mar 30, 2022
1 parent 4f61b1f commit 140babd
Show file tree
Hide file tree
Showing 11 changed files with 189 additions and 19 deletions.
12 changes: 12 additions & 0 deletions include/torch-mlir-c/TorchTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ MLIR_CAPI_EXPORTED MlirType
torchMlirTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes,
MlirType const *containedTypes);

//===----------------------------------------------------------------------===//
// torch.union<T1, T2, T3> type.
//===----------------------------------------------------------------------===//

/// Checks whether the given type is a !torch.union type
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchUnion(MlirType t);

/// Gets the !torch.union type with contained types `containedTypes`.
MLIR_CAPI_EXPORTED MlirType
torchMlirTorchUnionTypeGet(MlirContext context, intptr_t numContainedTypes,
MlirType const *containedTypes);

//===----------------------------------------------------------------------===//
// torch.list<T> type.
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/IR/TorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,7 @@ def Torch_DerefineOp : Torch_Op<"derefine", [
}];

let hasFolder = 1;
let hasCanonicalizer = 1;
}

def Torch_OperatorOp : Torch_Op<"operator", [
Expand Down
27 changes: 22 additions & 5 deletions include/torch-mlir/Dialect/Torch/IR/TorchTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,24 @@ def Torch_TupleType : Torch_Type<"Tuple", "tuple"> {
);
}

def Torch_UnionType : Torch_Type<"Union", "union"> {
let summary = "!torch.union<T1, T2, T3>";
let description = [{
Union type with 0-N alternative types.

NOTE: We use the terminology "contained types" for consistency with
PyTorch. Strictly speaking, the types aren't "contained" though.

TODO: Canonicalize unions based on subtype relations, to allow
using pointer equality to compare two unions for being the same.
For now, `!torch.union<T1, T2>` is different from `!torch.union<T2, T1>`,
and same for `!torch.union<T1, SubtypeOfT1>` vs `!torch.union<T1>`.
}];
let parameters = (ins
ArrayRefParameter<"::mlir::Type", "contained types">:$containedTypes
);
}

def Torch_DeviceType : Torch_Type<"Device", "Device"> {
let summary = "Torch device";
}
Expand Down Expand Up @@ -417,11 +435,9 @@ def AnyTorchOptionalTensorListType :
"Any optional tensor list type (Tensor?[])">;

// Note: TorchScript does not consider !torch.bool to be a Scalar.
def AnyTorchScalarType : AnyTypeOf<[
Torch_IntType,
Torch_FloatType,
Torch_NumberType,
], "Any Python numeric type compatible with being the scalar type of a tensor (`Scalar`)">;
def AnyTorchScalarType :
Type<CPred<"isValidSubtype($_self, ::mlir::torch::Torch::NumberType::get($_self.getContext()))">,
"Any Python numeric type compatible with being the scalar type of a tensor">;
def AnyTorchOptionalScalarType:
OptionalOf<AnyTorchScalarType, "Optional torch scalar type">;

Expand Down Expand Up @@ -454,6 +470,7 @@ def AnyTorchType : AnyTypeOf<[
Torch_OptionalType,
Torch_StringType,
Torch_TupleType,
Torch_UnionType,
], "Any type that is legal to pass to a Torch kernel">;

def AnyTorchListType : ListOf<[AnyType], "Any Torch list Type">;
Expand Down
18 changes: 18 additions & 0 deletions lib/CAPI/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,24 @@ MlirType torchMlirTorchTupleTypeGet(MlirContext context,
[](MlirType t) { return unwrap(t); }))));
}

//===----------------------------------------------------------------------===//
// torch.union<T1, T2, T3> type.
//===----------------------------------------------------------------------===//

bool torchMlirTypeIsATorchUnion(MlirType t) {
return unwrap(t).isa<Torch::UnionType>();
}

MlirType torchMlirTorchUnionTypeGet(MlirContext context,
intptr_t numContainedTypes,
MlirType const *containedTypes) {
return wrap(Torch::UnionType::get(
unwrap(context),
llvm::to_vector<6>(
llvm::map_range(llvm::makeArrayRef(containedTypes, numContainedTypes),
[](MlirType t) { return unwrap(t); }))));
}

//===----------------------------------------------------------------------===//
// torch.list<T> type.
//===----------------------------------------------------------------------===//
Expand Down
14 changes: 14 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,20 @@ OpFoldResult DerefineOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}

void DerefineOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](DerefineOp op, PatternRewriter &rewriter) {
bool madeChange = false;
for (OpOperand &use : llvm::make_early_inc_range(op->getUses())) {
if (use.getOwner()->hasTrait<OpTrait::AllowsTypeRefinement>()) {
use.set(op.getOperand());
madeChange = true;
}
}
return success(madeChange);
});
}

static OpFoldResult atenIsOrIsNotFoldHelper(Operation *op, bool equalIsTrue) {
Value lhs = op->getOperand(0);
Value rhs = op->getOperand(1);
Expand Down
70 changes: 59 additions & 11 deletions lib/Dialect/Torch/IR/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ bool Torch::isValidSubtype(Type subtype, Type type) {
if (subtype == type)
return true;

// For a UnionType to be a subtype, all of its contained types must be
// subtypes.
if (auto unionType = subtype.dyn_cast<UnionType>()) {
for (auto containedType : unionType.getContainedTypes()) {
if (!isValidSubtype(containedType, type))
return false;
}
return true;
}

if (auto any = type.dyn_cast<AnyType>())
return true;

Expand All @@ -35,6 +45,14 @@ bool Torch::isValidSubtype(Type subtype, Type type) {
return isValidSubtype(subtype, optional.getContainedType()) ||
subtype.isa<Torch::NoneType>();

if (auto unionType = type.dyn_cast<UnionType>()) {
for (auto containedType : unionType.getContainedTypes()) {
if (isValidSubtype(subtype, containedType))
return true;
}
return false;
}

if (auto tuple = type.dyn_cast<Torch::TupleType>()) {
if (!subtype.isa<Torch::TupleType>())
return false;
Expand Down Expand Up @@ -63,36 +81,66 @@ bool Torch::isValidSubtype(Type subtype, Type type) {
}

//===----------------------------------------------------------------------===//
// TupleType
// Helpers for TupleType and UnionType
//===----------------------------------------------------------------------===//

Type Torch::TupleType::parse(AsmParser &parser) {
MLIRContext *context = parser.getContext();
// Parse the `<T1, T2, T3>` of a type such as `!torch.tuple<T1, T2, T3>`.
static Optional<SmallVector<Type>>
parseMultipleContainedTypes(AsmParser &parser) {
if (parser.parseLess())
return Type();
if (!parser.parseOptionalGreater())
return Torch::TupleType::get(context, {});
return None;

SmallVector<Type> containedTypes;
if (!parser.parseOptionalGreater())
return containedTypes;
do {
Type containedType = parseTorchDialectType(parser);
if (!containedType)
return Type();
return None;
containedTypes.push_back(containedType);
} while (!parser.parseOptionalComma());
if (parser.parseGreater())
return Type();
return Torch::TupleType::get(context, containedTypes);
return None;
return containedTypes;
}

void Torch::TupleType::print(::mlir::AsmPrinter &printer) const {
static void printMultipleContainedTypes(AsmPrinter &printer,
ArrayRef<Type> containedTypes) {
printer << "<";
llvm::interleaveComma(getContainedTypes(), printer, [&](Type type) {
llvm::interleaveComma(containedTypes, printer, [&](Type type) {
printTorchDialectType(type, printer);
});
printer << ">";
}

//===----------------------------------------------------------------------===//
// TupleType
//===----------------------------------------------------------------------===//

Type Torch::TupleType::parse(AsmParser &parser) {
if (auto containedTypes = parseMultipleContainedTypes(parser))
return TupleType::get(parser.getContext(), *containedTypes);
return Type();
}

void Torch::TupleType::print(AsmPrinter &printer) const {
printMultipleContainedTypes(printer, getContainedTypes());
}

//===----------------------------------------------------------------------===//
// UnionType
//===----------------------------------------------------------------------===//

Type Torch::UnionType::parse(AsmParser &parser) {
if (auto containedTypes = parseMultipleContainedTypes(parser))
return UnionType::get(parser.getContext(), *containedTypes);
return Type();
}

void Torch::UnionType::print(AsmPrinter &printer) const {
printMultipleContainedTypes(printer, getContainedTypes());
}

//===----------------------------------------------------------------------===//
// BaseTensorType
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,15 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
return torchMlirTorchTupleTypeGet(context, containedTypes.size(),
containedTypes.data());
}
case TypeKind::UnionType: {
std::vector<MlirType> containedTypes;
for (const c10::TypePtr &type :
torchType->cast<c10::UnionType>()->containedTypes()) {
containedTypes.push_back(getMlirTypeFromTorchType(loc, type));
}
return torchMlirTorchUnionTypeGet(context, containedTypes.size(),
containedTypes.data());
}
case TypeKind::ListType: {
return torchMlirTorchListTypeGet(getMlirTypeFromTorchType(
loc, torchType->cast<c10::ListType>()->getElementType()));
Expand Down
5 changes: 2 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
--pre

numpy
# TODO: Fix for latest PyTorch.
torch==1.12.0.dev20220328+cpu
torchvision==0.13.0.dev20220328+cpu
torch
torchvision

# Build requirements.
pybind11
Expand Down
16 changes: 16 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,22 @@ func @torch.derefine$of_unchecked_cast(%arg0: !torch.optional<int>) -> !torch.op
return %1 : !torch.optional<int>
}

// CHECK-LABEL: func @torch.derefine$use_allows_type_refinement(
// CHECK-SAME: %{{.*}}: !torch.int) -> (!torch.vtensor, !torch.optional<int>) {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[DEREFINED:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<int>
// For the use that allows type refinement, we replace it with the refined value.
// CHECK: %[[ARANGE:.*]] = torch.aten.arange.start %{{.*}}, %{{.*}}, %[[NONE]], %{{.*}}, %{{.*}}, %{{.*}} : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor
// For the use that does not allow type refinement, don't replace.
// CHECK: return %[[ARANGE]], %[[DEREFINED]] : !torch.vtensor, !torch.optional<int>
func @torch.derefine$use_allows_type_refinement(%arg0: !torch.int) -> (!torch.vtensor, !torch.optional<int>) {
%none = torch.constant.none
%optional = torch.derefine %none : !torch.none to !torch.optional<int>
%ret = torch.aten.arange.start %arg0, %arg0, %optional, %none, %none, %none: !torch.int, !torch.int, !torch.optional<int>, !torch.none, !torch.none, !torch.none -> !torch.vtensor
return %ret, %optional : !torch.vtensor, !torch.optional<int>
}


// CHECK-LABEL: func @torch.tensor_static_info_cast$downcast_first(
// CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: return %[[T]] : !torch.tensor
Expand Down
12 changes: 12 additions & 0 deletions test/Dialect/Torch/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ func private @tuple.one_element() -> !torch.tuple<tensor>
// CHECK: @tuple.two_elements() -> !torch.tuple<tensor, tensor>
func private @tuple.two_elements() -> !torch.tuple<tensor, tensor>

// CHECK: @union.empty() -> !torch.union<>
func private @union.empty() -> !torch.union<>
// CHECK: @union.one_element() -> !torch.union<tensor>
func private @union.one_element() -> !torch.union<tensor>
// CHECK: @union.two_elements() -> !torch.union<tensor, tensor>
func private @union.two_elements() -> !torch.union<tensor, tensor>

// CHECK: @dict() -> !torch.dict<str, tensor>
func private @dict() -> !torch.dict<str, tensor>

Expand Down Expand Up @@ -134,3 +141,8 @@ func @shape_calculations(%arg0: !torch.vtensor) -> !torch.vtensor {
} : !torch.vtensor
return %0 : !torch.vtensor
}

func @number_type_subtypes(%arg0: !torch.tensor, %arg1: !torch.list<int>, %arg2: !torch.union<float, int>) {
%0 = torch.aten.constant_pad_nd %arg0, %arg1, %arg2 : !torch.tensor, !torch.list<int>, !torch.union<float, int> -> !torch.tensor
return
}
24 changes: 24 additions & 0 deletions test/python/importer/jit_ir/node_import/union.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See LICENSE.pytorch for license information.

from typing import Union

import torch
from torch_mlir.dialects.torch.importer.jit_ir import ModuleBuilder

# RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s

mb = ModuleBuilder()

# CHECK-LABEL: func @__torch__.f(
# CHECK-SAME: %{{.*}}: !torch.union<float, int>) -> !torch.none {

@mb.import_function
@torch.jit.script
def f(x: Union[int, float]):
return

assert isinstance(f, torch.jit.ScriptFunction)
mb.module.operation.print()
print()

0 comments on commit 140babd

Please sign in to comment.