diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 0f6f92bd7c2c..0532b4b19d94 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1627,6 +1627,25 @@ class ConvertElementwiseOp : public ConversionPattern { operands, [](Value v) { return isa(v.getType()); })); auto resultType = cast( getTypeConverter()->convertType(op->getResult(0).getType())); + bool isScalarOp = resultType.getShape().size() == 0; + if (isScalarOp) { + // for elementwise ops that are actually rank0 scalar computations, + // perform the payload outside a linalg generic op. + SmallVector payloadArgs; + for (auto t : tensorOperands) { + payloadArgs.push_back(rewriter.create(loc, t)); + } + Value scalarResult = createLinalgPayloadCalculationForElementwiseOp( + rewriter, loc, getTypeConverter(), payloadArgs, op, operands); + if (!scalarResult) + return rewriter.notifyMatchFailure( + op, "Failed to create payload for scalar elementwise op"); + Value rank0Result = + createInitTensor(rewriter, loc, ValueRange{}, + resultType.getElementType(), scalarResult); + rewriter.replaceOpWithNewOp(op, resultType, rank0Result); + return success(); + } bool hadErrorCreatingPayload = false; Value generic = torch_to_linalg::createElementwiseLinalgGeneric( rewriter, loc, tensorOperands, resultType.getElementType(), diff --git a/test/Conversion/TorchToLinalg/elementwise.mlir b/test/Conversion/TorchToLinalg/elementwise.mlir index aa2be74f5d7e..ecf4caa58389 100644 --- a/test/Conversion/TorchToLinalg/elementwise.mlir +++ b/test/Conversion/TorchToLinalg/elementwise.mlir @@ -4,13 +4,11 @@ // CHECK-LABEL: func.func @elementwise$unary( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { // CHECK-DAG: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor -// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor -// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%[[BUILTIN_TENSOR]] : tensor) outs(%[[INIT_TENSOR]] : tensor) { -// CHECK: ^bb0(%[[BBARG0:.*]]: f32, %{{.*}}: f32): -// CHECK: %[[TANH:.*]] = math.tanh %[[BBARG0]] : f32 -// CHECK: linalg.yield %[[TANH]] : f32 -// CHECK: } -> tensor -// CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor to tensor +// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[BUILTIN_TENSOR]][] : tensor +// CHECK: %[[TANH:.*]] = math.tanh %[[EXTRACT]] : f32 +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor +// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[TANH]] : f32) outs(%[[EMPTY]] : tensor) -> tensor +// CHECK: %[[CASTED:.*]] = tensor.cast %[[FILL:.*]] : tensor to tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor -> !torch.vtensor<[],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[],f32> // CHECK: }