Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][spirv] Add a generic convert-to-spirv pass #95942

Merged
merged 1 commit into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//===- ConvertToSPIRVPass.h - Conversion to SPIR-V pass ---*- C++ -*-=========//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H
#define MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H

#include <memory>

namespace mlir {
class Pass;

#define GEN_PASS_DECL_CONVERTTOSPIRVPASS
#include "mlir/Conversion/Passes.h.inc"

} // namespace mlir

#endif // MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
#include "mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h"
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
];
}

//===----------------------------------------------------------------------===//
// ToSPIRV
//===----------------------------------------------------------------------===//

def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
let summary = "Convert to SPIR-V";
let description = [{
This is a generic pass to convert to SPIR-V.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This claims to be a "generic" pass but it seems to me instead to be a "monolithic" pass.

Why didn't we align this on the convert-to-llvm pass design instead?

Copy link
Member

@kuhar kuhar Aug 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@angelz913 talked about this in https://youtu.be/-qoMMrlYvGs?t=436 (starts around the 7m 15s mark). The TL;DR is we didn't know what interfaces to have, and decided to start with a monolithic multi-stage v0 implementation, with the plan to add make it interface-based when gain confidence in this design.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TL;DR is we didn't know what interfaces to have

I don't quite understand? Why isn't it just copy the convert-to-llvm one and rename it?

I have strong concerns with in-tree monolithic passes like this: what is the timeline to remove this and migrate to a pluggable one?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that in-tree monolithic passes arent great, but the change to use interfaces and go to more LLVM based approach is more involved (I dont think it is as simple as "copy the convert-to-llvm and rename it". Angel and Jakub can fill in more details here). I think timeline will depend on how much community involvement we get here. Jakub and Angel are trying to get upstream support flushed out more. So this is a strict improvement anyway. Its probably better to iterate on this in a bit.

FWIW, for conversion to LLVM in IREE we just throw all the conversion patterns into one pass and run them together. I personally find the split of conversion from each dialect to LLVM kind of artificial. Everything needs to be translated to LLVM. Running multiple passes that walk the IR multiple times seems like a waste. But that is a downstream decision and not having monoliths in upstream MLIR is useful.

Copy link
Collaborator

@joker-eph joker-eph Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think timeline will depend on how much community involvement we get here.

In this case I would want to see this pass moved to the test folder: I'm not comfortable with monolithic passes alongside the rest of the transformations right now.

FWIW, for conversion to LLVM in IREE we just throw all the conversion patterns into one pass and run them together. I personally find the split of conversion from each dialect to LLVM kind of artificial

I suspect you missed the intent of upstream design: these individual passes are all made that way upstream for testing, the intention has always been that downstream projects create a monolithic pass for their own purpose and don't use the upstream passes as-is (which is why the populatePatterns method are exposed): that's exactly "work as intended".

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but the dependency on the spirv conversion code will have to move over there in some form.

Why would it have to be moved to the runner? The current pattern is that the transformations are made with mlir-opt ahead of invoking the runner (hence why we ended up with a "pure" mlir-cpu-runner).

So we can drop the runner, but it will just make the situation worse).

I think we're talking about different things maybe?
You seem to argue about the need for running some of these tests, while I'm just describing the system architecture of the project, which does not impact in any way which tests we're running.

That is, from a very high-level, the specific runner tool goes away and you do instead something like mlir-opt -pass-pipeline="builtin.module(convert-to-spirv)" %s | mlir-cpu-runner --shared-libs=mlir_vulkan_runtime.so.
(this is a transition we made for the mlir-cuda-runner ~3 years ago, and the Vulkan runner has been a TODO ever since IIUC)

Note that in this model, the test pass is available for the opt tool, the runner does not need to know about the test passes (or any pass).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification, Mehdi.

You seem to argue about the need for running some of these tests, while I'm just describing the system architecture of the project, which does not impact in any way which tests we're running.

Right, that's a good way to put it. From my point of view, what I mostly care about is that we do keep these tests (as in, both the implementation of convert-to-spirv and the e2e .mlir tests) while the cleanup in this area (both gpu-to-spirv and the runner implementation). I think we all aggrege on the end state, so I just want to make sure that we can find a path that allows us to make incremental improvements to get there.

From your perspective, would it be an option to make this a test pass and temporarily have it registered it in the vulkan runner (similar to how they are manually listed in mlir-opt.cpp)? Maybe that's a middle ground, although I'm not sure of how much difference to a pass being a test or not there practically is for other users of mlir.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From your perspective, would it be an option to make this a test pass and temporarily have it registered it in the vulkan runner (similar to how they are manually listed in mlir-opt.cpp)?

The problem is that we optionally compile the test passes I think? (With -DMLIR_INCLUDE_TESTS=ON).

It may be easier to just keep the pass as-is until we migrate tests from mlir-vulkan-runner to mlir-cpu-runner?

The only thing that makes be nervous is that unless there is a timeline with people making progress on it, we'll still be adding mlir-vulkan-runner based tests multiple years from now. What I'm trying to see is some sort of a "gradient" on the progression of all this (and some timeline), rather than a quick immediate solution.

Copy link
Member

@kuhar kuhar Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, let me check internally if we can commit to something more concrete along this axis. Overall, I think we did make a lot of recent progress on the spirv conversion test coverage and general cleanup in this area, but I understand why you'd want to prioritize the runner cleanup next.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joker-eph I checked and we will have @andfau-amd work on the mlir runner migration, starting from ~next week.

}];
let dependentDialects = ["spirv::SPIRVDialect"];
}

