Skip to content

Commit

Permalink
[SYCL] Add SYCL Kernel entry point generation.
Browse files Browse the repository at this point in the history
Main changes are:
1. Added of parallel_for and single_task kernel invoking functions search.
2. Added kernel name extraction from single_task/parallel_for template parameter.
3. Added SYCL kernel entry point generation.
4. Non-kernel code is not emmited for sycl device now.

Signed-off-by: Vladimir Lazarev <vladimir.lazarev@intel.com>
  • Loading branch information
vladimirlaz committed Jan 22, 2019
1 parent f509e63 commit 03354a2
Show file tree
Hide file tree
Showing 6 changed files with 261 additions and 0 deletions.
11 changes: 11 additions & 0 deletions clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -10842,6 +10842,17 @@ class Sema {
Expr *E,
llvm::function_ref<void(Expr *, RecordDecl *, FieldDecl *, CharUnits)>
Action);

private:
// We store SYCL Kernels here and handle separately -- which is a hack.
// FIXME: It would be best to refactor this.
SmallVector<Decl*, 4> SyclKernel;

public:
void AddSyclKernel(Decl * d) { SyclKernel.push_back(d); }
SmallVector<Decl*, 4> &SyclKernels() { return SyclKernel; }

void ConstructSYCLKernel(CXXMemberCallExpr* e);
};

