Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Clang][Sema][OpenMP] Allow num_teams to accept multiple expressions #99732

Conversation

shiltian
Copy link
Contributor

@shiltian shiltian commented Jul 20, 2024

By the OpenMP standard, num_teams clause can only accept one expression (for now). In this patch, we extend it to allow to accept multiple expressions when it is used with target teams ompx_bare construct. This will allow to launch a multi-dim grid, same as CUDA/HIP.

@shiltian shiltian marked this pull request as ready for review July 20, 2024 02:08
Copy link
Contributor Author

shiltian commented Jul 20, 2024

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:modules C++20 modules and Clang Header Modules clang:openmp OpenMP related changes to Clang labels Jul 20, 2024
@llvmbot
Copy link
Member

llvmbot commented Jul 20, 2024

@llvm/pr-subscribers-clang-codegen
@llvm/pr-subscribers-clang-modules

@llvm/pr-subscribers-clang

Author: Shilei Tian (shiltian)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/99732.diff

7 Files Affected:

  • (modified) clang/include/clang/AST/OpenMPClause.h (+48-31)
  • (modified) clang/include/clang/AST/RecursiveASTVisitor.h (+1-1)
  • (modified) clang/include/clang/Sema/SemaOpenMP.h (+2-1)
  • (modified) clang/lib/AST/OpenMPClause.cpp (+23-3)
  • (modified) clang/lib/AST/StmtProfile.cpp (+1-2)
  • (modified) clang/lib/Parse/ParseOpenMP.cpp (+1-1)
  • (modified) clang/lib/Sema/SemaOpenMP.cpp (+23-18)
diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h
index 325a1baa44614..2e82ccac28dc8 100644
--- a/clang/include/clang/AST/OpenMPClause.h
+++ b/clang/include/clang/AST/OpenMPClause.h
@@ -6131,43 +6131,54 @@ class OMPMapClause final : public OMPMappableExprListClause<OMPMapClause>,
 /// \endcode
 /// In this example directive '#pragma omp teams' has clause 'num_teams'
 /// with single expression 'n'.
