Skip to content

Commit

Permalink
[Clang][OpenMP] Allow num_teams to accept multiple expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
shiltian committed Jul 23, 2024
1 parent 9e97f80 commit 7d27a0a
Show file tree
Hide file tree
Showing 11 changed files with 472 additions and 304 deletions.
79 changes: 48 additions & 31 deletions clang/include/clang/AST/OpenMPClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -6131,60 +6131,77 @@ class OMPMapClause final : public OMPMappableExprListClause<OMPMapClause>,
/// \endcode
/// In this example directive '#pragma omp teams' has clause 'num_teams'
/// with single expression 'n'.
class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit {
friend class OMPClauseReader;
///
/// When 'ompx_bare' clause exists on a 'target' directive, 'num_teams' clause
/// can accept up to three expressions.
///
/// \code
/// #pragma omp target teams ompx_bare num_teams(x, y, z)
/// \endcode
class OMPNumTeamsClause final
: public OMPVarListClause<OMPNumTeamsClause>,
public OMPClauseWithPreInit,
private llvm::TrailingObjects<OMPNumTeamsClause, Expr *> {
friend OMPVarListClause;
friend TrailingObjects;

/// Location of '('.
SourceLocation LParenLoc;

/// NumTeams number.
Stmt *NumTeams = nullptr;
OMPNumTeamsClause(const ASTContext &C, SourceLocation StartLoc,
SourceLocation LParenLoc, SourceLocation EndLoc, unsigned N)
: OMPVarListClause(llvm::omp::OMPC_num_teams, StartLoc, LParenLoc, EndLoc,
N),
OMPClauseWithPreInit(this) {}

/// Set the NumTeams number.
///
/// \param E NumTeams number.
void setNumTeams(Expr *E) { NumTeams = E; }
/// Build an empty clause.
OMPNumTeamsClause(unsigned N)
: OMPVarListClause(llvm::omp::OMPC_num_teams, SourceLocation(),
SourceLocation(), SourceLocation(), N),
OMPClauseWithPreInit(this) {}

public:
/// Build 'num_teams' clause.
/// Creates clause with a list of variables \a VL.
///
/// \param E Expression associated with this clause.
/// \param HelperE Helper Expression associated with this clause.
/// \param CaptureRegion Innermost OpenMP region where expressions in this
/// clause must be captured.
/// \param C AST context.
/// \param StartLoc Starting location of the clause.
/// \param LParenLoc Location of '('.
/// \param EndLoc Ending location of the clause.
OMPNumTeamsClause(Expr *E, Stmt *HelperE, OpenMPDirectiveKind CaptureRegion,
SourceLocation StartLoc, SourceLocation LParenLoc,
SourceLocation EndLoc)
: OMPClause(llvm::omp::OMPC_num_teams, StartLoc, EndLoc),
OMPClauseWithPreInit(this), LParenLoc(LParenLoc), NumTeams(E) {
setPreInitStmt(HelperE, CaptureRegion);
}
/// \param VL List of references to the variables.
/// \param PreInit
static OMPNumTeamsClause *Create(const ASTContext &C, SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc, ArrayRef<Expr *> VL,
Stmt *PreInit);

/// Build an empty clause.
OMPNumTeamsClause()
: OMPClause(llvm::omp::OMPC_num_teams, SourceLocation(),
SourceLocation()),
OMPClauseWithPreInit(this) {}
/// Creates an empty clause with \a N variables.
///
/// \param C AST context.
/// \param N The number of variables.
static OMPNumTeamsClause *CreateEmpty(const ASTContext &C, unsigned N);

/// Sets the location of '('.
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }

/// Returns the location of '('.
SourceLocation getLParenLoc() const { return LParenLoc; }

/// Return NumTeams number.
Expr *getNumTeams() { return cast<Expr>(NumTeams); }
/// Return NumTeams number. By default, we return the first expression.
Expr *getNumTeams() { return getVarRefs().front(); }

/// Return NumTeams number.
Expr *getNumTeams() const { return cast<Expr>(NumTeams); }
/// Return NumTeams number. By default, we return the first expression.
Expr *getNumTeams() const {
return const_cast<OMPNumTeamsClause *>(this)->getNumTeams();
}

child_range children() { return child_range(&NumTeams, &NumTeams + 1); }
child_range children() {
return child_range(reinterpret_cast<Stmt **>(varlist_begin()),
reinterpret_cast<Stmt **>(varlist_end()));
}

const_child_range children() const {
return const_child_range(&NumTeams, &NumTeams + 1);
auto Children = const_cast<OMPNumTeamsClause *>(this)->children();
return const_child_range(Children.begin(), Children.end());
}

child_range used_children() {
Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/AST/RecursiveASTVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -3793,8 +3793,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPMapClause(OMPMapClause *C) {
template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPNumTeamsClause(
OMPNumTeamsClause *C) {
TRY_TO(VisitOMPClauseList(C));
TRY_TO(VisitOMPClauseWithPreInit(C));
TRY_TO(TraverseStmt(C->getNumTeams()));
return true;
}

Expand Down
3 changes: 2 additions & 1 deletion clang/include/clang/Sema/SemaOpenMP.h
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,8 @@ class SemaOpenMP : public SemaBase {
const OMPVarListLocTy &Locs, bool NoDiagnose = false,
ArrayRef<Expr *> UnresolvedMappers = std::nullopt);
/// Called on well-formed 'num_teams' clause.
OMPClause *ActOnOpenMPNumTeamsClause(Expr *NumTeams, SourceLocation StartLoc,
OMPClause *ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc);
/// Called on well-formed 'thread_limit' clause.
Expand Down
26 changes: 23 additions & 3 deletions clang/lib/AST/OpenMPClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1720,6 +1720,24 @@ const Expr *OMPDoacrossClause::getLoopData(unsigned NumLoop) const {
return *It;
}

OMPNumTeamsClause *
OMPNumTeamsClause::Create(const ASTContext &C, SourceLocation StartLoc,
SourceLocation LParenLoc, SourceLocation EndLoc,
ArrayRef<Expr *> VL, Stmt *PreInit) {
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size()));
OMPNumTeamsClause *Clause =
new (Mem) OMPNumTeamsClause(C, StartLoc, LParenLoc, EndLoc, VL.size());
Clause->setVarRefs(VL);
Clause->setPreInitStmt(PreInit);
return Clause;
}

OMPNumTeamsClause *OMPNumTeamsClause::CreateEmpty(const ASTContext &C,
unsigned N) {
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N));
return new (Mem) OMPNumTeamsClause(N);
}

