From 96fcde4d77d8f77ecb0fe58061a671f258870924 Mon Sep 17 00:00:00 2001 From: JianzheXiao Date: Sat, 9 Dec 2023 20:30:37 -0800 Subject: [PATCH] [Torch Dialect] Support Einsum Op (#2230) As title, support torch.aten.einsum op Right now only support Static Shape, because of the known issue, the fixed solution is here: https://github.com/llvm/torch-mlir/pull/2154 Co-authored-by: Jiawei Wu [wujiawei.aml@bytedance.com](mailto:wujiawei.aml@bytedance.com) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 ++ .../Transforms/AbstractInterpLibrary.cpp | 27 ++ .../Torch/Transforms/DecomposeComplexOps.cpp | 425 ++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 6 + .../build_tools/abstract_interp_lib_gen.py | 13 + .../build_tools/torch_ods_gen.py | 1 + .../test_suite/reshape_like.py | 56 +++ 8 files changed, 554 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 233a71621d10..7ccd0449c66b 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8447,6 +8447,31 @@ def Torch_AtenOneHotOp : Torch_Op<"aten.one_hot", [ }]; } +def Torch_AtenEinsumOp : Torch_Op<"aten.einsum", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::einsum : (str, Tensor[], int[]?) -> (Tensor)`"; + let arguments = (ins + Torch_StringType:$equation, + AnyTorchListOfTensorType:$tensors, + AnyTorchOptionalListOfTorchIntType:$path + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenEinsumOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenEinsumOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 102fb2bbfd88..7df929e73872 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -11321,6 +11321,33 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" " return %5 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.einsum\"(%arg0: !torch.str, %arg1: !torch.list>, %arg2: !torch.optional>) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list>\n" +" %1 = torch.prim.ListConstruct : () -> !torch.list\n" +" %2 = torch.aten.len.t %arg1 : !torch.list> -> !torch.int\n" +" %3 = torch.aten.ne.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.len.t %arg1 : !torch.list> -> !torch.int\n" +" torch.prim.Loop %4, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %6 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list>, !torch.int -> !torch.tuple\n" +" %7:2 = torch.prim.TupleUnpack %6 : !torch.tuple -> !torch.int, !torch.int\n" +" %8 = torch.aten.append.t %0, %7#0 : !torch.list>, !torch.int -> !torch.list>\n" +" %9 = torch.aten.append.t %1, %7#1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten._shape_as_tensor\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " return %int4 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index b7b2c26709ca..679af29d5db6 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -187,6 +187,358 @@ static SmallVector computeDimsOrderForMoveDim(int64_t srcDimInt, return dimsOrder; } +static bool parseEquation(const std::string &equation, + SmallVector> &inputTokens, + SmallVector &resultTokens) { + SmallVector inputToken; + size_t index = 0; + enum EquationVariable { kIsInput, kIsResult }; + EquationVariable currentVariable = kIsInput; + while (index < equation.size()) { + if (std::isalpha(equation[index])) { + if (currentVariable == kIsInput) { + inputToken.push_back(equation[index]); + } else { + resultTokens.push_back(equation[index]); + } + } else if (equation[index] == ',') { + inputTokens.push_back(inputToken); + inputToken.clear(); + } else if ((index < (equation.size() - 1)) && + (equation.substr(index, 2).find("->") != std::string::npos)) { + inputTokens.push_back(inputToken); + inputToken.clear(); + currentVariable = kIsResult; + index++; + } else { + return false; + } + index++; + } + return true; +} + +// [*batchingDims, *lhsOtherDims, *lhsReduceDims, *lhsContractingDims] => +// [batchingDimsProd, lhsOtherDimsProd, lhsContractingDimsProd] +static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc, + Value input, int64_t batchDimsLength, + int64_t contractingDimsLength, + int64_t otherDimsLength, + int64_t reduceDimsLength, bool isLhs) { + auto inputType = input.getType().cast(); + auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength + + reduceDimsLength; + SmallVector inputShapeTensor; + for (auto i = 0; i < inputRank; ++i) { + inputShapeTensor.emplace_back(rewriter.create( + loc, input, + rewriter.create(loc, + rewriter.getI64IntegerAttr(i)))); + } + + SmallVector outShapeTensor; + Value constOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + auto dimOffset = 0; + + auto appendDims = [&](int64_t dimLength) { + Value prod = constOne; + for (auto i = 0; i < dimLength; ++i) { + prod = rewriter.create(loc, prod, + inputShapeTensor[i + dimOffset]); + } + outShapeTensor.emplace_back(prod); + dimOffset += dimLength; + }; + + appendDims(batchDimsLength); + if (!isLhs) + appendDims(contractingDimsLength); + appendDims(otherDimsLength + reduceDimsLength); + if (isLhs) + appendDims(contractingDimsLength); + + auto outShapeValue = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(input.getContext())), + outShapeTensor); + + auto outType = inputType.getWithSizesAndDtype(std::nullopt, + inputType.getOptionalDtype()); + return rewriter.create(loc, outType, input, + outShapeValue); +} + +// classify every dim token into different categories. Note that although we +// parse out reduce dims, we delay their execution until +// `performLastPermuteAndReduce`. +static void parseDimTokens( + SmallVector &lhsTokens, SmallVector &rhsTokens, + SmallVector &finalResultTokens, SmallVector &contractingDims, + SmallVector &lhsReduceDims, SmallVector &rhsReduceDims, + SmallVector &batchingDims, SmallVector &lhsOtherDims, + SmallVector &rhsOtherDims) { + llvm::SmallDenseSet lhsTokenSet(lhsTokens.begin(), lhsTokens.end()); + llvm::SmallDenseSet rhsTokenSet(rhsTokens.begin(), rhsTokens.end()); + llvm::SmallDenseSet finalResultTokenSet(finalResultTokens.begin(), + finalResultTokens.end()); + + for (size_t i = 0; i < lhsTokens.size(); ++i) { + bool rhsContains = rhsTokenSet.contains(lhsTokens[i]); + bool finalResultConatins = finalResultTokenSet.contains(lhsTokens[i]); + // batching dim + if (rhsContains && finalResultConatins) { + batchingDims.push_back(lhsTokens[i]); + // reduce dim of lhs + } else if (!rhsContains && !finalResultConatins) { + lhsReduceDims.push_back(lhsTokens[i]); + // other dim of lhs + } else if (finalResultConatins) { + lhsOtherDims.push_back(lhsTokens[i]); + // contracting dim of lhs + } else if (rhsContains) { + contractingDims.push_back(lhsTokens[i]); + } + } + + for (size_t i = 0; i < rhsTokens.size(); ++i) { + bool lhsContains = lhsTokenSet.contains(rhsTokens[i]); + bool finalResultConatins = finalResultTokenSet.contains(rhsTokens[i]); + // batching dim + if (lhsContains && finalResultConatins) { + // reduce dim of rhs + } else if (!lhsContains && !finalResultConatins) { + rhsReduceDims.push_back(rhsTokens[i]); + // other dim of rhs + } else if (finalResultConatins) { + rhsOtherDims.push_back(rhsTokens[i]); + // contracting dim of rhs + } else if (lhsContains) { + } + } +} + +static void generateIdealReusltDimTokens(SmallVector &batchingDims, + SmallVector &lhsOtherDims, + SmallVector &rhsOtherDims, + SmallVector &lhsReduceDims, + SmallVector &rhsReduceDims, + SmallVector &resultTokens) { + // generate ideal result dims, i.e., + // [*batchingDims, *lhsOtherDims, *lhsReduceDims, *rhsOtherDims, + // *rhsReduceDims] + resultTokens.insert(resultTokens.end(), batchingDims.begin(), + batchingDims.end()); + resultTokens.insert(resultTokens.end(), lhsOtherDims.begin(), + lhsOtherDims.end()); + resultTokens.insert(resultTokens.end(), lhsReduceDims.begin(), + lhsReduceDims.end()); + resultTokens.insert(resultTokens.end(), rhsOtherDims.begin(), + rhsOtherDims.end()); + resultTokens.insert(resultTokens.end(), rhsReduceDims.begin(), + rhsReduceDims.end()); +} + +static Value permuteTensorForMatmul(PatternRewriter &rewriter, Location loc, + Value input, SmallVector &dimTokens, + SmallVector &batchingDims, + SmallVector &contractingDims, + SmallVector &otherDims, + SmallVector &reduceDims, bool isLhs) { + auto inputType = input.getType().cast(); + llvm::SmallDenseMap dimTokenMap; + for (size_t idx = 0; idx < dimTokens.size(); ++idx) { + dimTokenMap[dimTokens[idx]] = idx; + } + + SmallVector permuteVec; + auto appendDims = [&](SmallVector dimTokens) { + for (auto d : dimTokens) { + permuteVec.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(dimTokenMap[d]))); + } + }; + + appendDims(batchingDims); + if (!isLhs) + appendDims(contractingDims); + appendDims(otherDims); + appendDims(reduceDims); + if (isLhs) + appendDims(contractingDims); + + Value dstDims = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())), + permuteVec); + auto outType = inputType.getWithSizesAndDtype(std::nullopt, + inputType.getOptionalDtype()); + return rewriter.create(loc, outType, input, dstDims); +} + +static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, + Value lhs, SmallVector &lhsTokens, + Value rhs, SmallVector &rhsTokens, + Value &result, + SmallVector &resultTokens, + SmallVector &finalResultTokens) { + auto lhsType = lhs.getType().cast(); + auto rhsType = rhs.getType().cast(); + + Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype() + : rhsType.getOptionalDtype(); + + llvm::SmallDenseMap lhsDimShapeMap; + for (size_t idx = 0; idx < lhsTokens.size(); ++idx) { + char d = lhsTokens[idx]; + lhsDimShapeMap[d] = rewriter.create( + loc, lhs, + rewriter.create(loc, + rewriter.getI64IntegerAttr(idx))); + } + llvm::SmallDenseMap rhsDimShapeMap; + for (size_t idx = 0; idx < rhsTokens.size(); ++idx) { + char d = rhsTokens[idx]; + rhsDimShapeMap[d] = rewriter.create( + loc, rhs, + rewriter.create(loc, + rewriter.getI64IntegerAttr(idx))); + } + + // parse batch, contracting, other, reduce dims of lhs and rhs + SmallVector contractingDims; + SmallVector lhsReduceDims; + SmallVector rhsReduceDims; + SmallVector lhsOtherDims; + SmallVector rhsOtherDims; + SmallVector batchingDims; + parseDimTokens(lhsTokens, rhsTokens, finalResultTokens, contractingDims, + lhsReduceDims, rhsReduceDims, batchingDims, lhsOtherDims, + rhsOtherDims); + + llvm::SmallDenseMap outDimShapeMap; + auto generateOutDimShapeMap = [&](SmallVector &dims) { + for (auto d : dims) { + bool lhsContains = lhsDimShapeMap.count(d) > 0; + bool rhsContains = rhsDimShapeMap.count(d) > 0; + if (lhsContains && rhsContains) { + outDimShapeMap[d] = rewriter.create( + loc, lhsDimShapeMap[d], rhsDimShapeMap[d]); + } else if (lhsContains) { + outDimShapeMap[d] = lhsDimShapeMap[d]; + } else if (rhsContains) { + outDimShapeMap[d] = rhsDimShapeMap[d]; + } + } + }; + + generateOutDimShapeMap(contractingDims); + generateOutDimShapeMap(batchingDims); + generateOutDimShapeMap(lhsReduceDims); + generateOutDimShapeMap(rhsReduceDims); + generateOutDimShapeMap(lhsOtherDims); + generateOutDimShapeMap(rhsOtherDims); + + if (contractingDims.size() == 0 && lhsOtherDims.size() == 0 && + rhsOtherDims.size() == 0) { + return rewriter.notifyMatchFailure( + loc, "Hadamard product is currently not supported"); + } + + // shape: [*batchingDims, *lhsOtherDims, *lhsReduceDims, *lhsContractingDims] + lhs = permuteTensorForMatmul(rewriter, loc, lhs, lhsTokens, batchingDims, + contractingDims, lhsOtherDims, lhsReduceDims, + true); + // shape: [*batchingDims, *rhsContractingDims, *rhsOtherDims, *rhsReduceDims] + rhs = permuteTensorForMatmul(rewriter, loc, rhs, rhsTokens, batchingDims, + contractingDims, rhsOtherDims, rhsReduceDims, + false); + // shape: [batchingDimsProd, lhsOtherDimsProd, lhsContractingDimsProd] + lhs = collapseDimForMatmul(rewriter, loc, lhs, batchingDims.size(), + contractingDims.size(), lhsOtherDims.size(), + lhsReduceDims.size(), true); + // shape: [batchingDimsProd, rhsContractingDimsProd, rhsOtherDimsProd] + rhs = collapseDimForMatmul(rewriter, loc, rhs, batchingDims.size(), + contractingDims.size(), rhsOtherDims.size(), + rhsReduceDims.size(), false); + + // perform matmul + auto outType = lhsType.getWithSizesAndDtype(std::nullopt, outputDType); + result = rewriter.create(loc, outType, lhs, rhs); + + // generate ideal result dims. + generateIdealReusltDimTokens(batchingDims, lhsOtherDims, rhsOtherDims, + lhsReduceDims, rhsReduceDims, resultTokens); + + // reshape matmul result to ideal shape: + // [batchingDimsProd, lhsOtherDimsProd, rhsOtherDimsProd] => + // [*batchingDims, *lhsOtherDims, *lhsReduceDims, *rhsOtherDims, + // *rhsReduceDims] + SmallVector outShapeTensors; + for (char d : resultTokens) { + outShapeTensors.emplace_back(outDimShapeMap[d]); + } + + auto outResultShape = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(lhs.getContext())), + outShapeTensors); + result = rewriter.create( + loc, lhsType.getWithSizesAndDtype(std::nullopt, outputDType), result, + outResultShape); + return success(); +} + + +static Value performLastReduceAndPermute(PatternRewriter &rewriter, + Location loc, Type outType, + Value input, + SmallVector &inputTokens, + SmallVector &outTokens) { + auto inputType = input.getType().cast(); + + llvm::SmallDenseSet outTokenSet(outTokens.begin(), outTokens.end()); + SmallVector sumDims; + llvm::SmallDenseMap inputDimToIdx; + int64_t idx = 0; + for (size_t i = 0; i < inputTokens.size(); ++i) { + char d = inputTokens[i]; + if (!outTokenSet.contains(d)) { + sumDims.emplace_back(i); + } else { + inputDimToIdx[d] = idx++; + } + } + + if (sumDims.size() > 0) { + SmallVector sumDimsTensor; + for (auto d : sumDims) { + sumDimsTensor.emplace_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(d))); + } + auto sumDimsListValue = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())), + sumDimsTensor); + auto falseValue = rewriter.create( + loc, rewriter.getBoolAttr(false)); + auto noneValue = rewriter.create(loc); + input = rewriter.create( + loc, + inputType.getWithSizesAndDtype(std::nullopt, + inputType.getOptionalDtype()), + input, sumDimsListValue, falseValue, noneValue); + } + + SmallVector permuteDimsTensor; + for (auto d : outTokens) { + permuteDimsTensor.emplace_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(inputDimToIdx[d]))); + } + auto permuteDimsListValue = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(input.getContext())), + permuteDimsTensor); + auto out = rewriter.create(loc, outType, input, + permuteDimsListValue); + return out; +} + namespace { /// We decompose aten.amax into a set of aten.max.dim op(s) depending on the /// number of dimensions across which the max needs to be computed. @@ -628,6 +980,78 @@ class DecomposeAtenReshapeOp : public OpRewritePattern { }; } // namespace +namespace { +// Decompose AtenEinsumOp to AtenMatmulOp, and supports possible reduce +// operation and permute operation. Currently, this pass doesn't support +// Hadamard product. The basic idea is that: +// Step 1: split the string equation to input/result tokens and find +// batchingDims, contractingDims, otherDims and reduceDims. +// Step 2: permute and reshape input tensors suitable +// for matmul operations. +// Step 3: use AtenMatmulOp to get the result. +// Step 4: iteratively execute step 2 & 3 until we get the final result. +// Step 5: perform remaining permute and reduce operations. +// notice: support static shape only + +class DecomposeAtenEinsumOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenEinsumOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + std::string equation; + if (!matchPattern(op.getEquation(), m_TorchConstantStr(equation))) { + return rewriter.notifyMatchFailure(op, "Unsupported value of equation"); + } + SmallVector resultTokens; + SmallVector> inputTokens; + if (!parseEquation(equation, inputTokens, resultTokens)) { + return rewriter.notifyMatchFailure( + op, "Unexpected character in equations encountered"); + } + + SmallVector inputTensors; + if (!getListConstructElements(op.getTensors(), inputTensors)) { + return rewriter.notifyMatchFailure( + op, "input should comes from a PrimListConstructOp"); + } + + auto allTensorHasSizes = [](Value tensor) { + auto type = tensor.getType().dyn_cast(); + if (!type || !type.hasSizes()) + return false; + return true; + }; + + if (!llvm::all_of(inputTensors, allTensorHasSizes)) { + return rewriter.notifyMatchFailure(op, + "all input tensors should have sizes"); + } + + SmallVector lhsTokens = inputTokens[0]; + Value lhs = inputTensors[0]; + Value result; + + for (size_t i = 1; i < inputTensors.size(); ++i) { + auto rhs = inputTensors[i]; + auto rhsTokens = inputTokens[i]; + SmallVector outTokens; + if (failed(performMatmul(rewriter, loc, lhs, lhsTokens, rhs, rhsTokens, + result, outTokens, resultTokens))) { + return failure(); + } + lhs = result; + lhsTokens = outTokens; + } + + result = performLastReduceAndPermute(rewriter, loc, op.getType(), lhs, + lhsTokens, resultTokens); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + // Calculates the softmax function on the given `input` tensor. Softmax(x) = // exp(x)/sum(exp(x)). // To avoid overflow we use the following decomposition rule: @@ -5798,6 +6222,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 71b3a9d91c2a..9af47ae93faf 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -385,6 +385,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 356235528f45..16061629318e 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -554,6 +554,9 @@ "EmptyLikeModule_int", "ExpandAsIntModule_basic", "ExpandModule_basic", + "EinsumStaticModule_basic", + "EinsumStaticFourDimensionModule_basic", + "EinsumStaticContractRhsModule_basic", "Fill_TensorFloat64WithFloat32_basic", "Fill_TensorFloat64WithFloat64_basic", "Fill_TensorFloat64WithInt64_basic", @@ -1020,6 +1023,9 @@ "RsubFloatModule_basic", "RsubFloatModule_noalpha_basic", "RsubInt0d_NumToTensor_Module_basic", + "EinsumStaticModule_basic", + "EinsumStaticFourDimensionModule_basic", + "EinsumStaticContractRhsModule_basic", "ElementwiseBitwiseAndModule_basic", "ElementwiseBitwiseAndStaticShapeModule_basic", "ElementwiseBitwiseNotInt32Module_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 478c5a16131f..df1d6de5e66a 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -3684,6 +3684,19 @@ def aten〇cat〡dtype(tensors_rank_dtype: List[Tuple[int, int]], dim: int = 0) dtypes.append(tensor_dtype) return promote_dtypes(ranks, dtypes) +@check_dtype_function( + [Invocation("i,j->ij", [TensorOfShape(1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.int32)]),]) +def aten〇einsum〡dtype(equation: str, tensors_rank_dtype: List[Tuple[int, int]], path: Optional[List[int]] = None) -> int: + ranks: List[Optional[int]] = [] + dtypes: List[int] = [] + assert len(tensors_rank_dtype) != 0 + for tensor_rank_dtype in tensors_rank_dtype: + tensor_rank, tensor_dtype = tensor_rank_dtype + ranks.append(tensor_rank) + dtypes.append(tensor_dtype) + return promote_dtypes(ranks, dtypes) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇_shape_as_tensor〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.int64 diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 84718b382cd5..f50eb461b76c 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -566,6 +566,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)") emit("aten::argmin : (Tensor, int?, bool) -> (Tensor)") emit("aten::one_hot : (Tensor, int) -> (Tensor)") + emit("aten::einsum : (str, Tensor[], int[]?) -> (Tensor)") emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)") emit("aten::clone : (Tensor, int?) -> (Tensor)") emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index 723ec085363e..a73435c3c1ad 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -1044,3 +1044,59 @@ def forward(self, inputs): @register_test_case(module_factory=lambda: UnflattenIntNegativeOneSizeStaticModule()) def UnflattenIntNegativeOneSizeStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 12, 3)) + +# ============================================================================== + +class EinsumStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 2, 4], torch.float32, True), + ([5, 4, 6], torch.float32, True), + ([3, 7, 6], torch.float32, True), + ]) + def forward(self, tensor1, tensor2, tensor3): + return torch.ops.aten.einsum('bqe,ked,btd->bqtk', [tensor1, tensor2, tensor3]) + +@register_test_case(module_factory=lambda: EinsumStaticModule()) +def EinsumStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 2, 4), tu.rand(5, 4, 6), tu.rand(3, 7, 6)) + + +class EinsumStaticFourDimensionModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4, 5, 6], torch.float32, True), + ([3, 7, 5, 6], torch.float32, True), + ]) + def forward(self, tensor1, tensor2): + return torch.ops.aten.einsum('blhd,bshd->blhs', [tensor1, tensor2]) + +@register_test_case(module_factory=lambda: EinsumStaticFourDimensionModule()) +def EinsumStaticFourDimensionModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5, 6), tu.rand(3, 7, 5, 6)) + + +class EinsumStaticContractRhsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4, 5], torch.float32, True), + ([4, 5], torch.float32, True), + ]) + def forward(self, tensor1, tensor2): + return torch.ops.aten.einsum('abc,bc->a', [tensor1, tensor2]) + +@register_test_case(module_factory=lambda: EinsumStaticContractRhsModule()) +def EinsumStaticContractRhsModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.rand(4, 5)) \ No newline at end of file