-class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit {
-  friend class OMPClauseReader;
+///
+/// When 'ompx_bare' clause exists on a 'target' directive, 'num_teams' clause
+/// can accept up to three expressions.
+///
+/// \code
+/// #pragma omp target teams ompx_bare num_teams(x, y, z)
+/// \endcode
+class OMPNumTeamsClause final
+    : public OMPVarListClause<OMPNumTeamsClause>,
+      public OMPClauseWithPreInit,
+      private llvm::TrailingObjects<OMPNumTeamsClause, Expr *> {
+  friend OMPVarListClause;
+  friend TrailingObjects;
 
   /// Location of '('.
   SourceLocation LParenLoc;
 
-  /// NumTeams number.
-  Stmt *NumTeams = nullptr;
+  OMPNumTeamsClause(const ASTContext &C, SourceLocation StartLoc,
+                    SourceLocation LParenLoc, SourceLocation EndLoc, unsigned N)
+      : OMPVarListClause(llvm::omp::OMPC_num_teams, StartLoc, LParenLoc, EndLoc,
+                         N),
+        OMPClauseWithPreInit(this) {}
 
-  /// Set the NumTeams number.
-  ///
-  /// \param E NumTeams number.
-  void setNumTeams(Expr *E) { NumTeams = E; }
+  /// Build an empty clause.
+  OMPNumTeamsClause(unsigned N)
+      : OMPVarListClause(llvm::omp::OMPC_num_teams, SourceLocation(),
+                         SourceLocation(), SourceLocation(), N),
+        OMPClauseWithPreInit(this) {}
 
 public:
-  /// Build 'num_teams' clause.
+  /// Creates clause with a list of variables \a VL.
   ///
-  /// \param E Expression associated with this clause.
-  /// \param HelperE Helper Expression associated with this clause.
-  /// \param CaptureRegion Innermost OpenMP region where expressions in this
-  /// clause must be captured.
+  /// \param C AST context.
   /// \param StartLoc Starting location of the clause.
   /// \param LParenLoc Location of '('.
   /// \param EndLoc Ending location of the clause.
-  OMPNumTeamsClause(Expr *E, Stmt *HelperE, OpenMPDirectiveKind CaptureRegion,
-                    SourceLocation StartLoc, SourceLocation LParenLoc,
-                    SourceLocation EndLoc)
-      : OMPClause(llvm::omp::OMPC_num_teams, StartLoc, EndLoc),
-        OMPClauseWithPreInit(this), LParenLoc(LParenLoc), NumTeams(E) {
-    setPreInitStmt(HelperE, CaptureRegion);
-  }
+  /// \param VL List of references to the variables.
+  /// \param PreInit
+  static OMPNumTeamsClause *Create(const ASTContext &C, SourceLocation StartLoc,
+                                   SourceLocation LParenLoc,
+                                   SourceLocation EndLoc, ArrayRef<Expr *> VL,
+                                   Stmt *PreInit);
 
-  /// Build an empty clause.
-  OMPNumTeamsClause()
-      : OMPClause(llvm::omp::OMPC_num_teams, SourceLocation(),
-                  SourceLocation()),
-        OMPClauseWithPreInit(this) {}
+  /// Creates an empty clause with \a N variables.
+  ///
+  /// \param C AST context.
+  /// \param N The number of variables.
+  static OMPNumTeamsClause *CreateEmpty(const ASTContext &C, unsigned N);
 
   /// Sets the location of '('.
   void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
@@ -6175,16 +6186,22 @@ class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit {
   /// Returns the location of '('.
   SourceLocation getLParenLoc() const { return LParenLoc; }
 
-  /// Return NumTeams number.
-  Expr *getNumTeams() { return cast<Expr>(NumTeams); }
+  /// Return NumTeams number. By default, we return the first expression.
+  Expr *getNumTeams() { return getVarRefs().front(); }
 
-  /// Return NumTeams number.
-  Expr *getNumTeams() const { return cast<Expr>(NumTeams); }
+  /// Return NumTeams number. By default, we return the first expression.
+  Expr *getNumTeams() const {
+    return const_cast<OMPNumTeamsClause *>(this)->getNumTeams();
+  }
 
-  child_range children() { return child_range(&NumTeams, &NumTeams + 1); }
+  child_range children() {
+    return child_range(reinterpret_cast<Stmt **>(varlist_begin()),
+                       reinterpret_cast<Stmt **>(varlist_end()));
+  }
 
   const_child_range children() const {
-    return const_child_range(&NumTeams, &NumTeams + 1);
+    auto Children = const_cast<OMPNumTeamsClause *>(this)->children();
+    return const_child_range(Children.begin(), Children.end());
   }
 
   child_range used_children() {
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index e3c0cb46799f7..beb7b3597c2a8 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3793,8 +3793,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPMapClause(OMPMapClause *C) {
 template <typename Derived>
 bool RecursiveASTVisitor<Derived>::VisitOMPNumTeamsClause(
     OMPNumTeamsClause *C) {
+  TRY_TO(VisitOMPClauseList(C));
   TRY_TO(VisitOMPClauseWithPreInit(C));
-  TRY_TO(TraverseStmt(C->getNumTeams()));
   return true;
 }
 
diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h
index 54d81f91ffebc..bf5fbc670b05c 100644
--- a/clang/include/clang/Sema/SemaOpenMP.h
+++ b/clang/include/clang/Sema/SemaOpenMP.h
@@ -1227,7 +1227,8 @@ class SemaOpenMP : public SemaBase {
       const OMPVarListLocTy &Locs, bool NoDiagnose = false,
       ArrayRef<Expr *> UnresolvedMappers = std::nullopt);
   /// Called on well-formed 'num_teams' clause.
-  OMPClause *ActOnOpenMPNumTeamsClause(Expr *NumTeams, SourceLocation StartLoc,
+  OMPClause *ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
+                                       SourceLocation StartLoc,
                                        SourceLocation LParenLoc,
                                        SourceLocation EndLoc);
   /// Called on well-formed 'thread_limit' clause.
diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp
index 042a5df5906ca..ee9e9a0d39a92 100644
--- a/clang/lib/AST/OpenMPClause.cpp
+++ b/clang/lib/AST/OpenMPClause.cpp
@@ -1720,6 +1720,24 @@ const Expr *OMPDoacrossClause::getLoopData(unsigned NumLoop) const {
   return *It;
 }
 
+OMPNumTeamsClause *
+OMPNumTeamsClause::Create(const ASTContext &C, SourceLocation StartLoc,
+                          SourceLocation LParenLoc, SourceLocation EndLoc,
+                          ArrayRef<Expr *> VL, Stmt *PreInit) {
+  void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size()));
+  OMPNumTeamsClause *Clause =
+      new (Mem) OMPNumTeamsClause(C, StartLoc, LParenLoc, EndLoc, VL.size());
+  Clause->setVarRefs(VL);
+  Clause->setPreInitStmt(PreInit);
+  return Clause;
+}
+
+OMPNumTeamsClause *OMPNumTeamsClause::CreateEmpty(const ASTContext &C,
+                                                  unsigned N) {
+  void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N));
+  return new (Mem) OMPNumTeamsClause(N);
+}
+
 //===----------------------------------------------------------------------===//
 //  OpenMP clauses printing methods
 //===----------------------------------------------------------------------===//
@@ -1977,9 +1995,11 @@ void OMPClausePrinter::VisitOMPDeviceClause(OMPDeviceClause *Node) {
 }
 
 void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) {
-  OS << "num_teams(";
-  Node->getNumTeams()->printPretty(OS, nullptr, Policy, 0);
-  OS << ")";
+  if (!Node->varlist_empty()) {
+    OS << "num_teams";
+    VisitOMPClauseList(Node, '(');
+    OS << ")";
+  }
 }
 
 void OMPClausePrinter::VisitOMPThreadLimitClause(OMPThreadLimitClause *Node) {
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 89d2a422509d8..b782a4ab8367e 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -843,9 +843,8 @@ void OMPClauseProfiler::VisitOMPAllocateClause(const OMPAllocateClause *C) {
   VisitOMPClauseList(C);
 }
 void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) {
+  VisitOMPClauseList(C);
   VistOMPClauseWithPreInit(C);
-  if (C->getNumTeams())
-    Profiler->VisitStmt(C->getNumTeams());
 }
 void OMPClauseProfiler::VisitOMPThreadLimitClause(
     const OMPThreadLimitClause *C) {
diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index f5b44d210680c..e851bb4ac7fef 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -3098,7 +3098,6 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
   case OMPC_simdlen:
   case OMPC_collapse:
   case OMPC_ordered:
-  case OMPC_num_teams:
   case OMPC_thread_limit:
   case OMPC_priority:
   case OMPC_grainsize:
@@ -3279,6 +3278,7 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
   case OMPC_affinity:
   case OMPC_doacross:
   case OMPC_enter:
+  case OMPC_num_teams:
     if (getLangOpts().OpenMP >= 52 && DKind == OMPD_ordered &&
         CKind == OMPC_depend)
       Diag(Tok, diag::warn_omp_depend_in_ordered_deprecated);
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 3bd981cb442aa..a4e0ce730ae05 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -15041,9 +15041,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
   case OMPC_ordered:
     Res = ActOnOpenMPOrderedClause(StartLoc, EndLoc, LParenLoc, Expr);
     break;
-  case OMPC_num_teams:
-    Res = ActOnOpenMPNumTeamsClause(Expr, StartLoc, LParenLoc, EndLoc);
-    break;
   case OMPC_thread_limit:
     Res = ActOnOpenMPThreadLimitClause(Expr, StartLoc, LParenLoc, EndLoc);
     break;
@@ -15147,6 +15144,7 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
   case OMPC_affinity:
   case OMPC_when:
   case OMPC_bind:
+  case OMPC_num_teams:
   default:
     llvm_unreachable("Clause is not allowed.");
   }
