diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h index 4555a6dfee5d4..1d2f68aee70e0 100644 --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -10852,7 +10852,7 @@ class Sema { void AddSyclKernel(Decl * d) { SyclKernel.push_back(d); } SmallVector &SyclKernels() { return SyclKernel; } - void ConstructSYCLKernel(CXXMemberCallExpr* e); + void ConstructSYCLKernel(FunctionDecl* KernelHelper); }; /// RAII object that enters a new expression evaluation context. diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp index 136b76652f37a..94f7979f66f28 100644 --- a/clang/lib/Sema/SemaOverload.cpp +++ b/clang/lib/Sema/SemaOverload.cpp @@ -13012,15 +13012,6 @@ Sema::BuildCallToMemberFunction(Scope *S, Expr *MemExprE, CXXMemberCallExpr::Create(Context, MemExprE, Args, ResultType, VK, RParenLoc, Proto->getNumParams()); - if (getLangOpts().SYCL) { - auto Func = TheCall->getMethodDecl(); - auto Name = Func->getQualifiedNameAsString(); - if (Name == "cl::sycl::handler::parallel_for" || - Name == "cl::sycl::handler::single_task") { - ConstructSYCLKernel(TheCall); - } - } - // Check for a valid return type. if (CheckCallReturnType(Method->getReturnType(), MemExpr->getMemberLoc(), TheCall, Method)) diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index 24163552a0946..1a46b82e0392e 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -12,12 +12,44 @@ #include "clang/AST/AST.h" #include "clang/Sema/Sema.h" #include "llvm/ADT/SmallVector.h" +#include "TreeTransform.h" using namespace clang; -LambdaExpr *getBodyAsLambda(CXXMemberCallExpr *e) { - auto LastArg = e->getArg(e->getNumArgs() - 1); - return dyn_cast(LastArg); +typedef llvm::DenseMap DeclMap; + +class KernelBodyTransform : public TreeTransform { +public: + KernelBodyTransform(llvm::DenseMap &Map, + Sema &S) + : TreeTransform(S), DMap(Map), SemaRef(S) {} + bool AlwaysRebuild() { return true; } + + ExprResult TransformDeclRefExpr(DeclRefExpr *DRE) { + auto Ref = dyn_cast(DRE->getDecl()); + if (Ref) { + auto NewDecl = DMap[Ref]; + if (NewDecl) { + return DeclRefExpr::Create( + SemaRef.getASTContext(), DRE->getQualifierLoc(), + DRE->getTemplateKeywordLoc(), NewDecl, false, DRE->getNameInfo(), + NewDecl->getType(), DRE->getValueKind()); + } + } + return DRE; + } + +private: + DeclMap DMap; + Sema &SemaRef; +}; + +CXXRecordDecl* getBodyAsLambda(FunctionDecl *FD) { + auto FirstArg = (*FD->param_begin()); + if (FirstArg) + if (FirstArg->getType()->getAsCXXRecordDecl()->isLambda()) + return FirstArg->getType()->getAsCXXRecordDecl(); + return nullptr; } FunctionDecl *CreateSYCLKernelFunction(ASTContext &Context, StringRef Name, @@ -54,17 +86,16 @@ FunctionDecl *CreateSYCLKernelFunction(ASTContext &Context, StringRef Name, return Result; } -CompoundStmt *CreateSYCLKernelBody(Sema &S, CXXMemberCallExpr *e, +CompoundStmt *CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelHelper, DeclContext *DC) { llvm::SmallVector BodyStmts; // TODO: case when kernel is functor // TODO: possible refactoring when functor case will be completed - LambdaExpr *LE = getBodyAsLambda(e); - if (LE) { + CXXRecordDecl *LC = getBodyAsLambda(KernelHelper); + if (LC) { // Create Lambda object - CXXRecordDecl *LC = LE->getLambdaClass(); auto LambdaVD = VarDecl::Create( S.Context, DC, SourceLocation(), SourceLocation(), LC->getIdentifier(), QualType(LC->getTypeForDecl(), 0), LC->getLambdaTypeInfo(), SC_None); @@ -137,43 +168,23 @@ CompoundStmt *CreateSYCLKernelBody(Sema &S, CXXMemberCallExpr *e, TargetFuncParam++; } - // Create Lambda operator () call - FunctionDecl *LO = LE->getCallOperator(); - ArrayRef Args = LO->parameters(); - llvm::SmallVector ParamStmts(1); - ParamStmts[0] = dyn_cast(LambdaDRE); - - // Collect arguments for () operator - for (auto Arg : Args) { - QualType ArgType = Arg->getOriginalType(); - // Declare variable for parameter and pass it to call - auto param_VD = - VarDecl::Create(S.Context, DC, SourceLocation(), SourceLocation(), - Arg->getIdentifier(), ArgType, - S.Context.getTrivialTypeSourceInfo(ArgType), SC_None); - Stmt *param_DS = new (S.Context) - DeclStmt(DeclGroupRef(param_VD), SourceLocation(), SourceLocation()); - BodyStmts.push_back(param_DS); - auto DRE = DeclRefExpr::Create(S.Context, NestedNameSpecifierLoc(), - SourceLocation(), param_VD, false, - DeclarationNameInfo(), ArgType, VK_LValue); - Expr *Res = ImplicitCastExpr::Create( - S.Context, ArgType, CK_LValueToRValue, DRE, nullptr, VK_RValue); - ParamStmts.push_back(Res); - } + // In function from headers lambda is function parameter, we need + // to replace all refs to this lambda with our vardecl. + // I used TreeTransform here, but I'm not sure that it is good solution + // Also I used map and I'm not sure about it too. + Stmt* FunctionBody = KernelHelper->getBody(); + DeclMap DMap; + ParmVarDecl* LambdaParam = *(KernelHelper->param_begin()); + // DeclRefExpr with valid source location but with decl which is not marked + // as used is invalid. + LambdaVD->setIsUsed(); + DMap[LambdaParam] = LambdaVD; + // Without PushFunctionScope I had segfault. Maybe we also need to do pop. + S.PushFunctionScope(); + KernelBodyTransform KBT(DMap, S); + Stmt* NewBody = KBT.TransformStmt(FunctionBody).get(); + BodyStmts.push_back(NewBody); - // Create ref for call operator - DeclRefExpr *DRE = new (S.Context) - DeclRefExpr(S.Context, LO, false, LO->getType(), VK_LValue, - SourceLocation()); - QualType ResultTy = LO->getReturnType(); - ExprValueKind VK = Expr::getValueKindForType(ResultTy); - ResultTy = ResultTy.getNonLValueExprType(S.Context); - - CXXOperatorCallExpr *TheCall = CXXOperatorCallExpr::Create( - S.Context, OO_Call, DRE, ParamStmts, ResultTy, VK, SourceLocation(), - FPOptions(), clang::CallExpr::ADLCallKind::NotADL ); - BodyStmts.push_back(TheCall); } return CompoundStmt::Create(S.Context, BodyStmts, SourceLocation(), SourceLocation()); @@ -222,9 +233,9 @@ void BuildArgTys(ASTContext &Context, } } -void Sema::ConstructSYCLKernel(CXXMemberCallExpr *e) { +void Sema::ConstructSYCLKernel(FunctionDecl *KernelHelper) { // TODO: Case when kernel is functor - LambdaExpr *LE = getBodyAsLambda(e); + CXXRecordDecl *LE = getBodyAsLambda(KernelHelper); if (LE) { llvm::SmallVector ArgDecls; @@ -238,9 +249,8 @@ void Sema::ConstructSYCLKernel(CXXMemberCallExpr *e) { BuildArgTys(getASTContext(), ArgDecls, NewArgDecls, ArgTys); // Get Name for our kernel. - FunctionDecl *FuncDecl = e->getMethodDecl(); const TemplateArgumentList *TemplateArgs = - FuncDecl->getTemplateSpecializationArgs(); + KernelHelper->getTemplateSpecializationArgs(); QualType KernelNameType = TemplateArgs->get(0).getAsType(); std::string Name = KernelNameType.getBaseTypeIdentifier()->getName().str(); @@ -256,7 +266,7 @@ void Sema::ConstructSYCLKernel(CXXMemberCallExpr *e) { FunctionDecl *SYCLKernel = CreateSYCLKernelFunction(getASTContext(), Name, ArgTys, NewArgDecls); - CompoundStmt *SYCLKernelBody = CreateSYCLKernelBody(*this, e, SYCLKernel); + CompoundStmt *SYCLKernelBody = CreateSYCLKernelBody(*this, KernelHelper, SYCLKernel); SYCLKernel->setBody(SYCLKernelBody); AddSyclKernel(SYCLKernel); diff --git a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp index 31353e45baa83..7068c8ecc1148 100644 --- a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp +++ b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp @@ -5231,14 +5231,28 @@ void Sema::PerformPendingInstantiations(bool LocalOnly) { Function, [this, Inst, DefinitionRequired](FunctionDecl *CurFD) { InstantiateFunctionDefinition(/*FIXME:*/ Inst.second, CurFD, true, DefinitionRequired, true); - if (CurFD->isDefined()) + if (CurFD->isDefined()) { + // Because all SYCL kernel functions are template functions - they + // have deferred instantination. We need bodies of these functions + // so we are checking for SYCL kernel attribute after instantination. + if (getLangOpts().SYCL && CurFD->hasAttr()) { + ConstructSYCLKernel(CurFD); + } CurFD->setInstantiationIsPending(false); + } }); } else { InstantiateFunctionDefinition(/*FIXME:*/ Inst.second, Function, true, DefinitionRequired, true); - if (Function->isDefined()) + if (Function->isDefined()) { + // Because all SYCL kernel functions are template functions - they + // have deferred instantination. We need bodies of these functions + // so we are checking for SYCL kernel attribute after instantination. + if (getLangOpts().SYCL && Function->hasAttr()) { + ConstructSYCLKernel(Function); + } Function->setInstantiationIsPending(false); + } } continue; } diff --git a/clang/test/CodeGenSYCL/kernel-with-id.cpp b/clang/test/CodeGenSYCL/kernel-with-id.cpp index f9ed74a3ec304..0e4af7b2a0188 100644 --- a/clang/test/CodeGenSYCL/kernel-with-id.cpp +++ b/clang/test/CodeGenSYCL/kernel-with-id.cpp @@ -16,10 +16,9 @@ int main() { deviceQueue.submit([&](cl::sycl::handler &cgh) { auto accessorA = bufferA.template get_access(cgh); -// CHECK: %wiID = alloca %"struct.cl::sycl::id", align 8 // CHECK: call spir_func void @_ZN2cl4sycl8accessorIiLi1ELNS0_6access4modeE1026ELNS2_6targetE2014ELNS2_11placeholderE0EE13__set_pointerEPU3AS1i(%"class.cl::sycl::accessor"* %1, i32 addrspace(1)* %2) -// CHECK: call spir_func void @"_ZZZ4mainENK3$_0clERN2cl4sycl7handlerEENKUlNS1_2idILm1EEEE_clES5_"(%class.anon* %0, %"struct.cl::sycl::id"* byval align 8 %wiID) -// CHECK: %call = call spir_func i64 @_Z13get_global_idj(i32 0) +// CHECK: %call = call spir_func i64 @_Z13get_global_idj(i32 %{{.*}}) +// CHECK: call spir_func void @"_ZZZ4mainENK3$_0clERN2cl4sycl7handlerEENKUlNS1_2idILm1EEEE_clES5_"(%class.anon* %0, %"struct.cl::sycl::id"* byval align 8 %{{.*}}) cgh.parallel_for(numOfItems, [=](cl::sycl::id<1> wiID) { accessorA[wiID] = accessorA[wiID] * accessorA[wiID];