//===----------------------------------------------------------------------===//
// AffineToStandard
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_subdirectory(ControlFlowToLLVM)
add_subdirectory(ControlFlowToSCF)
add_subdirectory(ControlFlowToSPIRV)
add_subdirectory(ConvertToLLVM)
add_subdirectory(ConvertToSPIRV)
add_subdirectory(FuncToEmitC)
add_subdirectory(FuncToLLVM)
add_subdirectory(FuncToSPIRV)
Expand Down
22 changes: 22 additions & 0 deletions mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
set(LLVM_OPTIONAL_SOURCES
ConvertToSPIRVPass.cpp
)

add_mlir_conversion_library(MLIRConvertToSPIRVPass
ConvertToSPIRVPass.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ConvertToSPIRV

DEPENDS
MLIRConversionPassIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRRewrite
MLIRSPIRVConversion
MLIRSPIRVDialect
MLIRSupport
MLIRTransformUtils
)
71 changes: 71 additions & 0 deletions mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
//===- ConvertToSPIRVPass.cpp - MLIR SPIR-V Conversion --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <memory>

#define DEBUG_TYPE "convert-to-spirv"

namespace mlir {
#define GEN_PASS_DEF_CONVERTTOSPIRVPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;

namespace {

/// A pass to perform the SPIR-V conversion.
struct ConvertToSPIRVPass final
: impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {

void runOnOperation() override {
MLIRContext *context = &getContext();
Operation *op = getOperation();

spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
SPIRVTypeConverter typeConverter(targetAttr);

RewritePatternSet patterns(context);
ScfToSPIRVContext scfToSPIRVContext;

// Populate patterns.
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
populateFuncToSPIRVPatterns(typeConverter, patterns);
index::populateIndexToSPIRVPatterns(typeConverter, patterns);
populateVectorToSPIRVPatterns(typeConverter, patterns);
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);

std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);

if (failed(applyPartialConversion(op, *target, std::move(patterns))))
return signalPassFailure();
}
};

} // namespace
218 changes: 218 additions & 0 deletions mlir/test/Conversion/ConvertToSPIRV/arith.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
// RUN: mlir-opt -convert-to-spirv -split-input-file %s | FileCheck %s

//===----------------------------------------------------------------------===//
// arithmetic ops
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @int32_scalar
func.func @int32_scalar(%lhs: i32, %rhs: i32) {
// CHECK: spirv.IAdd %{{.*}}, %{{.*}}: i32
%0 = arith.addi %lhs, %rhs: i32
// CHECK: spirv.ISub %{{.*}}, %{{.*}}: i32
%1 = arith.subi %lhs, %rhs: i32
// CHECK: spirv.IMul %{{.*}}, %{{.*}}: i32
%2 = arith.muli %lhs, %rhs: i32
// CHECK: spirv.SDiv %{{.*}}, %{{.*}}: i32
%3 = arith.divsi %lhs, %rhs: i32
// CHECK: spirv.UDiv %{{.*}}, %{{.*}}: i32
%4 = arith.divui %lhs, %rhs: i32
// CHECK: spirv.UMod %{{.*}}, %{{.*}}: i32
%5 = arith.remui %lhs, %rhs: i32
return
}

