Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flang][MLIR] Addition of declare target to and other related map fixes #219

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions flang/include/flang/Optimizer/Builder/FIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,10 @@ elideLengthsAlreadyInType(mlir::Type type, mlir::ValueRange lenParams);
/// Get the address space which should be used for allocas
uint64_t getAllocaAddressSpace(mlir::DataLayout *dataLayout);

uint64_t getGlobalAddressSpace(mlir::DataLayout *dataLayout);

uint64_t getProgramAddressSpace(mlir::DataLayout *dataLayout);

} // namespace fir::factory

#endif // FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H
3 changes: 3 additions & 0 deletions flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ class ConvertFIRToLLVMPattern : public mlir::ConvertToLLVMPattern {
unsigned
getProgramAddressSpace(mlir::ConversionPatternRewriter &rewriter) const;

unsigned
getGlobalAddressSpace(mlir::ConversionPatternRewriter &rewriter) const;

const fir::FIRToLLVMPassOptions &options;

using ConvertToLLVMPattern::match;
Expand Down
14 changes: 14 additions & 0 deletions flang/lib/Optimizer/Builder/FIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1721,3 +1721,17 @@ uint64_t fir::factory::getAllocaAddressSpace(mlir::DataLayout *dataLayout) {
return mlir::cast<mlir::IntegerAttr>(addrSpace).getUInt();
return 0;
}

uint64_t fir::factory::getGlobalAddressSpace(mlir::DataLayout *dataLayout) {
if (dataLayout)
if (mlir::Attribute addrSpace = dataLayout->getGlobalMemorySpace())
return mlir::cast<mlir::IntegerAttr>(addrSpace).getUInt();
return 0;
}

uint64_t fir::factory::getProgramAddressSpace(mlir::DataLayout *dataLayout) {
if (dataLayout)
if (mlir::Attribute addrSpace = dataLayout->getProgramMemorySpace())
return mlir::cast<mlir::IntegerAttr>(addrSpace).getUInt();
return 0;
}
73 changes: 60 additions & 13 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,54 @@ addLLVMOpBundleAttrs(mlir::ConversionPatternRewriter &rewriter,
}

namespace {

mlir::Value replaceWithAddrOfOrASCast(mlir::ConversionPatternRewriter &rewriter,
mlir::Location loc,
std::uint64_t globalAS,
std::uint64_t programAS,
llvm::StringRef symName, mlir::Type type,
mlir::Operation *replaceOp = nullptr) {
if (mlir::isa<mlir::LLVM::LLVMPointerType>(type)) {
if (globalAS != programAS) {
auto llvmAddrOp = rewriter.create<mlir::LLVM::AddressOfOp>(
loc, getLlvmPtrType(rewriter.getContext(), globalAS), symName);
if (replaceOp)
return rewriter.replaceOpWithNewOp<mlir::LLVM::AddrSpaceCastOp>(
replaceOp, ::getLlvmPtrType(rewriter.getContext(), programAS),
llvmAddrOp);
return rewriter.create<mlir::LLVM::AddrSpaceCastOp>(
loc, getLlvmPtrType(rewriter.getContext(), programAS), llvmAddrOp);
}

if (replaceOp)
return rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
replaceOp, getLlvmPtrType(rewriter.getContext(), globalAS), symName);
return rewriter.create<mlir::LLVM::AddressOfOp>(
loc, getLlvmPtrType(rewriter.getContext(), globalAS), symName);
}

if (replaceOp)
return rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(replaceOp, type,
symName);
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, type, symName);
}

