-
Notifications
You must be signed in to change notification settings - Fork 12.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Clang][Sema][OpenMP] Allow num_teams
to accept multiple expressions
#99732
[Clang][Sema][OpenMP] Allow num_teams
to accept multiple expressions
#99732
Conversation
This stack of pull requests is managed by Graphite. Learn more about stacking. |
@llvm/pr-subscribers-clang-codegen @llvm/pr-subscribers-clang Author: Shilei Tian (shiltian) ChangesFull diff: https://github.com/llvm/llvm-project/pull/99732.diff 7 Files Affected:
diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h
index 325a1baa44614..2e82ccac28dc8 100644
--- a/clang/include/clang/AST/OpenMPClause.h
+++ b/clang/include/clang/AST/OpenMPClause.h
@@ -6131,43 +6131,54 @@ class OMPMapClause final : public OMPMappableExprListClause<OMPMapClause>,
/// \endcode
/// In this example directive '#pragma omp teams' has clause 'num_teams'
/// with single expression 'n'.
-class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit {
- friend class OMPClauseReader;
+///
+/// When 'ompx_bare' clause exists on a 'target' directive, 'num_teams' clause
+/// can accept up to three expressions.
+///
+/// \code
+/// #pragma omp target teams ompx_bare num_teams(x, y, z)
+/// \endcode
+class OMPNumTeamsClause final
+ : public OMPVarListClause<OMPNumTeamsClause>,
+ public OMPClauseWithPreInit,
+ private llvm::TrailingObjects<OMPNumTeamsClause, Expr *> {
+ friend OMPVarListClause;
+ friend TrailingObjects;
/// Location of '('.
SourceLocation LParenLoc;
- /// NumTeams number.
- Stmt *NumTeams = nullptr;
+ OMPNumTeamsClause(const ASTContext &C, SourceLocation StartLoc,
+ SourceLocation LParenLoc, SourceLocation EndLoc, unsigned N)
+ : OMPVarListClause(llvm::omp::OMPC_num_teams, StartLoc, LParenLoc, EndLoc,
+ N),
+ OMPClauseWithPreInit(this) {}
- /// Set the NumTeams number.
- ///
- /// \param E NumTeams number.
- void setNumTeams(Expr *E) { NumTeams = E; }
+ /// Build an empty clause.
+ OMPNumTeamsClause(unsigned N)
+ : OMPVarListClause(llvm::omp::OMPC_num_teams, SourceLocation(),
+ SourceLocation(), SourceLocation(), N),
+ OMPClauseWithPreInit(this) {}
public:
- /// Build 'num_teams' clause.
+ /// Creates clause with a list of variables \a VL.
///
- /// \param E Expression associated with this clause.
- /// \param HelperE Helper Expression associated with this clause.
- /// \param CaptureRegion Innermost OpenMP region where expressions in this
- /// clause must be captured.
+ /// \param C AST context.
/// \param StartLoc Starting location of the clause.
/// \param LParenLoc Location of '('.
/// \param EndLoc Ending location of the clause.
- OMPNumTeamsClause(Expr *E, Stmt *HelperE, OpenMPDirectiveKind CaptureRegion,
- SourceLocation StartLoc, SourceLocation LParenLoc,
- SourceLocation EndLoc)
- : OMPClause(llvm::omp::OMPC_num_teams, StartLoc, EndLoc),
- OMPClauseWithPreInit(this), LParenLoc(LParenLoc), NumTeams(E) {
- setPreInitStmt(HelperE, CaptureRegion);
- }
+ /// \param VL List of references to the variables.
+ /// \param PreInit
+ static OMPNumTeamsClause *Create(const ASTContext &C, SourceLocation StartLoc,
+ SourceLocation LParenLoc,
+ SourceLocation EndLoc, ArrayRef<Expr *> VL,
+ Stmt *PreInit);
- /// Build an empty clause.
- OMPNumTeamsClause()
- : OMPClause(llvm::omp::OMPC_num_teams, SourceLocation(),
- SourceLocation()),
- OMPClauseWithPreInit(this) {}
+ /// Creates an empty clause with \a N variables.
+ ///
+ /// \param C AST context.
+ /// \param N The number of variables.
+ static OMPNumTeamsClause *CreateEmpty(const ASTContext &C, unsigned N);
/// Sets the location of '('.
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
@@ -6175,16 +6186,22 @@ class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit {
/// Returns the location of '('.
SourceLocation getLParenLoc() const { return LParenLoc; }
- /// Return NumTeams number.
- Expr *getNumTeams() { return cast<Expr>(NumTeams); }
+ /// Return NumTeams number. By default, we return the first expression.
+ Expr *getNumTeams() { return getVarRefs().front(); }
- /// Return NumTeams number.
- Expr *getNumTeams() const { return cast<Expr>(NumTeams); }
+ /// Return NumTeams number. By default, we return the first expression.
+ Expr *getNumTeams() const {
+ return const_cast<OMPNumTeamsClause *>(this)->getNumTeams();
+ }
- child_range children() { return child_range(&NumTeams, &NumTeams + 1); }
+ child_range children() {
+ return child_range(reinterpret_cast<Stmt **>(varlist_begin()),
+ reinterpret_cast<Stmt **>(varlist_end()));
+ }
const_child_range children() const {
- return const_child_range(&NumTeams, &NumTeams + 1);
+ auto Children = const_cast<OMPNumTeamsClause *>(this)->children();
+ return const_child_range(Children.begin(), Children.end());
}
child_range used_children() {
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index e3c0cb46799f7..beb7b3597c2a8 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3793,8 +3793,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPMapClause(OMPMapClause *C) {
template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPNumTeamsClause(
OMPNumTeamsClause *C) {
+ TRY_TO(VisitOMPClauseList(C));
TRY_TO(VisitOMPClauseWithPreInit(C));
- TRY_TO(TraverseStmt(C->getNumTeams()));
return true;
}
diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h
index 54d81f91ffebc..bf5fbc670b05c 100644
--- a/clang/include/clang/Sema/SemaOpenMP.h
+++ b/clang/include/clang/Sema/SemaOpenMP.h
@@ -1227,7 +1227,8 @@ class SemaOpenMP : public SemaBase {
const OMPVarListLocTy &Locs, bool NoDiagnose = false,
ArrayRef<Expr *> UnresolvedMappers = std::nullopt);
/// Called on well-formed 'num_teams' clause.
- OMPClause *ActOnOpenMPNumTeamsClause(Expr *NumTeams, SourceLocation StartLoc,
+ OMPClause *ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
+ SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc);
/// Called on well-formed 'thread_limit' clause.
diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp
index 042a5df5906ca..ee9e9a0d39a92 100644
--- a/clang/lib/AST/OpenMPClause.cpp
+++ b/clang/lib/AST/OpenMPClause.cpp
@@ -1720,6 +1720,24 @@ const Expr *OMPDoacrossClause::getLoopData(unsigned NumLoop) const {
return *It;
}
+OMPNumTeamsClause *
+OMPNumTeamsClause::Create(const ASTContext &C, SourceLocation StartLoc,
+ SourceLocation LParenLoc, SourceLocation EndLoc,
+ ArrayRef<Expr *> VL, Stmt *PreInit) {
+ void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size()));
+ OMPNumTeamsClause *Clause =
+ new (Mem) OMPNumTeamsClause(C, StartLoc, LParenLoc, EndLoc, VL.size());
+ Clause->setVarRefs(VL);
+ Clause->setPreInitStmt(PreInit);
+ return Clause;
+}
+
+OMPNumTeamsClause *OMPNumTeamsClause::CreateEmpty(const ASTContext &C,
+ unsigned N) {
+ void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N));
+ return new (Mem) OMPNumTeamsClause(N);
+}
+
//===----------------------------------------------------------------------===//
// OpenMP clauses printing methods
//===----------------------------------------------------------------------===//
@@ -1977,9 +1995,11 @@ void OMPClausePrinter::VisitOMPDeviceClause(OMPDeviceClause *Node) {
}
void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) {
- OS << "num_teams(";
- Node->getNumTeams()->printPretty(OS, nullptr, Policy, 0);
- OS << ")";
+ if (!Node->varlist_empty()) {
+ OS << "num_teams";
+ VisitOMPClauseList(Node, '(');
+ OS << ")";
+ }
}
void OMPClausePrinter::VisitOMPThreadLimitClause(OMPThreadLimitClause *Node) {
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 89d2a422509d8..b782a4ab8367e 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -843,9 +843,8 @@ void OMPClauseProfiler::VisitOMPAllocateClause(const OMPAllocateClause *C) {
VisitOMPClauseList(C);
}
void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) {
+ VisitOMPClauseList(C);
VistOMPClauseWithPreInit(C);
- if (C->getNumTeams())
- Profiler->VisitStmt(C->getNumTeams());
}
void OMPClauseProfiler::VisitOMPThreadLimitClause(
const OMPThreadLimitClause *C) {
diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index f5b44d210680c..e851bb4ac7fef 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -3098,7 +3098,6 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
case OMPC_simdlen:
case OMPC_collapse:
case OMPC_ordered:
- case OMPC_num_teams:
case OMPC_thread_limit:
case OMPC_priority:
case OMPC_grainsize:
@@ -3279,6 +3278,7 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
case OMPC_affinity:
case OMPC_doacross:
case OMPC_enter:
+ case OMPC_num_teams:
if (getLangOpts().OpenMP >= 52 && DKind == OMPD_ordered &&
CKind == OMPC_depend)
Diag(Tok, diag::warn_omp_depend_in_ordered_deprecated);
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 3bd981cb442aa..a4e0ce730ae05 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -15041,9 +15041,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
case OMPC_ordered:
Res = ActOnOpenMPOrderedClause(StartLoc, EndLoc, LParenLoc, Expr);
break;
- case OMPC_num_teams:
- Res = ActOnOpenMPNumTeamsClause(Expr, StartLoc, LParenLoc, EndLoc);
- break;
case OMPC_thread_limit:
Res = ActOnOpenMPThreadLimitClause(Expr, StartLoc, LParenLoc, EndLoc);
break;
@@ -15147,6 +15144,7 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
case OMPC_affinity:
case OMPC_when:
case OMPC_bind:
+ case OMPC_num_teams:
default:
llvm_unreachable("Clause is not allowed.");
}
@@ -17010,6 +17008,9 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
static_cast<OpenMPDoacrossClauseModifier>(ExtraModifier),
ExtraModifierLoc, ColonLoc, VarList, StartLoc, LParenLoc, EndLoc);
break;
+ case OMPC_num_teams:
+ Res = ActOnOpenMPNumTeamsClause(VarList, StartLoc, LParenLoc, EndLoc);
+ break;
case OMPC_if:
case OMPC_depobj:
case OMPC_final:
@@ -17040,7 +17041,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
case OMPC_device:
case OMPC_threads:
case OMPC_simd:
- case OMPC_num_teams:
case OMPC_thread_limit:
case OMPC_priority:
case OMPC_grainsize:
@@ -21703,32 +21703,37 @@ const ValueDecl *SemaOpenMP::getOpenMPDeclareMapperVarName() const {
return cast<DeclRefExpr>(DSAStack->getDeclareMapperVarRef())->getDecl();
}
-OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(Expr *NumTeams,
+OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc) {
- Expr *ValExpr = NumTeams;
- Stmt *HelperValStmt = nullptr;
-
- // OpenMP [teams Constrcut, Restrictions]
- // The num_teams expression must evaluate to a positive integer value.
- if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
- /*StrictlyPositive=*/true))
+ if (VarList.empty())
return nullptr;
OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective();
OpenMPDirectiveKind CaptureRegion = getOpenMPCaptureRegionForClause(
DKind, OMPC_num_teams, getLangOpts().OpenMP);
- if (CaptureRegion != OMPD_unknown &&
- !SemaRef.CurContext->isDependentContext()) {
+
+ if (CaptureRegion == OMPD_unknown || SemaRef.CurContext->isDependentContext())
+ return OMPNumTeamsClause::Create(getASTContext(), StartLoc, LParenLoc,
+ EndLoc, VarList, /*PreInit=*/nullptr);
+
+ llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
+ SmallVector<Expr *, 3> Vars;
+ for (Expr *ValExpr : VarList) {
+ // OpenMP [teams Constrcut, Restrictions]
+ // The num_teams expression must evaluate to a positive integer value.
+ if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
+ /*StrictlyPositive=*/true))
+ return nullptr;
ValExpr = SemaRef.MakeFullExpr(ValExpr).get();
- llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
ValExpr = tryBuildCapture(SemaRef, ValExpr, Captures).get();
- HelperValStmt = buildPreInits(getASTContext(), Captures);
+ Vars.push_back(ValExpr);
}
- return new (getASTContext()) OMPNumTeamsClause(
- ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc);
+ Stmt *PreInit = buildPreInits(getASTContext(), Captures);
+ return OMPNumTeamsClause::Create(getASTContext(), StartLoc, LParenLoc, EndLoc,
+ Vars, PreInit);
}
OMPClause *SemaOpenMP::ActOnOpenMPThreadLimitClause(Expr *ThreadLimit,
|
299dca8
to
7d27a0a
Compare
48918a1
to
37c28db
Compare
num_teams
to accept multiple expressionsnum_teams
to accept multiple expressions
37c28db
to
4686b80
Compare
faa0d23
to
7de6bc7
Compare
7de6bc7
to
0906b80
Compare
0906b80
to
fefe6d3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update OpenMPSupport.rst and include info about changes to release notes
fefe6d3
to
27422f2
Compare
We do not inform OpenMP committee here, just users :) For users it would be good to have this info |
Ah, we don't even have an entry for |
27422f2
to
cbb4e5d
Compare
ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is almost ready. Only missing thing is the >3 argument check and test.
c983a11
to
1ddc3ba
Compare
d27b786
to
0bd0ff8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG
0bd0ff8
to
f9e58d7
Compare
By the OpenMP standard,
num_teams
clause can only accept one expression (for now). In this patch, we extend it to allow to accept multiple expressions when it is used withtarget teams ompx_bare
construct. This will allow to launch a multi-dim grid, same as CUDA/HIP.