Skip to content

Commit

Permalink
More on Gpu kernel fusing (#1332)
Browse files Browse the repository at this point in the history
## Summary

* Add Gpu::KernelInfo argument to ParallelFor to allow the user to indicate
  whether the kernel is an candidate for fusing.

* For MFIter, if the local size is less or equal to 3, the fuse region is
  turned on and small kernels marked fusable will be fused.

* Add launch macros for fusing.

* Add fusing to a number of functions used by linear solvers.  Note that
  there are a lot more amrex functions need to be updated for fusing.

* Optimize reduction for bottom solve.

* Consolidate memcpy in communication functions.

* Option to use device memory in communication kernels for packing and
  unpacking buffers.  But it's currently turned off because the performance
  was not improved in testing.  In fact, it was worse than using pinned
  memory.  But this might change in the future.  So the option is kept.

## Checklist

The proposed changes:
- [ ] fix a bug or incorrect behavior in AMReX
- [x] add new capabilities to AMReX
- [ ] changes answers in the test suite to more than roundoff level
- [ ] are likely to significantly affect the results of downstream AMReX users
- [ ] are described in the proposed changes to the AMReX documentation, if appropriate
  • Loading branch information
WeiqunZhang authored Sep 21, 2020
1 parent 626b5f5 commit 1cec808
Show file tree
Hide file tree
Showing 24 changed files with 1,309 additions and 428 deletions.
2 changes: 1 addition & 1 deletion Src/Base/AMReX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ amrex::Initialize (int& argc, char**& argv, bool build_parm_parse,
BL_PROFILE_INITPARAMS();
#endif
machine::Initialize();
#ifdef AMREX_USE_CUDA
#ifdef AMREX_USE_GPU
Gpu::Fuser::Initialize();
#endif

Expand Down
192 changes: 128 additions & 64 deletions Src/Base/AMReX_FBI.H

Large diffs are not rendered by default.

24 changes: 12 additions & 12 deletions Src/Base/AMReX_FabArray.H
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ public:
#endif

static void pack_send_buffer_gpu (FabArray<FAB> const& src, int scomp, int ncomp,
Vector<char*>& send_data,
Vector<char*> const& send_data,
Vector<std::size_t> const& send_size,
Vector<const CopyComTagsContainer*> const& send_cctc);

Expand All @@ -731,7 +731,7 @@ public:
#endif

static void pack_send_buffer_cpu (FabArray<FAB> const& src, int scomp, int ncomp,
Vector<char*>& send_data,
Vector<char*> const& send_data,
Vector<std::size_t> const& send_size,
Vector<const CopyComTagsContainer*> const& send_cctc);

Expand Down Expand Up @@ -1582,7 +1582,7 @@ FabArray<FAB>::setVal (value_type val,
{
const Box& bx = fai.growntilebox(nghost);
auto fab = this->array(fai);
AMREX_HOST_DEVICE_PARALLEL_FOR_4D ( bx, ncomp, i, j, k, n,
AMREX_HOST_DEVICE_PARALLEL_FOR_4D_FUSIBLE ( bx, ncomp, i, j, k, n,
{
fab(i,j,k,n+comp) = val;
});
Expand Down Expand Up @@ -1625,7 +1625,7 @@ FabArray<FAB>::setVal (value_type val,

if (b.ok()) {
auto fab = this->array(fai);
AMREX_HOST_DEVICE_PARALLEL_FOR_4D ( b, ncomp, i, j, k, n,
AMREX_HOST_DEVICE_PARALLEL_FOR_4D_FUSIBLE ( b, ncomp, i, j, k, n,
{
fab(i,j,k,n+comp) = val;
});
Expand Down Expand Up @@ -1655,7 +1655,7 @@ FabArray<FAB>::abs (int comp, int ncomp, const IntVect& nghost)
{
const Box& bx = mfi.growntilebox(nghost);
auto fab = this->array(mfi);
AMREX_HOST_DEVICE_PARALLEL_FOR_4D ( bx, ncomp, i, j, k, n,
AMREX_HOST_DEVICE_PARALLEL_FOR_4D_FUSIBLE ( bx, ncomp, i, j, k, n,
{
fab(i,j,k,n+comp) = amrex::Math::abs(fab(i,j,k,n+comp));
});
Expand All @@ -1674,7 +1674,7 @@ FabArray<FAB>::plus (value_type val, int comp, int num_comp, int nghost)
{
const Box& bx = mfi.growntilebox(nghost);
auto fab = this->array(mfi);
AMREX_HOST_DEVICE_PARALLEL_FOR_4D ( bx, num_comp, i, j, k, n,
AMREX_HOST_DEVICE_PARALLEL_FOR_4D_FUSIBLE ( bx, num_comp, i, j, k, n,
{
fab(i,j,k,n+comp) += val;
});
Expand All @@ -1694,7 +1694,7 @@ FabArray<FAB>::plus (value_type val, const Box& region, int comp, int num_comp,
const Box& bx = mfi.growntilebox(nghost) & region;
if (bx.ok()) {
auto fab = this->array(mfi);
AMREX_HOST_DEVICE_PARALLEL_FOR_4D ( bx, num_comp, i, j, k, n,
AMREX_HOST_DEVICE_PARALLEL_FOR_4D_FUSIBLE ( bx, num_comp, i, j, k, n,
{
fab(i,j,k,n+comp) += val;
});
Expand All @@ -1714,7 +1714,7 @@ FabArray<FAB>::mult (value_type val, int comp, int num_comp, int nghost)
{
const Box& bx = mfi.growntilebox(nghost);
auto fab = this->array(mfi);
AMREX_HOST_DEVICE_PARALLEL_FOR_4D ( bx, num_comp, i, j, k, n,
AMREX_HOST_DEVICE_PARALLEL_FOR_4D_FUSIBLE ( bx, num_comp, i, j, k, n,
{
fab(i,j,k,n+comp) *= val;
});
Expand All @@ -1734,7 +1734,7 @@ FabArray<FAB>::mult (value_type val, const Box& region, int comp, int num_comp,
const Box& bx = mfi.growntilebox(nghost) & region;
if (bx.ok()) {
auto fab = this->array(mfi);
AMREX_HOST_DEVICE_PARALLEL_FOR_4D ( bx, num_comp, i, j, k, n,
AMREX_HOST_DEVICE_PARALLEL_FOR_4D_FUSIBLE ( bx, num_comp, i, j, k, n,
{
fab(i,j,k,n+comp) *= val;
});
Expand All @@ -1754,7 +1754,7 @@ FabArray<FAB>::invert (value_type numerator, int comp, int num_comp, int nghost)
{
const Box& bx = mfi.growntilebox(nghost);
auto fab = this->array(mfi);
AMREX_HOST_DEVICE_PARALLEL_FOR_4D ( bx, num_comp, i, j, k, n,
AMREX_HOST_DEVICE_PARALLEL_FOR_4D_FUSIBLE ( bx, num_comp, i, j, k, n,
{
fab(i,j,k,n+comp) = numerator / fab(i,j,k,n+comp);
});
Expand All @@ -1774,7 +1774,7 @@ FabArray<FAB>::invert (value_type numerator, const Box& region, int comp, int nu
const Box& bx = mfi.growntilebox(nghost) & region;
if (bx.ok()) {
auto fab = this->array(mfi);
AMREX_HOST_DEVICE_PARALLEL_FOR_4D ( bx, num_comp, i, j, k, n,
AMREX_HOST_DEVICE_PARALLEL_FOR_4D_FUSIBLE ( bx, num_comp, i, j, k, n,
{
fab(i,j,k,n+comp) = numerator / fab(i,j,k,n+comp);
});
Expand Down Expand Up @@ -1970,7 +1970,7 @@ FabArray<FAB>::BuildMask (const Box& phys_domain, const Periodicity& period,
Box const& fbx = mfi.growntilebox();
Box const& gbx = fbx & domain;
Box const& vbx = mfi.validbox();
AMREX_HOST_DEVICE_FOR_4D(fbx, ncomp, i, j, k, n,
AMREX_HOST_DEVICE_FOR_4D_FUSIBLE(fbx, ncomp, i, j, k, n,
{
IntVect iv(AMREX_D_DECL(i,j,k));
if (vbx.contains(iv)) {
Expand Down
2 changes: 1 addition & 1 deletion Src/Base/AMReX_FabArrayBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ FabArrayBase::Initialize ()

#ifdef AMREX_USE_GPU
if (ParallelDescriptor::UseGpuAwareMpi()) {
the_fa_arena = The_Device_Arena();
the_fa_arena = The_Arena();
} else {
the_fa_arena = The_Pinned_Arena();
}
Expand Down
10 changes: 3 additions & 7 deletions Src/Base/AMReX_FabArrayCommI.H
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,7 @@ FabArray<FAB>::FBEP_nowait (int scomp, int ncomp, const IntVect& nghost,
{
the_send_data = static_cast<char*>(amrex::The_FA_Arena()->alloc(total_volume));
for (int i = 0, N = send_size.size(); i < N; ++i) {
if (send_size[i] > 0) {
send_data[i] = the_send_data + offset[i];
}
send_data[i] = the_send_data + offset[i];
}
} else {
the_send_data = nullptr;
Expand Down Expand Up @@ -495,9 +493,7 @@ FabArray<FAB>::ParallelCopy (const FabArray<FAB>& src,
{
the_send_data = static_cast<char*>(amrex::The_FA_Arena()->alloc(total_volume));
for (int i = 0, N = send_size.size(); i < N; ++i) {
if (send_size[i] > 0) {
send_data[i] = the_send_data + offset[i];
}
send_data[i] = the_send_data + offset[i];
}
}

Expand Down Expand Up @@ -749,9 +745,9 @@ FabArray<FAB>::PostRcvs (const MapOfCopyComTagContainers& m_RcvTags,

for (int i = 0; i < nrecv; ++i)
{
recv_data[i] = the_recv_data + offset[i];
if (recv_size[i] > 0)
{
recv_data[i] = the_recv_data + offset[i];
const int rank = ParallelContext::global_to_local_rank(recv_from[i]);
const int comm_data_type = ParallelDescriptor::select_comm_data_type(recv_size[i]);
if (comm_data_type == 1) {
Expand Down
18 changes: 10 additions & 8 deletions Src/Base/AMReX_FabArrayUtility.H
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,13 @@ ReduceSum_device (FabArray<FAB1> const& fa1, FabArray<FAB2> const& fa2,
using value_type = typename FAB1::value_type;
value_type sm = 0;

BL_PROFILE("ReduceSum_device");

{
ReduceOps<ReduceOpSum> reduce_op;
ReduceData<value_type> reduce_data(reduce_op);
using ReduceTuple = typename decltype(reduce_data)::Type;

Gpu::FuseReductionSafeGuard rsg(true);
for (MFIter mfi(fa1); mfi.isValid(); ++mfi)
{
const Box& bx = amrex::grow(mfi.validbox(),nghost);
Expand Down Expand Up @@ -1467,7 +1469,7 @@ Add (FabArray<FAB>& dst, FabArray<FAB> const& src, int srccomp, int dstcomp, int
{
auto const srcFab = src.array(mfi);
auto dstFab = dst.array(mfi);
AMREX_HOST_DEVICE_PARALLEL_FOR_4D ( bx, numcomp, i, j, k, n,
AMREX_HOST_DEVICE_PARALLEL_FOR_4D_FUSIBLE ( bx, numcomp, i, j, k, n,
{
dstFab(i,j,k,n+dstcomp) += srcFab(i,j,k,n+srccomp);
});
Expand Down Expand Up @@ -1499,7 +1501,7 @@ Copy (FabArray<FAB>& dst, FabArray<FAB> const& src, int srccomp, int dstcomp, in
{
auto const srcFab = src.array(mfi);
auto dstFab = dst.array(mfi);
AMREX_HOST_DEVICE_PARALLEL_FOR_4D ( bx, numcomp, i, j, k, n,
AMREX_HOST_DEVICE_PARALLEL_FOR_4D_FUSIBLE ( bx, numcomp, i, j, k, n,
{
dstFab(i,j,k,dstcomp+n) = srcFab(i,j,k,srccomp+n);
});
Expand Down Expand Up @@ -1531,7 +1533,7 @@ Subtract (FabArray<FAB>& dst, FabArray<FAB> const& src, int srccomp, int dstcomp
{
auto const srcFab = src.array(mfi);
auto dstFab = dst.array(mfi);
AMREX_HOST_DEVICE_PARALLEL_FOR_4D ( bx, numcomp, i, j, k, n,
AMREX_HOST_DEVICE_PARALLEL_FOR_4D_FUSIBLE ( bx, numcomp, i, j, k, n,
{
dstFab(i,j,k,n+dstcomp) -= srcFab(i,j,k,n+srccomp);
});
Expand Down Expand Up @@ -1563,7 +1565,7 @@ Multiply (FabArray<FAB>& dst, FabArray<FAB> const& src, int srccomp, int dstcomp
{
auto const srcFab = src.array(mfi);
auto dstFab = dst.array(mfi);
AMREX_HOST_DEVICE_PARALLEL_FOR_4D ( bx, numcomp, i, j, k, n,
AMREX_HOST_DEVICE_PARALLEL_FOR_4D_FUSIBLE ( bx, numcomp, i, j, k, n,
{
dstFab(i,j,k,n+dstcomp) *= srcFab(i,j,k,n+srccomp);
});
Expand Down Expand Up @@ -1595,7 +1597,7 @@ Divide (FabArray<FAB>& dst, FabArray<FAB> const& src, int srccomp, int dstcomp,
{
auto const srcFab = src.array(mfi);
auto dstFab = dst.array(mfi);
AMREX_HOST_DEVICE_PARALLEL_FOR_4D ( bx, numcomp, i, j, k, n,
AMREX_HOST_DEVICE_PARALLEL_FOR_4D_FUSIBLE ( bx, numcomp, i, j, k, n,
{
dstFab(i,j,k,n+dstcomp) /= srcFab(i,j,k,n+srccomp);
});
Expand Down Expand Up @@ -1625,7 +1627,7 @@ Abs (FabArray<FAB>& fa, int icomp, int numcomp, const IntVect& nghost)
if (bx.ok())
{
auto const& fab = fa.array(mfi);
AMREX_HOST_DEVICE_PARALLEL_FOR_4D ( bx, numcomp, i, j, k, n,
AMREX_HOST_DEVICE_PARALLEL_FOR_4D_FUSIBLE ( bx, numcomp, i, j, k, n,
{
fab(i,j,k,n+icomp) = amrex::Math::abs(fab(i,j,k,n+icomp));
});
Expand Down Expand Up @@ -1682,7 +1684,7 @@ OverrideSync (FabArray<FAB> & fa, FabArray<IFAB> const& msk, const Periodicity&
const Box& bx = mfi.tilebox();
auto fab = fa.array(mfi);
auto const ifab = msk.array(mfi);
AMREX_HOST_DEVICE_PARALLEL_FOR_4D ( bx, ncomp, i, j, k, n,
AMREX_HOST_DEVICE_PARALLEL_FOR_4D_FUSIBLE ( bx, ncomp, i, j, k, n,
{
if (!ifab(i,j,k)) fab(i,j,k,n) = 0;
});
Expand Down
1 change: 1 addition & 0 deletions Src/Base/AMReX_Gpu.H
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ namespace amrex { namespace Cuda {} }
#endif

#include <AMReX_GpuQualifiers.H>
#include <AMReX_GpuKernelInfo.H>
#include <AMReX_GpuPrint.H>
#include <AMReX_GpuAssert.H>
#include <AMReX_GpuTypes.H>
Expand Down
69 changes: 52 additions & 17 deletions Src/Base/AMReX_GpuFuse.H
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
namespace amrex {
namespace Gpu {

#ifdef AMREX_USE_GPU

#ifdef AMREX_USE_CUDA

typedef void (*Lambda1DLauncher)(char*,int);
Expand Down Expand Up @@ -229,23 +231,6 @@ private:
}
};

Long getFuseSizeThreshold ();
Long setFuseSizeThreshold (Long new_threshold);
int getFuseNumKernelsThreshold ();
int setFuseNumKernelsThreshold (int new_threshold);
bool inFuseRegion ();
bool setFuseRegion (bool flag);

struct FuseSafeGuard
{
explicit FuseSafeGuard (bool flag) noexcept
: m_old(setFuseRegion(flag)) {}
~FuseSafeGuard () { setFuseRegion(m_old); }
private:
bool m_old;
};


template <typename F>
void
Register (Box const& bx, F&& f)
Expand Down Expand Up @@ -273,6 +258,56 @@ LaunchFusedKernels ()
Fuser::getInstance().Launch();
}

#else

class Fuser
{
public:
static Fuser& getInstance ();
static void Initialize ();
static void Finalize ();
private:
static std::unique_ptr<Fuser> m_instance;
};

inline void LaunchFusedKernels () {}

#endif

Long getFuseSizeThreshold ();
Long setFuseSizeThreshold (Long new_threshold);
int getFuseNumKernelsThreshold ();
int setFuseNumKernelsThreshold (int new_threshold);
bool inFuseRegion ();
bool setFuseRegion (bool flag);
bool inFuseReductionRegion ();
bool setFuseReductionRegion (bool flag);

struct FuseSafeGuard
{
explicit FuseSafeGuard (bool flag) noexcept
: m_old(setFuseRegion(flag)) {}
~FuseSafeGuard () { setFuseRegion(m_old); }
private:
bool m_old;
};

struct FuseReductionSafeGuard
{
explicit FuseReductionSafeGuard (bool flag) noexcept
: m_old(setFuseReductionRegion(flag)) {}
~FuseReductionSafeGuard () { setFuseReductionRegion(m_old); }
private:
bool m_old;
};

#else

struct FuseSafeGuard
{
explicit FuseSafeGuard (bool) {}
};

#endif

}}
Expand Down
Loading

0 comments on commit 1cec808

Please sign in to comment.