// CHECK-LABEL: @int32_scalar_srem
// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
func.func @int32_scalar_srem(%lhs: i32, %rhs: i32) {
// CHECK: %[[LABS:.+]] = spirv.GL.SAbs %[[LHS]] : i32
// CHECK: %[[RABS:.+]] = spirv.GL.SAbs %[[RHS]] : i32
// CHECK: %[[ABS:.+]] = spirv.UMod %[[LABS]], %[[RABS]] : i32
// CHECK: %[[POS:.+]] = spirv.IEqual %[[LHS]], %[[LABS]] : i32
// CHECK: %[[NEG:.+]] = spirv.SNegate %[[ABS]] : i32
// CHECK: %{{.+}} = spirv.Select %[[POS]], %[[ABS]], %[[NEG]] : i1, i32
%0 = arith.remsi %lhs, %rhs: i32
return
}

// -----

//===----------------------------------------------------------------------===//
// arith bit ops
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @bitwise_scalar
func.func @bitwise_scalar(%arg0 : i32, %arg1 : i32) {
// CHECK: spirv.BitwiseAnd
%0 = arith.andi %arg0, %arg1 : i32
// CHECK: spirv.BitwiseOr
%1 = arith.ori %arg0, %arg1 : i32
// CHECK: spirv.BitwiseXor
%2 = arith.xori %arg0, %arg1 : i32
return
}

// CHECK-LABEL: @bitwise_vector
func.func @bitwise_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
// CHECK: spirv.BitwiseAnd
%0 = arith.andi %arg0, %arg1 : vector<4xi32>
// CHECK: spirv.BitwiseOr
%1 = arith.ori %arg0, %arg1 : vector<4xi32>
// CHECK: spirv.BitwiseXor
%2 = arith.xori %arg0, %arg1 : vector<4xi32>
return
}

// CHECK-LABEL: @logical_scalar
func.func @logical_scalar(%arg0 : i1, %arg1 : i1) {
// CHECK: spirv.LogicalAnd
%0 = arith.andi %arg0, %arg1 : i1
// CHECK: spirv.LogicalOr
%1 = arith.ori %arg0, %arg1 : i1
// CHECK: spirv.LogicalNotEqual
%2 = arith.xori %arg0, %arg1 : i1
return
}

// CHECK-LABEL: @logical_vector
func.func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
// CHECK: spirv.LogicalAnd
%0 = arith.andi %arg0, %arg1 : vector<4xi1>
// CHECK: spirv.LogicalOr
%1 = arith.ori %arg0, %arg1 : vector<4xi1>
// CHECK: spirv.LogicalNotEqual
%2 = arith.xori %arg0, %arg1 : vector<4xi1>
return
}

// CHECK-LABEL: @shift_scalar
func.func @shift_scalar(%arg0 : i32, %arg1 : i32) {
// CHECK: spirv.ShiftLeftLogical
%0 = arith.shli %arg0, %arg1 : i32
// CHECK: spirv.ShiftRightArithmetic
%1 = arith.shrsi %arg0, %arg1 : i32
// CHECK: spirv.ShiftRightLogical
%2 = arith.shrui %arg0, %arg1 : i32
return
}

// CHECK-LABEL: @shift_vector
func.func @shift_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
// CHECK: spirv.ShiftLeftLogical
%0 = arith.shli %arg0, %arg1 : vector<4xi32>
// CHECK: spirv.ShiftRightArithmetic
%1 = arith.shrsi %arg0, %arg1 : vector<4xi32>
// CHECK: spirv.ShiftRightLogical
%2 = arith.shrui %arg0, %arg1 : vector<4xi32>
return
}

// -----

//===----------------------------------------------------------------------===//
// arith.cmpf
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @cmpf
func.func @cmpf(%arg0 : f32, %arg1 : f32) {
// CHECK: spirv.FOrdEqual
%1 = arith.cmpf oeq, %arg0, %arg1 : f32
return
}

