Skip to content

Commit

Permalink
[onnx] Import onnx import to pass remaining tests (llvm#2951)
Browse files Browse the repository at this point in the history
Finish supporting importing the vast majority of `onnx` operations. This
includes:
- region support
- region value inherentance
- `torch.string` support
- `torch.list` support
- `torch.optional` support
  • Loading branch information
rsuderman authored Feb 28, 2024
1 parent 6f3d62a commit e48fe45
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 180 deletions.
13 changes: 12 additions & 1 deletion include/torch-mlir/Dialect/Torch/IR/TorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -843,12 +843,23 @@ def Torch_OperatorOp : Torch_Op<"operator", [

let arguments = (ins StrAttr:$name, Variadic<AnyTorchType>:$operands);
let results = (outs Variadic<AnyTorchType>:$results);
let regions = (region VariadicRegion<AnyRegion>:$regions);

let assemblyFormat = [{
$name `(` $operands `)` attr-dict `:` functional-type($operands, $results)
$name `(` $operands `)` attr-dict `:` functional-type($operands, $results) $regions
}];
}

def Torch_OperatorTerminatorOp : Torch_Op<"operator_terminator", [Terminator,
HasParent<"::mlir::torch::Torch::OperatorOp">]> {
let summary = "Implicit terminator for torch.operator";

let arguments = (ins Variadic<AnyTorchType>:$operands);
let results = (outs);

let assemblyFormat = "$operands attr-dict `:` type($operands)";
}

def Torch_LinearParamsCreateOp : Torch_Op<"linear_params.create", [
AllowsTypeRefinement,
AllowedInModuleInitializer,
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(

c = rewriter
.create<Torch::OperatorOp>(binder.getLoc(), cTy, newOperands,
newAttributes)
newAttributes,
binder.op->getRegions().size())
.getResult(0);

Value outScale = rewriter.create<Torch::AtenMulFloatOp>(
Expand Down
6 changes: 6 additions & 0 deletions lib/Dialect/Torch/IR/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ static bool isValidTorchDtype(Type dtype) {
// Builtin floating point types.
if (dtype.isa<Float16Type, BFloat16Type, Float32Type, Float64Type>())
return true;
if (dtype.isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, Float8E4M3B11FNUZType>())
return true;

if (dtype.isa<Torch::StringType>())
return true;
// Builtin integer types.
if (IntegerType type = dtype.dyn_cast<IntegerType>()) {
if (type.isSignless() && type.getWidth() == 1)
Expand Down
22 changes: 11 additions & 11 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.operator \"aten.ge\"(%arg0, %int0) : (!torch.union<float, int>, !torch.int) -> !torch.bool\n"
" %0 = torch.operator \"aten.ge\"(%arg0, %int0) : (!torch.union<float, int>, !torch.int) -> !torch.bool \n"
" torch.prim.If %0 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
Expand All @@ -138,14 +138,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.operator \"aten.ge\"(%arg1, %int0) : (!torch.union<float, int>, !torch.int) -> !torch.bool\n"
" %0 = torch.operator \"aten.ge\"(%arg1, %int0) : (!torch.union<float, int>, !torch.int) -> !torch.bool \n"
" torch.prim.If %0 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %1 = torch.operator \"aten.ge\"(%arg1, %arg0) : (!torch.union<float, int>, !torch.union<float, int>) -> !torch.bool\n"
" %1 = torch.operator \"aten.ge\"(%arg1, %arg0) : (!torch.union<float, int>, !torch.union<float, int>) -> !torch.bool \n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
Expand All @@ -162,16 +162,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.operator \"aten.ne\"(%arg2, %int0) : (!torch.union<float, int>, !torch.int) -> !torch.bool\n"
" %0 = torch.operator \"aten.ne\"(%arg2, %int0) : (!torch.union<float, int>, !torch.int) -> !torch.bool \n"
" torch.prim.If %0 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %1 = torch.operator \"aten.lt\"(%arg2, %int0) : (!torch.union<float, int>, !torch.int) -> !torch.bool\n"
" %1 = torch.operator \"aten.lt\"(%arg2, %int0) : (!torch.union<float, int>, !torch.int) -> !torch.bool \n"
" torch.prim.If %1 -> () {\n"
" %6 = torch.operator \"aten.ge\"(%arg0, %arg1) : (!torch.union<float, int>, !torch.union<float, int>) -> !torch.bool\n"
" %6 = torch.operator \"aten.ge\"(%arg0, %arg1) : (!torch.union<float, int>, !torch.union<float, int>) -> !torch.bool \n"
" torch.prim.If %6 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
Expand All @@ -180,7 +180,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" torch.prim.If.yield\n"
" } else {\n"
" %6 = torch.operator \"aten.ge\"(%arg1, %arg0) : (!torch.union<float, int>, !torch.union<float, int>) -> !torch.bool\n"
" %6 = torch.operator \"aten.ge\"(%arg1, %arg0) : (!torch.union<float, int>, !torch.union<float, int>) -> !torch.bool \n"
" torch.prim.If %6 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
Expand Down Expand Up @@ -5507,12 +5507,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %16 = torch.aten.__getitem__.t %11, %int0 : !torch.list<float>, !torch.int -> !torch.float\n"
" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float\n"
" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n"
" %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n"
" %19 = torch.aten.append.t %1, %18 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" %20 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %21 = torch.aten.__getitem__.t %11, %int1 : !torch.list<float>, !torch.int -> !torch.float\n"
" %22 = torch.operator \"aten.mul.int_float\"(%20, %21) : (!torch.int, !torch.float) -> !torch.float\n"
" %22 = torch.operator \"aten.mul.int_float\"(%20, %21) : (!torch.int, !torch.float) -> !torch.float \n"
" %23 = torch.aten.Int.float %22 : !torch.float -> !torch.int\n"
" %24 = torch.aten.append.t %1, %23 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.If.yield\n"
Expand Down Expand Up @@ -7246,7 +7246,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %3 = torch.prim.If %2 -> (!torch.list<int>) {\n"
" %5 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>\n"
" %6 = torch.aten.sub.int %1, %0 : !torch.int, !torch.int -> !torch.int\n"
" %7 = torch.operator \"aten.mul.left_t\"(%5, %6) : (!torch.list<int>, !torch.int) -> !torch.list<int>\n"
" %7 = torch.operator \"aten.mul.left_t\"(%5, %6) : (!torch.list<int>, !torch.int) -> !torch.list<int> \n"
" %8 = torch.aten.add.t %7, %arg1 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
" torch.prim.If.yield %8 : !torch.list<int>\n"
" } else {\n"
Expand Down Expand Up @@ -9304,7 +9304,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @__torch__.hacky_get_unknown_dimension_size() -> !torch.int {\n"
" %0 = torch.prim.CreateObject !torch.nn.Module<\"__torch__.DummyClassType\">\n"
" %1 = torch.prim.CallMethod %0[\"__init__\"] () : !torch.nn.Module<\"__torch__.DummyClassType\">, () -> !torch.none\n"
" %2 = torch.operator \"prim.id\"(%0) : (!torch.nn.Module<\"__torch__.DummyClassType\">) -> !torch.int\n"
" %2 = torch.operator \"prim.id\"(%0) : (!torch.nn.Module<\"__torch__.DummyClassType\">) -> !torch.int \n"
" return %2 : !torch.int\n"
" }\n"
" func.func @__torch__.DummyClassType.__init__(%arg0: !torch.nn.Module<\"__torch__.DummyClassType\">) -> !torch.none {\n"
Expand Down
Loading

0 comments on commit e48fe45

Please sign in to comment.