//===----------------------------------------------------------------------===//
// OpenMP clauses printing methods
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1977,9 +1995,11 @@ void OMPClausePrinter::VisitOMPDeviceClause(OMPDeviceClause *Node) {
}

void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) {
OS << "num_teams(";
Node->getNumTeams()->printPretty(OS, nullptr, Policy, 0);
OS << ")";
if (!Node->varlist_empty()) {
OS << "num_teams";
VisitOMPClauseList(Node, '(');
OS << ")";
}
}

void OMPClausePrinter::VisitOMPThreadLimitClause(OMPThreadLimitClause *Node) {
Expand Down
3 changes: 1 addition & 2 deletions clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,9 +843,8 @@ void OMPClauseProfiler::VisitOMPAllocateClause(const OMPAllocateClause *C) {
VisitOMPClauseList(C);
}
void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) {
VisitOMPClauseList(C);
VistOMPClauseWithPreInit(C);
if (C->getNumTeams())
Profiler->VisitStmt(C->getNumTeams());
}
void OMPClauseProfiler::VisitOMPThreadLimitClause(
const OMPThreadLimitClause *C) {
Expand Down
8 changes: 7 additions & 1 deletion clang/lib/Parse/ParseOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3098,7 +3098,6 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
case OMPC_simdlen:
case OMPC_collapse:
case OMPC_ordered:
case OMPC_num_teams:
case OMPC_thread_limit:
case OMPC_priority:
case OMPC_grainsize:
Expand Down Expand Up @@ -3252,6 +3251,13 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
? ParseOpenMPSimpleClause(CKind, WrongDirective)
: ParseOpenMPClause(CKind, WrongDirective);
break;
case OMPC_num_teams:
if (!FirstClause) {
Diag(Tok, diag::err_omp_more_one_clause)
<< getOpenMPDirectiveName(DKind) << getOpenMPClauseName(CKind) << 0;
ErrorFound = true;
}
[[clang::fallthrough]];
case OMPC_private:
case OMPC_firstprivate:
case OMPC_lastprivate:
Expand Down
58 changes: 40 additions & 18 deletions clang/lib/Sema/SemaOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13901,6 +13901,20 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDirective(
return StmtError();
}

const OMPClause *NumTeamsClause = nullptr;
bool HasNumTeamsClause = llvm::any_of(Clauses, [&](const OMPClause *C) {
NumTeamsClause = C;
return C->getClauseKind() == OMPC_num_teams;
});

if (HasNumTeamsClause) {
ArrayRef<const Expr *> NumTeams =
cast<OMPNumTeamsClause>(NumTeamsClause)->getVarRefs();
if (!HasBareClause && NumTeams.size() > 1) {
return StmtError();
}
}

return OMPTargetTeamsDirective::Create(getASTContext(), StartLoc, EndLoc,
Clauses, AStmt);
}
Expand Down Expand Up @@ -15041,9 +15055,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
case OMPC_ordered:
Res = ActOnOpenMPOrderedClause(StartLoc, EndLoc, LParenLoc, Expr);
break;
case OMPC_num_teams:
Res = ActOnOpenMPNumTeamsClause(Expr, StartLoc, LParenLoc, EndLoc);
break;
case OMPC_thread_limit:
Res = ActOnOpenMPThreadLimitClause(Expr, StartLoc, LParenLoc, EndLoc);
break;
Expand Down Expand Up @@ -15147,6 +15158,7 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
case OMPC_affinity:
case OMPC_when:
case OMPC_bind:
case OMPC_num_teams:
default:
llvm_unreachable("Clause is not allowed.");
}
Expand Down Expand Up @@ -17010,6 +17022,9 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
static_cast<OpenMPDoacrossClauseModifier>(ExtraModifier),
ExtraModifierLoc, ColonLoc, VarList, StartLoc, LParenLoc, EndLoc);
break;
case OMPC_num_teams:
Res = ActOnOpenMPNumTeamsClause(VarList, StartLoc, LParenLoc, EndLoc);
break;
case OMPC_if:
case OMPC_depobj:
case OMPC_final:
Expand Down Expand Up @@ -17040,7 +17055,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
case OMPC_device:
case OMPC_threads:
case OMPC_simd:
case OMPC_num_teams:
case OMPC_thread_limit:
case OMPC_priority:
case OMPC_grainsize:
Expand Down Expand Up @@ -21703,32 +21717,40 @@ const ValueDecl *SemaOpenMP::getOpenMPDeclareMapperVarName() const {
return cast<DeclRefExpr>(DSAStack->getDeclareMapperVarRef())->getDecl();
}

OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(Expr *NumTeams,
OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc) {
Expr *ValExpr = NumTeams;
Stmt *HelperValStmt = nullptr;

// OpenMP [teams Constrcut, Restrictions]
// The num_teams expression must evaluate to a positive integer value.
if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
/*StrictlyPositive=*/true))
if (VarList.empty())
return nullptr;

OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective();
OpenMPDirectiveKind CaptureRegion = getOpenMPCaptureRegionForClause(
DKind, OMPC_num_teams, getLangOpts().OpenMP);
if (CaptureRegion != OMPD_unknown &&
!SemaRef.CurContext->isDependentContext()) {

for (Expr *ValExpr : VarList) {
// OpenMP [teams Constrcut, Restrictions]
// The num_teams expression must evaluate to a positive integer value.
if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
/*StrictlyPositive=*/true))
return nullptr;
}

if (CaptureRegion == OMPD_unknown || SemaRef.CurContext->isDependentContext())
return OMPNumTeamsClause::Create(getASTContext(), StartLoc, LParenLoc,
EndLoc, VarList, /*PreInit=*/nullptr);

llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
SmallVector<Expr *, 3> Vars;
for (Expr *ValExpr : VarList) {
ValExpr = SemaRef.MakeFullExpr(ValExpr).get();
llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
ValExpr = tryBuildCapture(SemaRef, ValExpr, Captures).get();
HelperValStmt = buildPreInits(getASTContext(), Captures);
Vars.push_back(ValExpr);
}

return new (getASTContext()) OMPNumTeamsClause(
ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc);
Stmt *PreInit = buildPreInits(getASTContext(), Captures);
return OMPNumTeamsClause::Create(getASTContext(), StartLoc, LParenLoc, EndLoc,
Vars, PreInit);
}

OMPClause *SemaOpenMP::ActOnOpenMPThreadLimitClause(Expr *ThreadLimit,
Expand Down
9 changes: 7 additions & 2 deletions clang/lib/Serialization/ASTReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10562,7 +10562,7 @@ OMPClause *OMPClauseReader::readClause() {
break;
}
case llvm::omp::OMPC_num_teams:
C = new (Context) OMPNumTeamsClause();
C = OMPNumTeamsClause::CreateEmpty(Context, Record.readInt());
break;
case llvm::omp::OMPC_thread_limit:
C = new (Context) OMPThreadLimitClause();
Expand Down Expand Up @@ -11350,8 +11350,13 @@ void OMPClauseReader::VisitOMPAllocateClause(OMPAllocateClause *C) {

void OMPClauseReader::VisitOMPNumTeamsClause(OMPNumTeamsClause *C) {
VisitOMPClauseWithPreInit(C);
C->setNumTeams(Record.readSubExpr());
C->setLParenLoc(Record.readSourceLocation());
unsigned NumVars = C->varlist_size();
SmallVector<Expr *, 16> Vars;
Vars.reserve(NumVars);
for (unsigned i = 0; i != NumVars; ++i)
Vars.push_back(Record.readSubExpr());
C->setVarRefs(Vars);
}

void OMPClauseReader::VisitOMPThreadLimitClause(OMPThreadLimitClause *C) {
Expand Down
4 changes: 3 additions & 1 deletion clang/lib/Serialization/ASTWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7528,9 +7528,11 @@ void OMPClauseWriter::VisitOMPAllocateClause(OMPAllocateClause *C) {
}

void OMPClauseWriter::VisitOMPNumTeamsClause(OMPNumTeamsClause *C) {
Record.push_back(C->varlist_size());
VisitOMPClauseWithPreInit(C);
Record.AddStmt(C->getNumTeams());
Record.AddSourceLocation(C->getLParenLoc());
for (auto *VE : C->varlists())
Record.AddStmt(VE);
}

void OMPClauseWriter::VisitOMPThreadLimitClause(OMPThreadLimitClause *C) {
Expand Down
Loading

0 comments on commit 7d27a0a

Please sign in to comment.