// CHECK-LABEL: @vec1cmpf
func.func @vec1cmpf(%arg0 : vector<1xf32>, %arg1 : vector<1xf32>) {
// CHECK: spirv.FOrdGreaterThan
%0 = arith.cmpf ogt, %arg0, %arg1 : vector<1xf32>
// CHECK: spirv.FUnordLessThan
%1 = arith.cmpf ult, %arg0, %arg1 : vector<1xf32>
return
}

// -----

//===----------------------------------------------------------------------===//
// arith.cmpi
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @cmpi
func.func @cmpi(%arg0 : i32, %arg1 : i32) {
// CHECK: spirv.IEqual
%0 = arith.cmpi eq, %arg0, %arg1 : i32
return
}

// CHECK-LABEL: @indexcmpi
func.func @indexcmpi(%arg0 : index, %arg1 : index) {
// CHECK: spirv.IEqual
%0 = arith.cmpi eq, %arg0, %arg1 : index
return
}

// CHECK-LABEL: @vec1cmpi
func.func @vec1cmpi(%arg0 : vector<1xi32>, %arg1 : vector<1xi32>) {
// CHECK: spirv.ULessThan
%0 = arith.cmpi ult, %arg0, %arg1 : vector<1xi32>
// CHECK: spirv.SGreaterThan
%1 = arith.cmpi sgt, %arg0, %arg1 : vector<1xi32>
return
}

// CHECK-LABEL: @boolcmpi_equality
func.func @boolcmpi_equality(%arg0 : i1, %arg1 : i1) {
// CHECK: spirv.LogicalEqual
%0 = arith.cmpi eq, %arg0, %arg1 : i1
// CHECK: spirv.LogicalNotEqual
%1 = arith.cmpi ne, %arg0, %arg1 : i1
return
}

// CHECK-LABEL: @boolcmpi_unsigned
func.func @boolcmpi_unsigned(%arg0 : i1, %arg1 : i1) {
// CHECK-COUNT-2: spirv.Select
// CHECK: spirv.UGreaterThanEqual
%0 = arith.cmpi uge, %arg0, %arg1 : i1
// CHECK-COUNT-2: spirv.Select
// CHECK: spirv.ULessThan
%1 = arith.cmpi ult, %arg0, %arg1 : i1
return
}

// CHECK-LABEL: @vec1boolcmpi_equality
func.func @vec1boolcmpi_equality(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
// CHECK: spirv.LogicalEqual
%0 = arith.cmpi eq, %arg0, %arg1 : vector<1xi1>
// CHECK: spirv.LogicalNotEqual
%1 = arith.cmpi ne, %arg0, %arg1 : vector<1xi1>
return
}

// CHECK-LABEL: @vec1boolcmpi_unsigned
func.func @vec1boolcmpi_unsigned(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
// CHECK-COUNT-2: spirv.Select
// CHECK: spirv.UGreaterThanEqual
%0 = arith.cmpi uge, %arg0, %arg1 : vector<1xi1>
// CHECK-COUNT-2: spirv.Select
// CHECK: spirv.ULessThan
%1 = arith.cmpi ult, %arg0, %arg1 : vector<1xi1>
return
}

// CHECK-LABEL: @vecboolcmpi_equality
func.func @vecboolcmpi_equality(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
// CHECK: spirv.LogicalEqual
%0 = arith.cmpi eq, %arg0, %arg1 : vector<4xi1>
// CHECK: spirv.LogicalNotEqual
%1 = arith.cmpi ne, %arg0, %arg1 : vector<4xi1>
return
}

// CHECK-LABEL: @vecboolcmpi_unsigned
func.func @vecboolcmpi_unsigned(%arg0 : vector<3xi1>, %arg1 : vector<3xi1>) {
// CHECK-COUNT-2: spirv.Select
// CHECK: spirv.UGreaterThanEqual
%0 = arith.cmpi uge, %arg0, %arg1 : vector<3xi1>
// CHECK-COUNT-2: spirv.Select
// CHECK: spirv.ULessThan
%1 = arith.cmpi ult, %arg0, %arg1 : vector<3xi1>
return
}
Loading
Loading