Skip to content

Commit

Permalink
[Torch Dialect] support decomposition of aten.linspace (llvm#3006)
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyunqu authored Mar 14, 2024
1 parent 43c6996 commit 870e63b
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 0 deletions.
16 changes: 16 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8334,6 +8334,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %5 = call @__torch__.torch.jit._shape_functions.arange_end(%0, %1, %2, %3, %4) : (!torch.union<float, int>, !torch.any, !torch.any, !torch.any, !torch.any) -> !torch.list<int>\n"
" return %5 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.linspace\"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
" %0 = torch.prim.ListConstruct %arg2 : (!torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.add.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -12568,6 +12572,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.linspace\"(%arg0: !torch.number, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__is__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" } else {\n"
" %2 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %2 : !torch.int\n"
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.normal_functional\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
Expand Down
73 changes: 73 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6331,6 +6331,78 @@ class DecomposeAtenRandOp : public OpRewritePattern<AtenRandOp> {
};
} // namespace

namespace {
class DecomposeAtenLinspaceOp : public OpRewritePattern<AtenLinspaceOp> {
public:
using OpRewritePattern<AtenLinspaceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenLinspaceOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
MLIRContext *context = getContext();

auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
Value none = rewriter.create<ConstantNoneOp>(loc);
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));

Value addStart;
int64_t steps;
if (matchPattern(op.getSteps(), m_TorchConstantInt(&steps)) && steps == 1) {
// specically handle steps == 1
Value arange = rewriter.create<AtenArangeStartOp>(
loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(),
op.getDevice(), op.getPinMemory());
addStart = rewriter.create<AtenAddScalarOp>(loc, baseType, arange,
op.getStart(), one);
} else {
// handle steps != 1 or dynamic steps
Value neOrNot = rewriter.create<AtenNeIntOp>(loc, op.getSteps(), one);
rewriter.create<RuntimeAssertOp>(
loc, neOrNot,
rewriter.getStringAttr("linspace's dynamic steps must not be 1"));
// create arange: [0, ..., steps - 1]
Value arange = rewriter.create<AtenArangeStartOp>(
loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(),
op.getDevice(), op.getPinMemory());
// calculate (end - start) / (steps - 1)
Value sub;
if (op.getEnd().getType().isa<Torch::FloatType>() ||
op.getStart().getType().isa<Torch::FloatType>()) {
sub = rewriter.create<AtenSubOp>(loc, Torch::FloatType::get(context),
op.getEnd(), op.getStart());
} else {
sub = rewriter.create<AtenSubIntOp>(loc, op.getEnd(), op.getStart());
}
Value div = rewriter.create<AtenDivOp>(
loc, sub, rewriter.create<AtenSubIntOp>(loc, op.getSteps(), one));
// calculate [0, ..., steps - 1] * ((end - start) / (steps - 1)) + start
Value mulScalar =
rewriter.create<AtenMulScalarOp>(loc, baseType, arange, div);
addStart = rewriter.create<AtenAddScalarOp>(loc, baseType, mulScalar,
op.getStart(), one);
}
// to dtype
Value result;
if (!op.getDtype().getType().isa<Torch::NoneType>()) {
result = rewriter.create<AtenToDtypeOp>(
loc, op.getType(), addStart, op.getDtype(), /*non_blocking=*/falseVal,
/*copy=*/falseVal, /*memory_format=*/none);
} else {
Value f32Type = rewriter.create<ConstantIntOp>(
loc, (int)torch_upstream::ScalarType::Float);
result = rewriter.create<AtenToDtypeOp>(
loc, op.getType(), addStart, f32Type, /*non_blocking=*/falseVal,
/*copy=*/falseVal, /*memory_format=*/none);
}
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace

namespace {
class DecomposeAtenVarMeanOp : public OpRewritePattern<AtenVarMeanOp> {
public:
Expand Down Expand Up @@ -7216,6 +7288,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLinspaceOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns);
addPatternIfTargetOpIsIllegal<
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenConvTranspose2dInputOp>();
target.addIllegalOp<AtenArangeOp>();
target.addIllegalOp<AtenArangeStartOp>();
target.addIllegalOp<AtenLinspaceOp>();
target.addIllegalOp<AtenArgmaxOp>();
target.addIllegalOp<AtenArgminOp>();
target.addIllegalOp<AtenSquareOp>();
Expand Down
8 changes: 8 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,11 @@
"ZerosModuleFloat3D_basic",
"ZerosModuleInt2D_basic",
"ZerosModuleInt3D_basic",
"LinspaceDtypeModule_basic",
"LinspaceEmptyModule_basic",
"LinspaceModule_basic",
"LinspaceOneSizeModule_basic",
"LinspaceTwoSizeModule_basic",
}