/// RAII object that enters a new expression evaluation context.
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/CodeGen/CodeGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2128,6 +2128,11 @@ void CodeGenModule::EmitGlobal(GlobalDecl GD) {
if (Global->hasAttr<IFuncAttr>())
return emitIFuncDefinition(GD);

if (LangOpts.SYCL) {
if (!Global->hasAttr<OpenCLKernelAttr>())
return;
}

// If this is a cpu_dispatch multiversion function, emit the resolver.
if (Global->hasAttr<CPUDispatchAttr>())
return emitCPUDispatchDefinition(GD);
Expand Down
6 changes: 6 additions & 0 deletions clang/lib/Parse/ParseAST.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,12 @@ void clang::ParseAST(Sema &S, bool PrintStats, bool SkipFunctionBodies) {
for (Decl *D : S.WeakTopLevelDecls())
Consumer->HandleTopLevelDecl(DeclGroupRef(D));

if (S.getLangOpts().SYCL) {
for (Decl *D : S.SyclKernels()) {
Consumer->HandleTopLevelDecl(DeclGroupRef(D));
}
}

Consumer->HandleTranslationUnit(S.getASTContext());

// Finalize the template instantiation observer chain.
Expand Down
1 change: 1 addition & 0 deletions clang/lib/Sema/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ add_clang_library(clangSema
SemaStmt.cpp
SemaStmtAsm.cpp
SemaStmtAttr.cpp
SemaSYCL.cpp
SemaTemplate.cpp
SemaTemplateDeduction.cpp
SemaTemplateInstantiate.cpp
Expand Down
9 changes: 9 additions & 0 deletions clang/lib/Sema/SemaOverload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13012,6 +13012,15 @@ Sema::BuildCallToMemberFunction(Scope *S, Expr *MemExprE,
CXXMemberCallExpr::Create(Context, MemExprE, Args, ResultType, VK,
RParenLoc, Proto->getNumParams());

if (getLangOpts().SYCL) {
auto Func = TheCall->getMethodDecl();
auto Name = Func->getQualifiedNameAsString();
if (Name == "cl::sycl::handler::parallel_for" ||
Name == "cl::sycl::handler::single_task") {
ConstructSYCLKernel(TheCall);
}
}

// Check for a valid return type.
if (CheckCallReturnType(Method->getReturnType(), MemExpr->getMemberLoc(),
TheCall, Method))
Expand Down
229 changes: 229 additions & 0 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
//===- SemaSYCL.cpp - Semantic Analysis for SYCL constructs ---------------===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
// This implements Semantic Analysis for SYCL constructs.
//===----------------------------------------------------------------------===//

#include "clang/AST/AST.h"
#include "clang/Sema/Sema.h"
#include "llvm/ADT/SmallVector.h"

using namespace clang;

LambdaExpr *getBodyAsLambda(CXXMemberCallExpr *e) {
auto LastArg = e->getArg(e->getNumArgs() - 1);
return dyn_cast<LambdaExpr>(LastArg);
}

FunctionDecl *CreateSYCLKernelFunction(ASTContext &Context, StringRef Name,
ArrayRef<QualType> ArgTys,
ArrayRef<DeclaratorDecl *> ArgDecls) {

DeclContext *DC = Context.getTranslationUnitDecl();
FunctionProtoType::ExtProtoInfo Info;
QualType RetTy = Context.VoidTy;
QualType FuncTy = Context.getFunctionType(RetTy, ArgTys, Info);
DeclarationName DN = DeclarationName(&Context.Idents.get(Name));
FunctionDecl *Result = FunctionDecl::Create(
Context, DC, SourceLocation(), SourceLocation(), DN, FuncTy,
Context.getTrivialTypeSourceInfo(RetTy), SC_None);
llvm::SmallVector<ParmVarDecl *, 16> Params;
int i = 0;
for (auto ArgTy : ArgTys) {
auto P =
ParmVarDecl::Create(Context, Result, SourceLocation(), SourceLocation(),
ArgDecls[i]->getIdentifier(), ArgTy,
ArgDecls[i]->getTypeSourceInfo(), SC_None, 0);
P->setScopeInfo(0, i++);
P->setIsUsed();
Params.push_back(P);
}
Result->setParams(Params);
// TODO: Add SYCL specific attribute for kernel and all functions called
// by kernel.
Result->addAttr(OpenCLKernelAttr::CreateImplicit(Context));
Result->addAttr(AsmLabelAttr::CreateImplicit(Context, Name));
return Result;
}

CompoundStmt *CreateSYCLKernelBody(Sema &S, CXXMemberCallExpr *e,
DeclContext *DC) {

llvm::SmallVector<Stmt *, 16> BodyStmts;

// TODO: case when kernel is functor
// TODO: possible refactoring when functor case will be completed
LambdaExpr *LE = getBodyAsLambda(e);
if (LE) {
// Create Lambda object
CXXRecordDecl *LC = LE->getLambdaClass();
auto Lambda_VD = VarDecl::Create(
S.Context, DC, SourceLocation(), SourceLocation(), LC->getIdentifier(),
QualType(LC->getTypeForDecl(), 0), LC->getLambdaTypeInfo(), SC_None);
Stmt *DS = new (S.Context)
DeclStmt(DeclGroupRef(Lambda_VD), SourceLocation(), SourceLocation());
BodyStmts.push_back(DS);
auto Lambda_DRE = DeclRefExpr::Create(
S.Context, NestedNameSpecifierLoc(), SourceLocation(), Lambda_VD, false,
DeclarationNameInfo(), QualType(LC->getTypeForDecl(), 0), VK_LValue);

// Init Lambda fields
llvm::SmallVector<Expr *, 16> InitCaptures;

auto TargetFunc = dyn_cast<FunctionDecl>(DC);
auto TargetFuncParam =
TargetFunc->param_begin(); // Iterator to ParamVarDecl (VarDecl)
for (auto CaptureField : LE->captures()) {
VarDecl *CapturedVar =
CaptureField
.getCapturedVar(); // accessor, need to do setInit for this
QualType ParamType = (*TargetFuncParam)->getOriginalType();
auto DRE = DeclRefExpr::Create(
S.Context, NestedNameSpecifierLoc(), SourceLocation(),
*TargetFuncParam, false, DeclarationNameInfo(), ParamType, VK_LValue);

Expr *Res = ImplicitCastExpr::Create(
S.Context, ParamType, CK_LValueToRValue, DRE, nullptr, VK_RValue);

Expr *InitCapture = new (S.Context) InitListExpr(
S.Context, SourceLocation(), /*initExprs*/ Res, SourceLocation());
CapturedVar->setInit(InitCapture);
InitCapture->setType(CapturedVar->getType());
InitCaptures.push_back(InitCapture);
TargetFuncParam++;
}

Expr *InitLambdaCaptures = new (S.Context)
InitListExpr(S.Context, SourceLocation(), /*initExprs*/ InitCaptures,
SourceLocation());
InitLambdaCaptures->setType(Lambda_VD->getType());
Lambda_VD->setInit(InitLambdaCaptures);

// Create Lambda operator () call
FunctionDecl *LO = LE->getCallOperator();
ArrayRef<ParmVarDecl *> Args = LO->parameters();
llvm::SmallVector<Expr *, 16> ParamStmts(1);
ParamStmts[0] = dyn_cast<Expr>(Lambda_DRE);

// Collect arguments for () operator
for (auto Arg : Args) {
QualType ArgType = Arg->getOriginalType();
// Declare variable for parameter and pass it to call
auto param_VD =
VarDecl::Create(S.Context, DC, SourceLocation(), SourceLocation(),
Arg->getIdentifier(), ArgType,
S.Context.getTrivialTypeSourceInfo(ArgType), SC_None);
Stmt *param_DS = new (S.Context)
DeclStmt(DeclGroupRef(param_VD), SourceLocation(), SourceLocation());
BodyStmts.push_back(param_DS);
auto DRE = DeclRefExpr::Create(S.Context, NestedNameSpecifierLoc(),
SourceLocation(), param_VD, false,
DeclarationNameInfo(), ArgType, VK_LValue);
ParamStmts.push_back(DRE);
}

// Create ref for call operator
DeclRefExpr *DRE = new (S.Context)
DeclRefExpr(S.Context, LO, false, LO->getType(), VK_LValue,
SourceLocation());
QualType ResultTy = LO->getReturnType();
ExprValueKind VK = Expr::getValueKindForType(ResultTy);
ResultTy = ResultTy.getNonLValueExprType(S.Context);

CXXOperatorCallExpr *TheCall = CXXOperatorCallExpr::Create(
S.Context, OO_Call, DRE, ParamStmts, ResultTy, VK, SourceLocation(),
FPOptions(), clang::CallExpr::ADLCallKind::NotADL );
BodyStmts.push_back(TheCall);
}
return CompoundStmt::Create(S.Context, BodyStmts, SourceLocation(),
SourceLocation());
}

void BuildArgTys(ASTContext &Context,
llvm::SmallVector<DeclaratorDecl *, 16> &ArgDecls,
llvm::SmallVector<DeclaratorDecl *, 16> &NewArgDecls,
llvm::SmallVector<QualType, 16> &ArgTys) {
for (auto V : ArgDecls) {
QualType ArgTy = V->getType();
QualType ActualArgType = ArgTy;
StringRef Name = ArgTy.getBaseTypeIdentifier()->getName();
// TODO: harden this check with additional validation that this class is
// declared in cl::sycl namespace
if (std::string(Name) == "accessor") {
if (const auto *RecordDecl = ArgTy->getAsCXXRecordDecl()) {
const auto *TemplateDecl =
dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl);
if (TemplateDecl) {
QualType PointeeType = TemplateDecl->getTemplateArgs()[0].getAsType();
Qualifiers Quals = PointeeType.getQualifiers();
Quals.setAddressSpace(LangAS::opencl_global);
PointeeType =
Context.getQualifiedType(PointeeType.getUnqualifiedType(), Quals);
QualType PointerType = Context.getPointerType(PointeeType);
ActualArgType =
Context.getQualifiedType(PointerType.getUnqualifiedType(), Quals);
}
}
}
DeclContext *DC = Context.getTranslationUnitDecl();

IdentifierInfo *VarName = 0;
SmallString<8> Str;
llvm::raw_svector_ostream OS(Str);
OS << "_arg_" << V->getIdentifier()->getName();
VarName = &Context.Idents.get(OS.str());

auto NewVarDecl = VarDecl::Create(
Context, DC, SourceLocation(), SourceLocation(), VarName, ActualArgType,
Context.getTrivialTypeSourceInfo(ActualArgType), SC_None);
ArgTys.push_back(ActualArgType);
NewArgDecls.push_back(NewVarDecl);
}
}

void Sema::ConstructSYCLKernel(CXXMemberCallExpr *e) {
// TODO: Case when kernel is functor
LambdaExpr *LE = getBodyAsLambda(e);
if (LE) {

llvm::SmallVector<DeclaratorDecl *, 16> ArgDecls;

for (const auto &V : LE->captures()) {
ArgDecls.push_back(V.getCapturedVar());
}

llvm::SmallVector<QualType, 16> ArgTys;
llvm::SmallVector<DeclaratorDecl *, 16> NewArgDecls;
BuildArgTys(getASTContext(), ArgDecls, NewArgDecls, ArgTys);

// Get Name for our kernel.
FunctionDecl *FuncDecl = e->getMethodDecl();
const TemplateArgumentList *TemplateArgs =
FuncDecl->getTemplateSpecializationArgs();
QualType KernelNameType = TemplateArgs->get(0).getAsType();
std::string Name = KernelNameType.getBaseTypeIdentifier()->getName().str();

if (const auto *RecordDecl = KernelNameType->getAsCXXRecordDecl()) {
const auto *TemplateDecl =
dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl);
if (TemplateDecl) {
QualType ParamType = TemplateDecl->getTemplateArgs()[0].getAsType();
Name += "_" + ParamType.getAsString() + "_";
}
}

FunctionDecl *SYCLKernel =
CreateSYCLKernelFunction(getASTContext(), Name, ArgTys, NewArgDecls);

CompoundStmt *SYCLKernelBody = CreateSYCLKernelBody(*this, e, SYCLKernel);
SYCLKernel->setBody(SYCLKernelBody);

AddSyclKernel(SYCLKernel);
}
}

0 comments on commit 03354a2

Please sign in to comment.