Skip to content

Commit

Permalink
build: manually update PyTorch version (#3919)
Browse files Browse the repository at this point in the history
This commit sets the PyTorch and TorchVision version to nightly release
2024-12-16.

This commit adds the support for `aten.rrelu_with_noise_functional` op
by decomposing it. And, also updates the existing decomposition of
`aten.rrelu_with_noise` op by decomposing it to the newly added
`aten.rrelu_with_noise_functional` op. It also updates the e2e tests for
`aten.rrelu_with_noise` op by replacing it with its functional variant
which is added here:
pytorch/pytorch@f85e238
and which captures the noise mutation which was earlier a reason for the
test failures during the training mode.

This commit also removes the newly passing tests from the xfail_sets.

---------

Signed-off-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
  • Loading branch information
vivekkhandelwal1 authored Dec 20, 2024
1 parent 02fa411 commit a6179c0
Show file tree
Hide file tree
Showing 11 changed files with 121 additions and 87 deletions.
33 changes: 30 additions & 3 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,7 @@ def Torch_AtenRrelu_Op : Torch_Op<"aten.rrelu_", [
}

def Torch_AtenRreluWithNoiseOp : Torch_Op<"aten.rrelu_with_noise", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
let arguments = (ins
Expand Down Expand Up @@ -17519,6 +17517,35 @@ def Torch_AtenRreluWithNoiseBackwardOp : Torch_Op<"aten.rrelu_with_noise_backwar
}];
}

def Torch_AtenRreluWithNoiseFunctionalOp : Torch_Op<"aten.rrelu_with_noise_functional", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::rrelu_with_noise_functional : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$noise,
AnyTorchScalarType:$lower,
AnyTorchScalarType:$upper,
Torch_BoolType:$training,
AnyTorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchOptionalTensorType:$result0,
AnyTorchOptionalTensorType:$noise_out
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRreluWithNoiseFunctionalOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 2);
}
void AtenRreluWithNoiseFunctionalOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 2);
}
}];
}

def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
56 changes: 17 additions & 39 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7304,6 +7304,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise_functional\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.tuple<list<int>, list<int>> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" %1 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list<int>) -> !torch.list<int>\n"
" %2 = torch.prim.TupleConstruct %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %2 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -12599,17 +12605,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number, %arg3: !torch.bool, %arg4: !torch.any) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %true = torch.constant.bool true\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" torch.prim.If.yield %3 : !torch.bool\n"
" }\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
Expand All @@ -12618,46 +12622,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.int {\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise_functional\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.tuple<int, int> {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %true = torch.constant.bool true\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" torch.prim.If.yield %7 : !torch.bool\n"
" }\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %5 = torch.prim.If %4 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" torch.prim.If.yield %7 : !torch.bool\n"
" }\n"
" torch.prim.If %5 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %6 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %6 -> () {\n"
" %2 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" %3 = torch.prim.TupleConstruct %0#1, %1#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %3 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
Expand Down
39 changes: 32 additions & 7 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3791,11 +3791,7 @@ class DecomposeAtenRreluOp : public OpRewritePattern<AtenRreluOp> {
// Create a uniform random op with low and high set to `lower` and
// `upper`, respectively.
Value none = rewriter.create<ConstantNoneOp>(loc);
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
loc, resType, self, constantZeroFloat, /*dtype=*/none,
/*layout=*/none,
/*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none);
alpha = rewriter.create<AtenUniformOp>(loc, resType, emptyTensor,
alpha = rewriter.create<AtenUniformOp>(loc, resType, self,
/*from=*/lower, /*to=*/upper,
/*generator=*/none);
} else {
Expand Down Expand Up @@ -3840,6 +3836,33 @@ class DecomposeAtenRreluWithNoiseOp
Value lower = op.getLower();
Value upper = op.getUpper();
auto resType = cast<BaseTensorType>(op.getType());
Value cstNone = rewriter.create<ConstantNoneOp>(loc);
Value cstFalse =
rewriter.create<ConstantBoolOp>(loc, rewriter.getBoolAttr(false));
Value result =
rewriter
.create<AtenRreluWithNoiseFunctionalOp>(
loc, resType, self, noise, lower, upper, cstFalse, cstNone)
->getResult(0);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace

namespace {
class DecomposeAtenRreluWithNoiseFunctionalOp
: public OpRewritePattern<AtenRreluWithNoiseFunctionalOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenRreluWithNoiseFunctionalOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
Value noise = op.getNoise();
Value lower = op.getLower();
Value upper = op.getUpper();
auto resType = cast<BaseTensorType>(op.getResultTypes()[0]);
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
Expand Down Expand Up @@ -3885,7 +3908,7 @@ class DecomposeAtenRreluWithNoiseOp
rewriter.getI1Type());
Value oneTensor =
createRank0Tensor(rewriter, loc, resType, constantOneFloat);
Value not_positive = rewriter.create<AtenLtScalarOp>(
Value not_positive = rewriter.create<AtenLeScalarOp>(
loc, boolResType, self, constantZeroFloat);
noise = rewriter.create<AtenWhereSelfOp>(loc, resType, not_positive,
alpha, oneTensor);
Expand All @@ -3897,7 +3920,7 @@ class DecomposeAtenRreluWithNoiseOp
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, scaledSelf);
Value rreluOutput = rewriter.create<AtenAddTensorOp>(
loc, resType, positiveOutput, negativeOutput, constantOneFloat);
rewriter.replaceOp(op, rreluOutput);
rewriter.replaceOp(op, {rreluOutput, noise});
return success();
}
};
Expand Down Expand Up @@ -11568,6 +11591,8 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenPreluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseFunctionalOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseBackwardOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCeluOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenPreluOp>();
target.addIllegalOp<AtenRreluOp>();
target.addIllegalOp<AtenRreluWithNoiseOp>();
target.addIllegalOp<AtenRreluWithNoiseFunctionalOp>();
target.addIllegalOp<AtenRreluWithNoiseBackwardOp>();
target.addIllegalOp<AtenCeluOp>();
target.addIllegalOp<AtenToDtypeLayoutOp>();
Expand Down
21 changes: 0 additions & 21 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,6 @@
"AtenIntBoolOpConstTrueModule_basic",
"AtenIntBoolOpModule_basic",
"AtenIntMM_basic",
"AtenItemFpOpModule_basic",
"AtenNonzero1DDynamicModule_basic", # no lowering for torch.aten.sym_constrain_range_for_size
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
Expand All @@ -425,7 +424,6 @@
"CumsumModule_basic",
"CumprodModule_basic",
"DeformConv2D_basic",
"DivFloatModule_basic",
"DivIntModule_basic",
"ElementwiseDequantizePerChannelModule_basic",
"ElementwiseDequantizePerTensorModule_basic",
Expand All @@ -439,7 +437,6 @@
"IntFloatModule_basic",
"IntImplicitModule_basic",
"LenStrModule_basic",
"MulFloatModule_basic",
"NativeGroupNormBackwardModule_basic",
"NeFloatIntModule_basic",
"NllLossModuleBackward1DMeanWeight_basic",
Expand All @@ -464,15 +461,11 @@
"QuantizedSingleLayer_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
"ScalarImplicitFloatModule_basic",
"SplitDimDynamicModule_basic",
"SplitDimStaticModule_basic",
"SqrtIntModule_basic",
"SubFloatModule_basic",
"TensorToBoolZeroRank_basic",
"TensorToBool_basic",
"TensorToFloatZeroRank_basic",
"TensorToFloat_basic",
"ThresholdBackward2dMixedModule_basic",
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
"UpSampleNearest2dDynamicFactor_basic",
Expand Down Expand Up @@ -507,9 +500,6 @@
"MeshgridIndexingIJ_basic",
"MeshgridIndexingXY_basic",
"Meshgrid_basic",
# RuntimeError: cannot mutate tensors with frozen storage
"ElementwiseRreluWithNoiseTrainModule_basic",
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
"ElementwiseSignbitModule_basic",
"ElementwiseCopysignModule_basic",
"BernoulliFloatModule_basic",
Expand All @@ -527,9 +517,6 @@
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleSumAllDims_basic",
"Aten_TrilinearModuleSumdims_basic",
# torch export: RuntimeError: cannot mutate tensors with frozen storage
"ElementwiseRreluWithNoiseTrainModule_basic",
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
}

