Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Clang][Sema][OpenMP] Allow num_teams to accept multiple expressions #99732

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clang/docs/OpenMPSupport.rst
Original file line number Diff line number Diff line change
Expand Up @@ -363,5 +363,7 @@ considered for standardization. Please post on the
| device extension | `'ompx_bare' clause on 'target teams' construct | :good:`prototyped` | #66844, #70612 |
| | <https://www.osti.gov/servlets/purl/2205717>`_ | | |
+------------------------------+-----------------------------------------------------------------------------------+--------------------------+--------------------------------------------------------+
| device extension | Multi-dim `'num_teams' clause on 'target teams ompx_bare' construct | :good:`partial` | #99732, #101407 |
+------------------------------+-----------------------------------------------------------------------------------+--------------------------+--------------------------------------------------------+

.. _Discourse forums (Runtimes - OpenMP category): https://discourse.llvm.org/c/runtimes/openmp/35
3 changes: 3 additions & 0 deletions clang/docs/ReleaseNotes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,9 @@ Improvements
^^^^^^^^^^^^
- Improve the handling of mapping array-section for struct containing nested structs with user defined mappers

- `num_teams` now accepts multiple expressions when it is used along in ``target teams ompx_bare`` construct.
This allows the target region to be launched with multi-dim grid on GPUs.

Additional Information
======================

Expand Down
79 changes: 48 additions & 31 deletions clang/include/clang/AST/OpenMPClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -6369,60 +6369,77 @@ class OMPMapClause final : public OMPMappableExprListClause<OMPMapClause>,
/// \endcode
/// In this example directive '#pragma omp teams' has clause 'num_teams'
/// with single expression 'n'.
class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit {
friend class OMPClauseReader;
///
/// When 'ompx_bare' clause exists on a 'target' directive, 'num_teams' clause
/// can accept up to three expressions.
///
/// \code
/// #pragma omp target teams ompx_bare num_teams(x, y, z)
/// \endcode
class OMPNumTeamsClause final
: public OMPVarListClause<OMPNumTeamsClause>,
public OMPClauseWithPreInit,
private llvm::TrailingObjects<OMPNumTeamsClause, Expr *> {
friend OMPVarListClause;
friend TrailingObjects;

/// Location of '('.
SourceLocation LParenLoc;

/// NumTeams number.
Stmt *NumTeams = nullptr;
OMPNumTeamsClause(const ASTContext &C, SourceLocation StartLoc,
SourceLocation LParenLoc, SourceLocation EndLoc, unsigned N)
: OMPVarListClause(llvm::omp::OMPC_num_teams, StartLoc, LParenLoc, EndLoc,
N),
OMPClauseWithPreInit(this) {}

/// Set the NumTeams number.
///
/// \param E NumTeams number.
void setNumTeams(Expr *E) { NumTeams = E; }
/// Build an empty clause.
OMPNumTeamsClause(unsigned N)
: OMPVarListClause(llvm::omp::OMPC_num_teams, SourceLocation(),
SourceLocation(), SourceLocation(), N),
OMPClauseWithPreInit(this) {}

public:
/// Build 'num_teams' clause.
/// Creates clause with a list of variables \a VL.
///
/// \param E Expression associated with this clause.
/// \param HelperE Helper Expression associated with this clause.
/// \param CaptureRegion Innermost OpenMP region where expressions in this
/// clause must be captured.
/// \param C AST context.
/// \param StartLoc Starting location of the clause.
/// \param LParenLoc Location of '('.
/// \param EndLoc Ending location of the clause.
OMPNumTeamsClause(Expr *E, Stmt *HelperE, OpenMPDirectiveKind CaptureRegion,
SourceLocation StartLoc, SourceLocation LParenLoc,
SourceLocation EndLoc)
: OMPClause(llvm::omp::OMPC_num_teams, StartLoc, EndLoc),
OMPClauseWithPreInit(this), LParenLoc(LParenLoc), NumTeams(E) {
setPreInitStmt(HelperE, CaptureRegion);
}
/// \param VL List of references to the variables.
/// \param PreInit
static OMPNumTeamsClause *
Create(const ASTContext &C, OpenMPDirectiveKind CaptureRegion,
SourceLocation StartLoc, SourceLocation LParenLoc,
SourceLocation EndLoc, ArrayRef<Expr *> VL, Stmt *PreInit);

/// Build an empty clause.
OMPNumTeamsClause()
: OMPClause(llvm::omp::OMPC_num_teams, SourceLocation(),
SourceLocation()),
OMPClauseWithPreInit(this) {}
/// Creates an empty clause with \a N variables.
///
/// \param C AST context.
/// \param N The number of variables.
static OMPNumTeamsClause *CreateEmpty(const ASTContext &C, unsigned N);

/// Sets the location of '('.
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }

/// Returns the location of '('.
SourceLocation getLParenLoc() const { return LParenLoc; }

/// Return NumTeams number.
Expr *getNumTeams() { return cast<Expr>(NumTeams); }
/// Return NumTeams expressions.
ArrayRef<Expr *> getNumTeams() { return getVarRefs(); }

/// Return NumTeams number.
Expr *getNumTeams() const { return cast<Expr>(NumTeams); }
/// Return NumTeams expressions.
ArrayRef<Expr *> getNumTeams() const {
return const_cast<OMPNumTeamsClause *>(this)->getNumTeams();
}

child_range children() { return child_range(&NumTeams, &NumTeams + 1); }
child_range children() {
return child_range(reinterpret_cast<Stmt **>(varlist_begin()),
reinterpret_cast<Stmt **>(varlist_end()));
}

const_child_range children() const {
return const_child_range(&NumTeams, &NumTeams + 1);
auto Children = const_cast<OMPNumTeamsClause *>(this)->children();
return const_child_range(Children.begin(), Children.end());
}

child_range used_children() {
Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/AST/RecursiveASTVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -3828,8 +3828,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPMapClause(OMPMapClause *C) {
template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPNumTeamsClause(
OMPNumTeamsClause *C) {
TRY_TO(VisitOMPClauseList(C));
TRY_TO(VisitOMPClauseWithPreInit(C));
TRY_TO(TraverseStmt(C->getNumTeams()));
return true;
}

Expand Down
2 changes: 2 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -11639,6 +11639,8 @@ def warn_omp_unterminated_declare_target : Warning<
InGroup<SourceUsesOpenMP>;
def err_ompx_bare_no_grid : Error<
"'ompx_bare' clauses requires explicit grid size via 'num_teams' and 'thread_limit' clauses">;
def err_omp_multi_expr_not_allowed: Error<"only one expression allowed in '%0' clause">;
def err_ompx_more_than_three_expr_not_allowed: Error<"at most three expressions are allowed in '%0' clause in 'target teams ompx_bare' construct">;
} // end of OpenMP category

let CategoryName = "Related Result Type Issue" in {
Expand Down
3 changes: 2 additions & 1 deletion clang/include/clang/Sema/SemaOpenMP.h
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,8 @@ class SemaOpenMP : public SemaBase {
const OMPVarListLocTy &Locs, bool NoDiagnose = false,
ArrayRef<Expr *> UnresolvedMappers = std::nullopt);
/// Called on well-formed 'num_teams' clause.
OMPClause *ActOnOpenMPNumTeamsClause(Expr *NumTeams, SourceLocation StartLoc,
OMPClause *ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc);
/// Called on well-formed 'thread_limit' clause.
Expand Down
26 changes: 23 additions & 3 deletions clang/lib/AST/OpenMPClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1755,6 +1755,24 @@ OMPContainsClause *OMPContainsClause::CreateEmpty(const ASTContext &C,
return new (Mem) OMPContainsClause(K);
}

OMPNumTeamsClause *OMPNumTeamsClause::Create(
const ASTContext &C, OpenMPDirectiveKind CaptureRegion,
SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc,
ArrayRef<Expr *> VL, Stmt *PreInit) {
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size()));
OMPNumTeamsClause *Clause =
new (Mem) OMPNumTeamsClause(C, StartLoc, LParenLoc, EndLoc, VL.size());
Clause->setVarRefs(VL);
Clause->setPreInitStmt(PreInit, CaptureRegion);
return Clause;
}

OMPNumTeamsClause *OMPNumTeamsClause::CreateEmpty(const ASTContext &C,
unsigned N) {
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N));
return new (Mem) OMPNumTeamsClause(N);
}

//===----------------------------------------------------------------------===//
// OpenMP clauses printing methods
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2055,9 +2073,11 @@ void OMPClausePrinter::VisitOMPDeviceClause(OMPDeviceClause *Node) {
}

void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) {
OS << "num_teams(";
Node->getNumTeams()->printPretty(OS, nullptr, Policy, 0);
OS << ")";
if (!Node->varlist_empty()) {
OS << "num_teams";
VisitOMPClauseList(Node, '(');
OS << ")";
}
}

void OMPClausePrinter::VisitOMPThreadLimitClause(OMPThreadLimitClause *Node) {
Expand Down
3 changes: 1 addition & 2 deletions clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -857,9 +857,8 @@ void OMPClauseProfiler::VisitOMPAllocateClause(const OMPAllocateClause *C) {
VisitOMPClauseList(C);
}
void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) {
VisitOMPClauseList(C);
VistOMPClauseWithPreInit(C);
if (C->getNumTeams())
Profiler->VisitStmt(C->getNumTeams());
}
void OMPClauseProfiler::VisitOMPThreadLimitClause(
const OMPThreadLimitClause *C) {
Expand Down
7 changes: 4 additions & 3 deletions clang/lib/CodeGen/CGOpenMPRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6036,8 +6036,9 @@ const Expr *CGOpenMPRuntime::getNumTeamsExprForTargetDirective(
dyn_cast_or_null<OMPExecutableDirective>(ChildStmt)) {
if (isOpenMPTeamsDirective(NestedDir->getDirectiveKind())) {
if (NestedDir->hasClausesOfKind<OMPNumTeamsClause>()) {
const Expr *NumTeams =
NestedDir->getSingleClause<OMPNumTeamsClause>()->getNumTeams();
const Expr *NumTeams = NestedDir->getSingleClause<OMPNumTeamsClause>()
->getNumTeams()
.front();
if (NumTeams->isIntegerConstantExpr(CGF.getContext()))
if (auto Constant =
NumTeams->getIntegerConstantExpr(CGF.getContext()))
Expand All @@ -6062,7 +6063,7 @@ const Expr *CGOpenMPRuntime::getNumTeamsExprForTargetDirective(
case OMPD_target_teams_distribute_parallel_for_simd: {
if (D.hasClausesOfKind<OMPNumTeamsClause>()) {
const Expr *NumTeams =
D.getSingleClause<OMPNumTeamsClause>()->getNumTeams();
D.getSingleClause<OMPNumTeamsClause>()->getNumTeams().front();
if (NumTeams->isIntegerConstantExpr(CGF.getContext()))
if (auto Constant = NumTeams->getIntegerConstantExpr(CGF.getContext()))
MinTeamsVal = MaxTeamsVal = Constant->getExtValue();
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CodeGen/CGStmtOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6859,7 +6859,7 @@ static void emitCommonOMPTeamsDirective(CodeGenFunction &CGF,
const auto *NT = S.getSingleClause<OMPNumTeamsClause>();
const auto *TL = S.getSingleClause<OMPThreadLimitClause>();
if (NT || TL) {
const Expr *NumTeams = NT ? NT->getNumTeams() : nullptr;
const Expr *NumTeams = NT ? NT->getNumTeams().front() : nullptr;
const Expr *ThreadLimit = TL ? TL->getThreadLimit() : nullptr;

CGF.CGM.getOpenMPRuntime().emitNumTeamsClause(CGF, NumTeams, ThreadLimit,
Expand Down
8 changes: 7 additions & 1 deletion clang/lib/Parse/ParseOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3175,7 +3175,6 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
case OMPC_simdlen:
case OMPC_collapse:
case OMPC_ordered:
case OMPC_num_teams:
case OMPC_thread_limit:
case OMPC_priority:
case OMPC_grainsize:
Expand Down Expand Up @@ -3332,6 +3331,13 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
? ParseOpenMPSimpleClause(CKind, WrongDirective)
: ParseOpenMPClause(CKind, WrongDirective);
break;
case OMPC_num_teams:
if (!FirstClause) {
Diag(Tok, diag::err_omp_more_one_clause)
<< getOpenMPDirectiveName(DKind) << getOpenMPClauseName(CKind) << 0;
ErrorFound = true;
}
[[clang::fallthrough]];
case OMPC_private:
case OMPC_firstprivate:
case OMPC_lastprivate:
Expand Down
Loading
Loading