@@ -17010,6 +17008,9 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
         static_cast<OpenMPDoacrossClauseModifier>(ExtraModifier),
         ExtraModifierLoc, ColonLoc, VarList, StartLoc, LParenLoc, EndLoc);
     break;
+  case OMPC_num_teams:
+    Res = ActOnOpenMPNumTeamsClause(VarList, StartLoc, LParenLoc, EndLoc);
+    break;
   case OMPC_if:
   case OMPC_depobj:
   case OMPC_final:
@@ -17040,7 +17041,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
   case OMPC_device:
   case OMPC_threads:
   case OMPC_simd:
-  case OMPC_num_teams:
   case OMPC_thread_limit:
   case OMPC_priority:
   case OMPC_grainsize:
@@ -21703,32 +21703,37 @@ const ValueDecl *SemaOpenMP::getOpenMPDeclareMapperVarName() const {
   return cast<DeclRefExpr>(DSAStack->getDeclareMapperVarRef())->getDecl();
 }
 
-OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(Expr *NumTeams,
+OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
                                                  SourceLocation StartLoc,
                                                  SourceLocation LParenLoc,
                                                  SourceLocation EndLoc) {
-  Expr *ValExpr = NumTeams;
-  Stmt *HelperValStmt = nullptr;
-
-  // OpenMP [teams Constrcut, Restrictions]
-  // The num_teams expression must evaluate to a positive integer value.
-  if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
-                                 /*StrictlyPositive=*/true))
+  if (VarList.empty())
     return nullptr;
 
   OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective();
   OpenMPDirectiveKind CaptureRegion = getOpenMPCaptureRegionForClause(
       DKind, OMPC_num_teams, getLangOpts().OpenMP);