STABLEHLO_CRASHING_SET = {
Expand Down Expand Up @@ -1260,6 +1265,9 @@
"_LogSoftmaxModuleStable_basic",
"_LogSoftmaxModule_basic",
"_SoftmaxModule_basic",
"LinspaceModule_basic",
"LinspaceOneSizeModule_basic",
"LinspaceTwoSizeModule_basic",
}

MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,9 @@ def aten〇arange〇start〡shape(start: float, end: float, dtype: Optional[int]
def aten〇arange〡shape(end: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return upstream_shape_functions.arange_end(end, dtype, layout, device, pin_memory)

def aten〇linspace〡shape(start: float, end: float, steps: int, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return [steps]

@check_shape_function([
Invocation(TensorOfShape(2, 3), TensorOfShape(2, 3)), # Basic case.
Invocation(TensorOfShape(2, 3), TensorOfShape(3)), # Rank broadcasting.
Expand Down Expand Up @@ -4248,6 +4251,16 @@ def aten〇randn〡dtype(size: List[int], dtype: Optional[int] = None, layout: O
assert not is_integer_dtype(dtype)
return dtype

@check_dtype_function([Invocation(start=1, end=10, steps=9),
Invocation(start=1, end=10, steps=9, dtype=torch.int32),
Invocation(start=1, end=10, steps=9, dtype=torch.double),
Invocation(start=1, end=10, steps=9, dtype=torch.complex64),
Invocation(start=1, end=10, steps=9, dtype=torch.complex128)])
def aten〇linspace〡dtype(start: Union[int, float, complex], end: Union[int, float, complex], steps: int, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
if dtype is None:
return torch.float32
return dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(
num_of_tensors=1,
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}))
Expand Down
82 changes: 82 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def forward(self):
def ArangeZeroElementOutputModule_basic(module, tu: TestUtils):
module.forward()

# ==============================================================================

class ArangeStartIntModule(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -130,6 +131,7 @@ def forward(self):
def ArangeNegativeStartFloatModule_basic(module, tu: TestUtils):
module.forward()

# ==============================================================================

class ArangeStartStepIntModule(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -198,6 +200,7 @@ def forward(self):
def ArangeStartNegativeStepFloatModule_basic(module, tu: TestUtils):
module.forward()

# ==============================================================================

class ArangeDtypeFloatModule(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -232,6 +235,7 @@ def forward(self):
def ArangeDtypeIntModule_basic(module, tu: TestUtils):
module.forward()

# ==============================================================================

class ArangeFalsePinMemoryModule(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -298,3 +302,81 @@ def forward(self, x):
@register_test_case(module_factory=lambda: ArangeStartOutDtypeModule())
def ArangeStartOutDtypeModule_basic(module, tu: TestUtils):
module.forward(torch.zeros(12).to(torch.int64))

# ==============================================================================

class LinspaceModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
])
def forward(self):
return torch.linspace(-10.1, 10.1, 10)

@register_test_case(module_factory=lambda: LinspaceModule())
def LinspaceModule_basic(module, tu: TestUtils):
module.forward()

class LinspaceDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
])
def forward(self):
return torch.linspace(-10.1, 10.1, 10, dtype=torch.int64)


@register_test_case(module_factory=lambda: LinspaceDtypeModule())
def LinspaceDtypeModule_basic(module, tu: TestUtils):
module.forward()

class LinspaceEmptyModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
])
def forward(self):
return torch.linspace(-10.1, 10.1, 0)

@register_test_case(module_factory=lambda: LinspaceEmptyModule())
def LinspaceEmptyModule_basic(module, tu: TestUtils):
module.forward()

class LinspaceOneSizeModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
])
def forward(self):
return torch.linspace(-10.1, 10.1, 1)

@register_test_case(module_factory=lambda: LinspaceOneSizeModule())
def LinspaceOneSizeModule_basic(module, tu: TestUtils):
module.forward()

class LinspaceTwoSizeModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
])
def forward(self):
return torch.linspace(-10.1, 10.1, 2)

@register_test_case(module_factory=lambda: LinspaceTwoSizeModule())
def LinspaceTwoSizeModule_basic(module, tu: TestUtils):
module.forward()

0 comments on commit 870e63b

Please sign in to comment.