Skip to content

Commit

Permalink
[Clang][OpenMP] Allow num_teams to accept multiple expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
shiltian committed Jul 29, 2024
1 parent 378fe2f commit 37c28db
Show file tree
Hide file tree
Showing 15 changed files with 478 additions and 312 deletions.
79 changes: 48 additions & 31 deletions clang/include/clang/AST/OpenMPClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -6131,60 +6131,77 @@ class OMPMapClause final : public OMPMappableExprListClause<OMPMapClause>,
/// \endcode
/// In this example directive '#pragma omp teams' has clause 'num_teams'
/// with single expression 'n'.
class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit {
friend class OMPClauseReader;
///
/// When 'ompx_bare' clause exists on a 'target' directive, 'num_teams' clause
/// can accept up to three expressions.
///
/// \code
/// #pragma omp target teams ompx_bare num_teams(x, y, z)
/// \endcode
class OMPNumTeamsClause final
: public OMPVarListClause<OMPNumTeamsClause>,
public OMPClauseWithPreInit,
private llvm::TrailingObjects<OMPNumTeamsClause, Expr *> {
friend OMPVarListClause;
friend TrailingObjects;

/// Location of '('.
SourceLocation LParenLoc;

/// NumTeams number.
Stmt *NumTeams = nullptr;
OMPNumTeamsClause(const ASTContext &C, SourceLocation StartLoc,
SourceLocation LParenLoc, SourceLocation EndLoc, unsigned N)
: OMPVarListClause(llvm::omp::OMPC_num_teams, StartLoc, LParenLoc, EndLoc,
N),
OMPClauseWithPreInit(this) {}

/// Set the NumTeams number.
///
/// \param E NumTeams number.
void setNumTeams(Expr *E) { NumTeams = E; }
/// Build an empty clause.
OMPNumTeamsClause(unsigned N)
: OMPVarListClause(llvm::omp::OMPC_num_teams, SourceLocation(),
SourceLocation(), SourceLocation(), N),
OMPClauseWithPreInit(this) {}

public:
/// Build 'num_teams' clause.
/// Creates clause with a list of variables \a VL.
///
/// \param E Expression associated with this clause.
/// \param HelperE Helper Expression associated with this clause.
/// \param CaptureRegion Innermost OpenMP region where expressions in this
/// clause must be captured.
/// \param C AST context.
/// \param StartLoc Starting location of the clause.
/// \param LParenLoc Location of '('.
/// \param EndLoc Ending location of the clause.
OMPNumTeamsClause(Expr *E, Stmt *HelperE, OpenMPDirectiveKind CaptureRegion,
SourceLocation StartLoc, SourceLocation LParenLoc,
SourceLocation EndLoc)
: OMPClause(llvm::omp::OMPC_num_teams, StartLoc, EndLoc),
OMPClauseWithPreInit(this), LParenLoc(LParenLoc), NumTeams(E) {
setPreInitStmt(HelperE, CaptureRegion);
}
/// \param VL List of references to the variables.
/// \param PreInit
static OMPNumTeamsClause *Create(const ASTContext &C, SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc, ArrayRef<Expr *> VL,
Stmt *PreInit);

/// Build an empty clause.
OMPNumTeamsClause()
: OMPClause(llvm::omp::OMPC_num_teams, SourceLocation(),
SourceLocation()),
OMPClauseWithPreInit(this) {}
/// Creates an empty clause with \a N variables.
///
/// \param C AST context.
/// \param N The number of variables.
static OMPNumTeamsClause *CreateEmpty(const ASTContext &C, unsigned N);

/// Sets the location of '('.
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }

/// Returns the location of '('.
SourceLocation getLParenLoc() const { return LParenLoc; }

/// Return NumTeams number.
Expr *getNumTeams() { return cast<Expr>(NumTeams); }
/// Return NumTeams expressions.
ArrayRef<Expr *> getNumTeams() { return getVarRefs(); }

/// Return NumTeams number.
Expr *getNumTeams() const { return cast<Expr>(NumTeams); }
/// Return NumTeams expressions.
ArrayRef<Expr *> getNumTeams() const {
return const_cast<OMPNumTeamsClause *>(this)->getNumTeams();
}

child_range children() { return child_range(&NumTeams, &NumTeams + 1); }
child_range children() {
return child_range(reinterpret_cast<Stmt **>(varlist_begin()),
reinterpret_cast<Stmt **>(varlist_end()));
}

const_child_range children() const {
return const_child_range(&NumTeams, &NumTeams + 1);
auto Children = const_cast<OMPNumTeamsClause *>(this)->children();
return const_child_range(Children.begin(), Children.end());
}

child_range used_children() {
Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/AST/RecursiveASTVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -3793,8 +3793,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPMapClause(OMPMapClause *C) {
template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPNumTeamsClause(
OMPNumTeamsClause *C) {
TRY_TO(VisitOMPClauseList(C));
TRY_TO(VisitOMPClauseWithPreInit(C));
TRY_TO(TraverseStmt(C->getNumTeams()));
return true;
}

Expand Down
3 changes: 2 additions & 1 deletion clang/include/clang/Sema/SemaOpenMP.h
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,8 @@ class SemaOpenMP : public SemaBase {
const OMPVarListLocTy &Locs, bool NoDiagnose = false,
ArrayRef<Expr *> UnresolvedMappers = std::nullopt);
/// Called on well-formed 'num_teams' clause.
OMPClause *ActOnOpenMPNumTeamsClause(Expr *NumTeams, SourceLocation StartLoc,
OMPClause *ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc);
/// Called on well-formed 'thread_limit' clause.
Expand Down
26 changes: 23 additions & 3 deletions clang/lib/AST/OpenMPClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1720,6 +1720,24 @@ const Expr *OMPDoacrossClause::getLoopData(unsigned NumLoop) const {
return *It;
}

OMPNumTeamsClause *
OMPNumTeamsClause::Create(const ASTContext &C, SourceLocation StartLoc,
SourceLocation LParenLoc, SourceLocation EndLoc,
ArrayRef<Expr *> VL, Stmt *PreInit) {
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size()));
OMPNumTeamsClause *Clause =
new (Mem) OMPNumTeamsClause(C, StartLoc, LParenLoc, EndLoc, VL.size());
Clause->setVarRefs(VL);
Clause->setPreInitStmt(PreInit);
return Clause;
}

OMPNumTeamsClause *OMPNumTeamsClause::CreateEmpty(const ASTContext &C,
unsigned N) {
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N));
return new (Mem) OMPNumTeamsClause(N);
}

//===----------------------------------------------------------------------===//
// OpenMP clauses printing methods
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1977,9 +1995,11 @@ void OMPClausePrinter::VisitOMPDeviceClause(OMPDeviceClause *Node) {
}

void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) {
OS << "num_teams(";
Node->getNumTeams()->printPretty(OS, nullptr, Policy, 0);
OS << ")";
if (!Node->varlist_empty()) {
OS << "num_teams";
VisitOMPClauseList(Node, '(');
OS << ")";
}
}

void OMPClausePrinter::VisitOMPThreadLimitClause(OMPThreadLimitClause *Node) {
Expand Down
3 changes: 1 addition & 2 deletions clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,9 +843,8 @@ void OMPClauseProfiler::VisitOMPAllocateClause(const OMPAllocateClause *C) {
VisitOMPClauseList(C);
}
void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) {
VisitOMPClauseList(C);
VistOMPClauseWithPreInit(C);
if (C->getNumTeams())
Profiler->VisitStmt(C->getNumTeams());
}
void OMPClauseProfiler::VisitOMPThreadLimitClause(
const OMPThreadLimitClause *C) {
Expand Down
7 changes: 4 additions & 3 deletions clang/lib/CodeGen/CGOpenMPRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6036,8 +6036,9 @@ const Expr *CGOpenMPRuntime::getNumTeamsExprForTargetDirective(
dyn_cast_or_null<OMPExecutableDirective>(ChildStmt)) {
if (isOpenMPTeamsDirective(NestedDir->getDirectiveKind())) {
if (NestedDir->hasClausesOfKind<OMPNumTeamsClause>()) {
const Expr *NumTeams =
NestedDir->getSingleClause<OMPNumTeamsClause>()->getNumTeams();
const Expr *NumTeams = NestedDir->getSingleClause<OMPNumTeamsClause>()
->getNumTeams()
.front();
if (NumTeams->isIntegerConstantExpr(CGF.getContext()))
if (auto Constant =
NumTeams->getIntegerConstantExpr(CGF.getContext()))
Expand All @@ -6062,7 +6063,7 @@ const Expr *CGOpenMPRuntime::getNumTeamsExprForTargetDirective(
case OMPD_target_teams_distribute_parallel_for_simd: {
if (D.hasClausesOfKind<OMPNumTeamsClause>()) {
const Expr *NumTeams =
D.getSingleClause<OMPNumTeamsClause>()->getNumTeams();
D.getSingleClause<OMPNumTeamsClause>()->getNumTeams().front();
if (NumTeams->isIntegerConstantExpr(CGF.getContext()))
if (auto Constant = NumTeams->getIntegerConstantExpr(CGF.getContext()))
MinTeamsVal = MaxTeamsVal = Constant->getExtValue();
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CodeGen/CGStmtOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6859,7 +6859,7 @@ static void emitCommonOMPTeamsDirective(CodeGenFunction &CGF,
const auto *NT = S.getSingleClause<OMPNumTeamsClause>();
const auto *TL = S.getSingleClause<OMPThreadLimitClause>();
if (NT || TL) {
const Expr *NumTeams = NT ? NT->getNumTeams() : nullptr;
const Expr *NumTeams = NT ? NT->getNumTeams().front() : nullptr;
const Expr *ThreadLimit = TL ? TL->getThreadLimit() : nullptr;

CGF.CGM.getOpenMPRuntime().emitNumTeamsClause(CGF, NumTeams, ThreadLimit,
Expand Down
8 changes: 7 additions & 1 deletion clang/lib/Parse/ParseOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3098,7 +3098,6 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
case OMPC_simdlen:
case OMPC_collapse:
case OMPC_ordered:
case OMPC_num_teams:
case OMPC_thread_limit:
case OMPC_priority:
case OMPC_grainsize:
Expand Down Expand Up @@ -3252,6 +3251,13 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
? ParseOpenMPSimpleClause(CKind, WrongDirective)
: ParseOpenMPClause(CKind, WrongDirective);
break;
case OMPC_num_teams:
if (!FirstClause) {
Diag(Tok, diag::err_omp_more_one_clause)
<< getOpenMPDirectiveName(DKind) << getOpenMPClauseName(CKind) << 0;
ErrorFound = true;
}
[[clang::fallthrough]];
case OMPC_private:
case OMPC_firstprivate:
case OMPC_lastprivate:
Expand Down
53 changes: 35 additions & 18 deletions clang/lib/Sema/SemaOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13785,6 +13785,16 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDirective(
return StmtError();
}

auto NumTeamsClauseItr =
llvm::find_if(Clauses, llvm::IsaPred<OMPNumTeamsClause>);
if (NumTeamsClauseItr != Clauses.end()) {
ArrayRef<const Expr *> NumTeams =
cast<OMPNumTeamsClause>(*NumTeamsClauseItr)->getNumTeams();
if (!HasBareClause && NumTeams.size() > 1) {
return StmtError();
}
}

return OMPTargetTeamsDirective::Create(getASTContext(), StartLoc, EndLoc,
Clauses, AStmt);
}
Expand Down Expand Up @@ -14925,9 +14935,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
case OMPC_ordered:
Res = ActOnOpenMPOrderedClause(StartLoc, EndLoc, LParenLoc, Expr);
break;
case OMPC_num_teams:
Res = ActOnOpenMPNumTeamsClause(Expr, StartLoc, LParenLoc, EndLoc);
break;
case OMPC_thread_limit:
Res = ActOnOpenMPThreadLimitClause(Expr, StartLoc, LParenLoc, EndLoc);
break;
Expand Down Expand Up @@ -15031,6 +15038,7 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
case OMPC_affinity:
case OMPC_when:
case OMPC_bind:
case OMPC_num_teams:
default:
llvm_unreachable("Clause is not allowed.");
}
Expand Down Expand Up @@ -16894,6 +16902,9 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
static_cast<OpenMPDoacrossClauseModifier>(ExtraModifier),
ExtraModifierLoc, ColonLoc, VarList, StartLoc, LParenLoc, EndLoc);
break;
case OMPC_num_teams:
Res = ActOnOpenMPNumTeamsClause(VarList, StartLoc, LParenLoc, EndLoc);
break;
case OMPC_if:
case OMPC_depobj:
case OMPC_final:
Expand Down Expand Up @@ -16924,7 +16935,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
case OMPC_device:
case OMPC_threads:
case OMPC_simd:
case OMPC_num_teams:
case OMPC_thread_limit:
case OMPC_priority:
case OMPC_grainsize:
Expand Down Expand Up @@ -21587,32 +21597,39 @@ const ValueDecl *SemaOpenMP::getOpenMPDeclareMapperVarName() const {
return cast<DeclRefExpr>(DSAStack->getDeclareMapperVarRef())->getDecl();
}

OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(Expr *NumTeams,
OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc) {
Expr *ValExpr = NumTeams;
Stmt *HelperValStmt = nullptr;

// OpenMP [teams Constrcut, Restrictions]
// The num_teams expression must evaluate to a positive integer value.
if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
/*StrictlyPositive=*/true))
if (VarList.empty())
return nullptr;

for (Expr *ValExpr : VarList) {
// OpenMP [teams Constrcut, Restrictions]
// The num_teams expression must evaluate to a positive integer value.
if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
/*StrictlyPositive=*/true))
return nullptr;
}

OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective();
OpenMPDirectiveKind CaptureRegion = getOpenMPCaptureRegionForClause(
DKind, OMPC_num_teams, getLangOpts().OpenMP);
if (CaptureRegion != OMPD_unknown &&
!SemaRef.CurContext->isDependentContext()) {
if (CaptureRegion == OMPD_unknown || SemaRef.CurContext->isDependentContext())
return OMPNumTeamsClause::Create(getASTContext(), StartLoc, LParenLoc,
EndLoc, VarList, /*PreInit=*/nullptr);

llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
SmallVector<Expr *, 3> Vars;
for (Expr *ValExpr : VarList) {
ValExpr = SemaRef.MakeFullExpr(ValExpr).get();
llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
ValExpr = tryBuildCapture(SemaRef, ValExpr, Captures).get();
HelperValStmt = buildPreInits(getASTContext(), Captures);
Vars.push_back(ValExpr);
}

return new (getASTContext()) OMPNumTeamsClause(
ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc);
Stmt *PreInit = buildPreInits(getASTContext(), Captures);
return OMPNumTeamsClause::Create(getASTContext(), StartLoc, LParenLoc, EndLoc,
Vars, PreInit);
}

OMPClause *SemaOpenMP::ActOnOpenMPThreadLimitClause(Expr *ThreadLimit,
Expand Down
7 changes: 4 additions & 3 deletions clang/lib/Sema/TreeTransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -2065,10 +2065,11 @@ class TreeTransform {
///
/// By default, performs semantic analysis to build the new statement.
/// Subclasses may override this routine to provide different behavior.
OMPClause *RebuildOMPNumTeamsClause(Expr *NumTeams, SourceLocation StartLoc,
OMPClause *RebuildOMPNumTeamsClause(ArrayRef<Expr *> VarList,
SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc) {
return getSema().OpenMP().ActOnOpenMPNumTeamsClause(NumTeams, StartLoc,
return getSema().OpenMP().ActOnOpenMPNumTeamsClause(VarList, StartLoc,
LParenLoc, EndLoc);
}

Expand Down Expand Up @@ -10872,7 +10873,7 @@ TreeTransform<Derived>::TransformOMPAllocateClause(OMPAllocateClause *C) {
template <typename Derived>
OMPClause *
TreeTransform<Derived>::TransformOMPNumTeamsClause(OMPNumTeamsClause *C) {
ExprResult E = getDerived().TransformExpr(C->getNumTeams());
ExprResult E = getDerived().TransformExpr(C->getNumTeams().front());
if (E.isInvalid())
return nullptr;
return getDerived().RebuildOMPNumTeamsClause(
Expand Down
10 changes: 8 additions & 2 deletions clang/lib/Serialization/ASTReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
#include "llvm/ADT/IntrusiveRefCntPtr.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/SmallVector.h"
Expand Down Expand Up @@ -10562,7 +10563,7 @@ OMPClause *OMPClauseReader::readClause() {
break;
}
case llvm::omp::OMPC_num_teams:
C = new (Context) OMPNumTeamsClause();
C = OMPNumTeamsClause::CreateEmpty(Context, Record.readInt());
break;
case llvm::omp::OMPC_thread_limit:
C = new (Context) OMPThreadLimitClause();
Expand Down Expand Up @@ -11350,8 +11351,13 @@ void OMPClauseReader::VisitOMPAllocateClause(OMPAllocateClause *C) {

void OMPClauseReader::VisitOMPNumTeamsClause(OMPNumTeamsClause *C) {
VisitOMPClauseWithPreInit(C);
C->setNumTeams(Record.readSubExpr());
C->setLParenLoc(Record.readSourceLocation());
unsigned NumVars = C->varlist_size();
SmallVector<Expr *, 16> Vars;
Vars.reserve(NumVars);
for ([[maybe_unused]] unsigned I : llvm::seq<unsigned>(NumVars))
Vars.push_back(Record.readSubExpr());
C->setVarRefs(Vars);
}

void OMPClauseReader::VisitOMPThreadLimitClause(OMPThreadLimitClause *C) {
Expand Down
Loading

0 comments on commit 37c28db

Please sign in to comment.