-  if (CaptureRegion != OMPD_unknown &&
-      !SemaRef.CurContext->isDependentContext()) {
+
+  if (CaptureRegion == OMPD_unknown || SemaRef.CurContext->isDependentContext())
+    return OMPNumTeamsClause::Create(getASTContext(), StartLoc, LParenLoc,
+                                     EndLoc, VarList, /*PreInit=*/nullptr);
+
+  llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
+  SmallVector<Expr *, 3> Vars;
+  for (Expr *ValExpr : VarList) {
+    // OpenMP [teams Constrcut, Restrictions]
+    // The num_teams expression must evaluate to a positive integer value.
+    if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
+                                   /*StrictlyPositive=*/true))
+      return nullptr;
     ValExpr = SemaRef.MakeFullExpr(ValExpr).get();
-    llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
     ValExpr = tryBuildCapture(SemaRef, ValExpr, Captures).get();
-    HelperValStmt = buildPreInits(getASTContext(), Captures);
+    Vars.push_back(ValExpr);
   }
 
-  return new (getASTContext()) OMPNumTeamsClause(
-      ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc);
+  Stmt *PreInit = buildPreInits(getASTContext(), Captures);
+  return OMPNumTeamsClause::Create(getASTContext(), StartLoc, LParenLoc, EndLoc,
+                                   Vars, PreInit);
 }
 
 OMPClause *SemaOpenMP::ActOnOpenMPThreadLimitClause(Expr *ThreadLimit,

@shiltian shiltian marked this pull request as draft July 20, 2024 02:12
@shiltian shiltian force-pushed the users/shiltian/07-19-_clang_openmp_allow_num_teams_to_accept_multiple_expressions branch from 299dca8 to 7d27a0a Compare July 23, 2024 03:24
@shiltian shiltian force-pushed the users/shiltian/07-19-_clang_openmp_allow_num_teams_to_accept_multiple_expressions branch 2 times, most recently from 48918a1 to 37c28db Compare July 29, 2024 02:24
@shiltian shiltian changed the title [Clang][OpenMP] Allow num_teams to accept multiple expressions [Clang][Sema][OpenMP] Allow num_teams to accept multiple expressions Jul 29, 2024
@shiltian shiltian force-pushed the users/shiltian/07-19-_clang_openmp_allow_num_teams_to_accept_multiple_expressions branch from 37c28db to 4686b80 Compare July 30, 2024 03:13
@shiltian shiltian marked this pull request as ready for review July 30, 2024 03:14
@shiltian shiltian force-pushed the users/shiltian/07-19-_clang_openmp_allow_num_teams_to_accept_multiple_expressions branch 2 times, most recently from faa0d23 to 7de6bc7 Compare July 31, 2024 20:36
@shiltian shiltian force-pushed the users/shiltian/07-19-_clang_openmp_allow_num_teams_to_accept_multiple_expressions branch from 7de6bc7 to 0906b80 Compare July 31, 2024 20:52
@shiltian shiltian force-pushed the users/shiltian/07-19-_clang_openmp_allow_num_teams_to_accept_multiple_expressions branch from 0906b80 to fefe6d3 Compare July 31, 2024 20:59
Copy link
Member

@alexey-bataev alexey-bataev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update OpenMPSupport.rst and include info about changes to release notes

@shiltian shiltian force-pushed the users/shiltian/07-19-_clang_openmp_allow_num_teams_to_accept_multiple_expressions branch from fefe6d3 to 27422f2 Compare August 1, 2024 23:01
@alexey-bataev
Copy link
Member

Update OpenMPSupport.rst and include info about changes to release notes

This is not part of OpenMP standard so I don't think we need to update OpenMPSupport.rst.

We do not inform OpenMP committee here, just users :) For users it would be good to have this info