FX_IMPORTER_STABLEHLO_XFAIL_SET = {
Expand Down Expand Up @@ -934,9 +921,6 @@
"UpSampleNearest2dStaticFactor_basic",
"UpSampleNearest2dStaticSize_basic",
"UpSampleNearest2d_basic",
# RuntimeError: cannot mutate tensors with frozen storage
"ElementwiseRreluWithNoiseTrainModule_basic",
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
"BernoulliFloatModule_basic",
"UniformModule_basic",
"UniformStaticShapeModule_basic",
Expand All @@ -961,9 +945,6 @@
"Aten_TrilinearModuleSumdims_basic",
"Aten_TrilinearModuleSumAllDims_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
# torch export: RuntimeError: cannot mutate tensors with frozen storage
"ElementwiseRreluWithNoiseTrainModule_basic",
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
"CrossEntropyLossModule_basic",
"CrossEntropyLossNoReductionModule_basic",
}
Expand Down Expand Up @@ -3459,8 +3440,6 @@
"ElementwiseSignbitModule_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"ElementwiseRreluWithNoiseTrainModule_basic",
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
"MaxPool3dEmptyStrideStaticModule_basic",
"MaxPool3dLargeDatadModule_basic",
"MaxPool3dModuleRandomSimple_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,9 @@ def aten〇rrelu〡shape(self: List[int], lower: float = 0.125, upper: float = 0
def aten〇rrelu_with_noise〡shape(self: List[int], noise: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇rrelu_with_noise_functional〡shape(self: List[int], noise: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> Tuple[List[int], List[int]]:
return upstream_shape_functions.unary(self), upstream_shape_functions.unary(noise)

def aten〇selu〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

Expand Down Expand Up @@ -3472,21 +3475,25 @@ def aten〇celu〡dtype(self_rank_dtype: Tuple[int, int], alpha: Union[int, floa
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool, *all_integer_dtypes()}))
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def aten〇rrelu〡dtype(self_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> int:
self_rank, self_dtype = self_rank_dtype
assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype)
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2, error_types={torch.bool, *all_integer_dtypes()}))
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2))
def aten〇rrelu_with_noise〡dtype(self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> int:
self_rank, self_dtype = self_rank_dtype
noise_rank, noise_dtype = noise_rank_dtype
assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype)
assert is_float_dtype(noise_dtype) or is_complex_dtype(noise_dtype)
assert self_rank == noise_rank
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2))
def aten〇rrelu_with_noise_functional〡dtype(self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> Tuple[int, int]:
self_rank, self_dtype = self_rank_dtype
noise_rank, noise_dtype = noise_rank_dtype
assert self_rank == noise_rank
return self_dtype, noise_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}))
def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,9 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)"
)
emit(
"aten::rrelu_with_noise_functional : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor, Tensor)"
)

# quantized ops
emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)")
Expand Down
Loading

0 comments on commit a6179c0

Please sign in to comment.