/// Lower `fir.address_of` operation to `llvm.address_of` operation.
struct AddrOfOpConversion : public fir::FIROpConversion<fir::AddrOfOp> {
using FIROpConversion::FIROpConversion;

llvm::LogicalResult
matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto ty = convertType(addr.getType());
rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
addr, ty, addr.getSymbol().getRootReference().getValue());
auto global = addr->getParentOfType<mlir::ModuleOp>()
.lookupSymbol<mlir::LLVM::GlobalOp>(addr.getSymbol());
replaceWithAddrOfOrASCast(
rewriter, addr->getLoc(),
global ? global.getAddrSpace() : getGlobalAddressSpace(rewriter),
getProgramAddressSpace(rewriter),
global ? global.getSymName()
: addr.getSymbol().getRootReference().getValue(),
convertType(addr.getType()), addr);
return mlir::success();
}
};
Expand Down Expand Up @@ -1255,14 +1293,19 @@ struct EmboxCommonConversion : public fir::FIROpConversion<OP> {
? fir::NameUniquer::getTypeDescriptorAssemblyName(recType.getName())
: fir::NameUniquer::getTypeDescriptorName(recType.getName());
mlir::Type llvmPtrTy = ::getLlvmPtrType(mod.getContext());
mlir::DataLayout dataLayout(mod);
if (auto global = mod.template lookupSymbol<fir::GlobalOp>(name)) {
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
global.getSymName());
return replaceWithAddrOfOrASCast(
rewriter, loc, fir::factory::getGlobalAddressSpace(&dataLayout),
fir::factory::getProgramAddressSpace(&dataLayout),
global.getSymName(), llvmPtrTy);
}
if (auto global = mod.template lookupSymbol<mlir::LLVM::GlobalOp>(name)) {
// The global may have already been translated to LLVM.
return rewriter.create<mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
global.getSymName());
return replaceWithAddrOfOrASCast(
rewriter, loc, global.getAddrSpace(),
fir::factory::getProgramAddressSpace(&dataLayout),
global.getSymName(), llvmPtrTy);
}
// Type info derived types do not have type descriptors since they are the
// types defining type descriptors.
Expand Down Expand Up @@ -2763,12 +2806,16 @@ struct TypeDescOpConversion : public fir::FIROpConversion<fir::TypeDescOp> {
: fir::NameUniquer::getTypeDescriptorName(recordType.getName());
auto llvmPtrTy = ::getLlvmPtrType(typeDescOp.getContext());
if (auto global = module.lookupSymbol<mlir::LLVM::GlobalOp>(typeDescName)) {
rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
typeDescOp, llvmPtrTy, global.getSymName());
replaceWithAddrOfOrASCast(rewriter, typeDescOp->getLoc(),
global.getAddrSpace(),
getProgramAddressSpace(rewriter),
global.getSymName(), llvmPtrTy, typeDescOp);
return mlir::success();
} else if (auto global = module.lookupSymbol<fir::GlobalOp>(typeDescName)) {
rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(
typeDescOp, llvmPtrTy, global.getSymName());
replaceWithAddrOfOrASCast(rewriter, typeDescOp->getLoc(),
getGlobalAddressSpace(rewriter),
getProgramAddressSpace(rewriter),
global.getSymName(), llvmPtrTy, typeDescOp);
return mlir::success();
}
return mlir::failure();
Expand Down Expand Up @@ -2859,8 +2906,8 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
mlir::SymbolRefAttr comdat;
llvm::ArrayRef<mlir::NamedAttribute> attrs;
auto g = rewriter.create<mlir::LLVM::GlobalOp>(
loc, tyAttr, isConst, linkage, global.getSymName(), initAttr, 0, 0,
false, false, comdat, attrs, dbgExprs);
loc, tyAttr, isConst, linkage, global.getSymName(), initAttr, 0,
getGlobalAddressSpace(rewriter), false, false, comdat, attrs, dbgExprs);

if (global.getAlignment() && *global.getAlignment() > 0)
g.setAlignment(*global.getAlignment());
Expand Down
25 changes: 23 additions & 2 deletions flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,10 @@ unsigned ConvertFIRToLLVMPattern::getAllocaAddressSpace(
mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
assert(parentOp != nullptr &&
"expected insertion block to have parent operation");
if (auto module = parentOp->getParentOfType<mlir::ModuleOp>())
auto module = mlir::isa<mlir::ModuleOp>(parentOp)
? mlir::cast<mlir::ModuleOp>(parentOp)
: parentOp->getParentOfType<mlir::ModuleOp>();
if (module)
if (mlir::Attribute addrSpace =
mlir::DataLayout(module).getAllocaMemorySpace())
return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt();
Expand All @@ -358,11 +361,29 @@ unsigned ConvertFIRToLLVMPattern::getProgramAddressSpace(
mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
assert(parentOp != nullptr &&
"expected insertion block to have parent operation");
if (auto module = parentOp->getParentOfType<mlir::ModuleOp>())
auto module = mlir::isa<mlir::ModuleOp>(parentOp)
? mlir::cast<mlir::ModuleOp>(parentOp)
: parentOp->getParentOfType<mlir::ModuleOp>();
if (module)
if (mlir::Attribute addrSpace =
mlir::DataLayout(module).getProgramMemorySpace())
return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt();
return defaultAddressSpace;
}

unsigned ConvertFIRToLLVMPattern::getGlobalAddressSpace(
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
assert(parentOp != nullptr &&
"expected insertion block to have parent operation");
auto module = mlir::isa<mlir::ModuleOp>(parentOp)
? mlir::cast<mlir::ModuleOp>(parentOp)
: parentOp->getParentOfType<mlir::ModuleOp>();
if (module)
if (mlir::Attribute addrSpace =
mlir::DataLayout(module).getGlobalMemorySpace())
return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt();
return defaultAddressSpace;
}

} // namespace fir
6 changes: 4 additions & 2 deletions flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,10 @@ class MapInfoFinalizationPass
return llvm::to_underlying(
hasImplicitMap
? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT |
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
}

mlir::omp::MapInfoOp genDescriptorMemberMaps(mlir::omp::MapInfoOp op,
Expand Down
Loading