@shiltian
Copy link
Contributor Author

shiltian commented Aug 1, 2024

Ah, we don't even have an entry for ompx_bare in OpenMPSupport.rst. I will do a follow-up patch to add it there.

@shiltian shiltian force-pushed the users/shiltian/07-19-_clang_openmp_allow_num_teams_to_accept_multiple_expressions branch from 27422f2 to cbb4e5d Compare August 2, 2024 17:35
@shiltian
Copy link
Contributor Author

shiltian commented Aug 2, 2024

ping

Copy link
Member

@jdoerfert jdoerfert left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is almost ready. Only missing thing is the >3 argument check and test.

@shiltian shiltian force-pushed the users/shiltian/07-19-_clang_openmp_allow_num_teams_to_accept_multiple_expressions branch 2 times, most recently from c983a11 to 1ddc3ba Compare August 6, 2024 04:10
@shiltian shiltian force-pushed the users/shiltian/07-19-_clang_openmp_allow_num_teams_to_accept_multiple_expressions branch 2 times, most recently from d27b786 to 0bd0ff8 Compare August 6, 2024 12:31
Copy link
Member

@alexey-bataev alexey-bataev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG

@shiltian shiltian force-pushed the users/shiltian/07-19-_clang_openmp_allow_num_teams_to_accept_multiple_expressions branch from 0bd0ff8 to f9e58d7 Compare August 6, 2024 12:43
@shiltian shiltian merged commit cee594c into main Aug 6, 2024
8 of 10 checks passed
@shiltian shiltian deleted the users/shiltian/07-19-_clang_openmp_allow_num_teams_to_accept_multiple_expressions branch August 6, 2024 14:55
shiltian added a commit that referenced this pull request Aug 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:codegen clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:modules C++20 modules and Clang Header Modules clang:openmp OpenMP related changes to Clang clang Clang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants