From a0bfcfec7362b0433ebbe786f6f19cc4b9df6c9d Mon Sep 17 00:00:00 2001 From: Evan Weinberg Date: Tue, 30 May 2023 08:58:25 -0700 Subject: [PATCH 01/12] Potential fix for heavy quark residual restart woes --- lib/inv_cg_quda.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/lib/inv_cg_quda.cpp b/lib/inv_cg_quda.cpp index bfce1b422b..40d8e66fcf 100644 --- a/lib/inv_cg_quda.cpp +++ b/lib/inv_cg_quda.cpp @@ -227,8 +227,14 @@ namespace quda { // Detect whether this is a pure double solve or not; informs the necessity of some stability checks bool is_pure_double = (param.precision == QUDA_DOUBLE_PRECISION && param.precision_sloppy == QUDA_DOUBLE_PRECISION); + // Determine whether or not we're doing a heavy quark residual + const bool use_heavy_quark_res = + (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; + bool heavy_quark_restart = false; + // whether to select alternative reliable updates - bool alternative_reliable = param.use_alternative_reliable; + // if we're computing the heavy quark residual, force "traditional" reliable updates + bool alternative_reliable = use_heavy_quark_res ? param.use_alternative_reliable : false; /** When CG is used as a preconditioner, and we disable the `advanced features`, these features are turned off: - Reliable updates @@ -356,10 +362,6 @@ namespace quda { blas::xpayz(rSloppy, beta, x_update_batch.get_current_field(), x_update_batch.get_current_field()); } - const bool use_heavy_quark_res = - (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; - bool heavy_quark_restart = false; - if (!param.is_preconditioner) { profile.TPSTOP(QUDA_PROFILE_INIT); profile.TPSTART(QUDA_PROFILE_PREAMBLE); From 9768c6395db04c74ccb25e75fecb0fdcd2735adf Mon Sep 17 00:00:00 2001 From: Evan Weinberg Date: Tue, 6 Jun 2023 14:20:08 -0400 Subject: [PATCH 02/12] Added a separate, simpler (but less hyper-optimized) codepath for CG solves requesting a HQ tolerance --- include/invert_quda.h | 8 + lib/inv_cg_quda.cpp | 346 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 354 insertions(+) diff --git a/include/invert_quda.h b/include/invert_quda.h index b2699ad3b7..a472941b7d 100644 --- a/include/invert_quda.h +++ b/include/invert_quda.h @@ -785,6 +785,14 @@ namespace quda { void blocksolve(ColorSpinorField& out, ColorSpinorField& in); virtual bool hermitian() { return true; } /** CG is only for Hermitian systems */ + + protected: + /** + * @brief Separate codepath for performing a "simpler" CG solve when a heavy quark residual is requested. + * @param out Solution-vector. + * @param in Right-hand side. + */ + void hqsolve(ColorSpinorField& out, ColorSpinorField& in); }; class CGNE : public CG diff --git a/lib/inv_cg_quda.cpp b/lib/inv_cg_quda.cpp index 40d8e66fcf..b55a3ea02c 100644 --- a/lib/inv_cg_quda.cpp +++ b/lib/inv_cg_quda.cpp @@ -232,6 +232,12 @@ namespace quda { (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; bool heavy_quark_restart = false; + if (use_heavy_quark_res) { + hqsolve(x, b); + if (param.is_preconditioner) commGlobalReductionPop(); + return; + } + // whether to select alternative reliable updates // if we're computing the heavy quark residual, force "traditional" reliable updates bool alternative_reliable = use_heavy_quark_res ? param.use_alternative_reliable : false; @@ -639,6 +645,346 @@ namespace quda { if (param.is_preconditioner) commGlobalReductionPop(); } + // Separate HQ residual codepath + void CG::hqsolve(ColorSpinorField &x, ColorSpinorField &b) { + + logQuda(QUDA_VERBOSE, "Performing a HQ CG solve\n"); + + // Verbose errors: HQ solves won't support deflation, pipelining + if (param.deflate) + errorQuda("HQ solves don't support deflation"); + if (param.is_preconditioner) + errorQuda("HQ solves cannot be preconditioners"); + + // Non-terminal errors: HQ solves don't support advanced reliable updates + if (param.use_alternative_reliable) + logQuda(QUDA_SUMMARIZE, "HQ solves don't support alternative reliable updates, reverting to traditional reliable updates\n"); + if (param.pipeline) + logQuda(QUDA_SUMMARIZE, "HQ solves don't support pipelining, disabling..."); + + profile.TPSTART(QUDA_PROFILE_INIT); + + double b2 = blas::norm2(b); + + // Detect whether this is a pure double solve or not; informs the necessity of some stability checks + bool is_pure_double = (param.precision == QUDA_DOUBLE_PRECISION && param.precision_sloppy == QUDA_DOUBLE_PRECISION); + + bool heavy_quark_restart = false; + + // Check to see that we're not trying to invert on a zero-field source + if (b2 == 0 && param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { + profile.TPSTOP(QUDA_PROFILE_INIT); + printfQuda("Warning: inverting on zero-field source\n"); + x = b; + param.true_res = 0.0; + param.true_res_hq = 0.0; + return; + } + + if (!init) { + ColorSpinorParam csParam(x); + csParam.create = QUDA_NULL_FIELD_CREATE; + rp = ColorSpinorField::Create(csParam); + yp = ColorSpinorField::Create(csParam); + + // sloppy fields + csParam.setPrecision(param.precision_sloppy); + pp = ColorSpinorField::Create(csParam); + App = ColorSpinorField::Create(csParam); + if(param.precision != param.precision_sloppy) { + rSloppyp = ColorSpinorField::Create(csParam); + xSloppyp = ColorSpinorField::Create(csParam); + } else { + rSloppyp = rp; + param.use_sloppy_partial_accumulator = false; + } + + // temporary fields + tmpp = ColorSpinorField::Create(csParam); + init = true; + } + + ColorSpinorField &r = *rp; + ColorSpinorField &y = *yp; + ColorSpinorField &p = *pp; + ColorSpinorField &Ap = *App; + ColorSpinorField &tmp = *tmpp; + ColorSpinorField &rSloppy = *rSloppyp; + ColorSpinorField &xSloppy = param.use_sloppy_partial_accumulator ? *xSloppyp : x; + + const double uhigh = precisionEpsilon(); // solver precision + + double beta = 0; + + // for detecting HQ residual stalls + // let |r2/b2| drop to epsilon tolerance * 1e-30, semi-arbitrarily, but + // with the intent of letting the solve grind as long as possible before + // triggering a `NaN`. Ignored for pure double solves because if + // pure double has stability issues, bigger problems are at hand. + const double hq_res_stall_check = is_pure_double ? 0. : uhigh * uhigh * 1e-60; + + // compute initial residual + double r2 = 0.0; + if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { + // Compute r = b - A * x + mat(r, x); + r2 = blas::xmyNorm(b, r); + if (b2 == 0) b2 = r2; + // y contains the original guess. + blas::copy(y, x); + } else { + if (&r != &b) blas::copy(r, b); + r2 = b2; + blas::zero(y); + } + + blas::zero(x); + if (&x != &xSloppy) blas::zero(xSloppy); + blas::copy(rSloppy,r); + blas::copy(p, rSloppy); + + double r2_old = 0.0; + + profile.TPSTOP(QUDA_PROFILE_INIT); + profile.TPSTART(QUDA_PROFILE_PREAMBLE); + + double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + + double heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + double heavy_quark_res_old = heavy_quark_res; // heavy quark residual + const int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual + + double alpha, pAp; + + // set L2breakdown to be immediately true if we aren't requesting an L2 norm, alternatively, + // it only gets set to true after the L2 norm has "stalled out" + bool L2breakdown = !(param.residual_type & (QUDA_L2_RELATIVE_RESIDUAL | QUDA_L2_ABSOLUTE_RESIDUAL)); + const double L2breakdown_eps = 100. * uhigh; + + + profile.TPSTOP(QUDA_PROFILE_PREAMBLE); + profile.TPSTART(QUDA_PROFILE_COMPUTE); + blas::flops = 0; + + int k = 0; + + PrintStats("CG", k, r2, b2, heavy_quark_res); + + bool converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); + + double rNorm = sqrt(r2); + double r0Norm = rNorm; + double maxrx = rNorm; + double maxrr = rNorm; + bool updateX = false; + bool updateR = false; + int steps_since_reliable = 1; + int rUpdate = 0; + + int resIncrease = 0; + int resIncreaseTotal = 0; + int hqresIncrease = 0; + int hqresRestartTotal = 0; + + while ( !converged && k < param.maxiter ) { + matSloppy(Ap, p); + double sigma; + + r2_old = r2; + + pAp = blas::reDotProduct(p, Ap); + + alpha = r2 / pAp; + + // here we are deploying the alternative beta computation + auto cg_norm = blas::axpyCGNorm(-alpha, Ap, rSloppy); + r2 = cg_norm.x; // (r_new, r_new) + sigma = cg_norm.y >= 0.0 ? cg_norm.y : r2; // use r2 if (r_k+1, r_k+1-r_k) breaks + + // reliable update conditions + rNorm = sqrt(r2); + + // from reliable_updates.h + if (rNorm > maxrx) maxrx = rNorm; + if (rNorm > maxrr) maxrr = rNorm; + updateX = (rNorm < param.delta * r0Norm && r0Norm <= maxrx); + updateR = ((rNorm < param.delta * maxrr && r0Norm <= maxrr) || updateX); + + if (updateX) + logQuda(QUDA_VERBOSE, "Triggered updateX via `(rNorm < param.delta * r0Norm && r0Norm <= maxrx)`\n"); + + if (updateR) + logQuda(QUDA_VERBOSE, "Triggered updateR via `((rNorm < param.delta * maxrr && r0Norm <= maxrr) || updateX)`\n"); + + // force a reliable update if we are within target tolerance (only if doing reliable updates) + if (convergence(r2, heavy_quark_res, stop, param.tol_hq) && param.delta >= param.tol) { + updateX = true; + logQuda(QUDA_VERBOSE, "Triggered updateX via convergence && delta > tol check\n"); + } + + if (L2breakdown + && (convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq) || (r2 / b2) < hq_res_stall_check) + && param.delta >= param.tol) { + updateX = true; + logQuda(QUDA_VERBOSE, "Triggered updateX via L2breakdown condition\n"); + } + + if (!(updateR || updateX)) { + beta = sigma / r2_old; // use the alternative beta computation + + blas::axpyZpbx(alpha, p, xSloppy, rSloppy, beta); + + if (k % heavy_quark_check == 0) { + if (&x != &xSloppy) { + blas::copy(tmp, y); + heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(xSloppy, tmp, rSloppy).z); + } else { + blas::copy(r, rSloppy); + heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x, y, r).z); + } + } + + steps_since_reliable++; + + } else { + + + blas::axpy(alpha, p, xSloppy); + + blas::copy(x, xSloppy); // no op when these pointers alias + + blas::xpy(x, y); + mat(r, y); + r2 = blas::xmyNorm(b, r); + + blas::copy(rSloppy, r); // no op when these pointers alias + blas::zero(xSloppy); + + rNorm = sqrt(r2); + maxrr = rNorm; + maxrx = rNorm; + + // calculate new reliable HQ resididual + heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(y, r).z); + + // break-out check if we have reached the limit of the precision + if (sqrt(r2) > r0Norm && updateX && !L2breakdown) { // reuse r0Norm for this + resIncrease++; + resIncreaseTotal++; + warningQuda("new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)", + sqrt(r2), r0Norm, resIncreaseTotal); + + if (sqrt(r2) < L2breakdown_eps + || resIncrease > param.max_res_increase + || resIncreaseTotal > param.max_res_increase_total + || r2 < stop) { + L2breakdown = true; + warningQuda("L2 breakdown %e, %e", sqrt(r2), L2breakdown_eps); + } + } else { + resIncrease = 0; + } + + // if L2 broke down already we turn off reliable updates and restart the CG + if (L2breakdown) { + hqresRestartTotal++; // count the number of heavy quark restarts we've done + warningQuda("CG: Restarting without reliable updates for heavy-quark residual (total #inc %i)", + hqresRestartTotal); + heavy_quark_restart = true; + + if (heavy_quark_res > heavy_quark_res_old) { // check if new hq residual is greater than previous + hqresIncrease++; // count the number of consecutive increases + warningQuda("CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", + heavy_quark_res, heavy_quark_res_old); + // break out if we do not improve here anymore + if (hqresIncrease > param.max_hq_res_increase) { + warningQuda("CG: solver exiting due to too many heavy quark residual norm increases (%i/%i)", hqresIncrease, + param.max_hq_res_increase); + break; + } + } else { + hqresIncrease = 0; + } + } + + if (hqresRestartTotal > param.max_hq_res_restart_total) { + warningQuda("CG: solver exiting due to too many heavy quark residual restarts (%i/%i)", hqresRestartTotal, + param.max_hq_res_restart_total); + break; + } + + if (heavy_quark_restart) { + // perform a restart + logQuda(QUDA_SUMMARIZE, "HQ restart == hard CG restart\n"); + blas::copy(p, rSloppy); + heavy_quark_restart = false; + } else { + logQuda(QUDA_SUMMARIZE, "Regular restart == explicit gradient vector re-orthogonalization\n"); + // explicitly restore the orthogonality of the gradient vector + Complex rp = blas::cDotProduct(rSloppy, p) / (r2); + blas::caxpy(-rp, rSloppy, p); + + beta = r2 / r2_old; + blas::xpayz(rSloppy, beta, p, p); + } + + steps_since_reliable = 0; + r0Norm = sqrt(r2); + rUpdate++; + + heavy_quark_res_old = heavy_quark_res; + } + + k++; + + PrintStats("CG", k, r2, b2, heavy_quark_res); + // check convergence, if convergence is satisfied we only need to check that we had a reliable update for the heavy quarks recently + converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); + + // check for recent enough reliable updates of the HQ residual if we use it + // L2 is converged or precision maxed out for L2 + bool L2done = L2breakdown or convergenceL2(r2, heavy_quark_res, stop, param.tol_hq); + // HQ is converged and if we do reliable update the HQ residual has been calculated using a reliable update + bool HQdone = (steps_since_reliable == 0 && param.delta > 0) + && convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq); + converged = L2done && HQdone; + + } + + blas::copy(x, xSloppy); + blas::xpy(y, x); + + profile.TPSTOP(QUDA_PROFILE_COMPUTE); + profile.TPSTART(QUDA_PROFILE_EPILOGUE); + + param.secs = profile.Last(QUDA_PROFILE_COMPUTE); + double gflops = (blas::flops + mat.flops() + matSloppy.flops() + matPrecon.flops() + matEig.flops()) * 1e-9; + param.gflops = gflops; + param.iter += k; + + if (k == param.maxiter) warningQuda("Exceeded maximum iterations %d", param.maxiter); + + if (getVerbosity() >= QUDA_VERBOSE) printfQuda("CG: Reliable updates = %d\n", rUpdate); + + if (param.compute_true_res) { + // compute the true residuals + mat(r, x); + param.true_res = sqrt(blas::xmyNorm(b, r) / b2); + param.true_res_hq = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + } + + PrintSummary("CG", k, r2, b2, stop, param.tol_hq); + + // reset the flops counters + blas::flops = 0; + mat.flops(); + matSloppy.flops(); + matPrecon.flops(); + + profile.TPSTOP(QUDA_PROFILE_EPILOGUE); + + } + // use BlockCGrQ algortithm or BlockCG (with / without GS, see BLOCKCG_GS option) #define BCGRQ 1 #if BCGRQ From f66ec8f80b20d9c1441126a6c64d548a6c9b8c3f Mon Sep 17 00:00:00 2001 From: Evan Weinberg Date: Wed, 7 Jun 2023 01:05:23 -0400 Subject: [PATCH 03/12] Updated the logic for HQ reliable updates to match L2 reliable updates and only trigger when the HQ residual has dropped by an appropriate amount --- lib/inv_cg_quda.cpp | 245 +++++++++++++++++++++++++++++--------------- 1 file changed, 164 insertions(+), 81 deletions(-) diff --git a/lib/inv_cg_quda.cpp b/lib/inv_cg_quda.cpp index b55a3ea02c..6430553a77 100644 --- a/lib/inv_cg_quda.cpp +++ b/lib/inv_cg_quda.cpp @@ -712,15 +712,12 @@ namespace quda { ColorSpinorField &rSloppy = *rSloppyp; ColorSpinorField &xSloppy = param.use_sloppy_partial_accumulator ? *xSloppyp : x; - const double uhigh = precisionEpsilon(); // solver precision - - double beta = 0; - // for detecting HQ residual stalls // let |r2/b2| drop to epsilon tolerance * 1e-30, semi-arbitrarily, but // with the intent of letting the solve grind as long as possible before // triggering a `NaN`. Ignored for pure double solves because if // pure double has stability issues, bigger problems are at hand. + const double uhigh = precisionEpsilon(); // solver precision const double hq_res_stall_check = is_pure_double ? 0. : uhigh * uhigh * 1e-60; // compute initial residual @@ -750,15 +747,17 @@ namespace quda { double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver - double heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); - double heavy_quark_res_old = heavy_quark_res; // heavy quark residual - const int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual + // compute the initial heavy quark residual + double hq_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + + double alpha, beta, sigma, pAp; - double alpha, pAp; + // Whether or not we also need to compute the L2 norm + const bool L2_required = param.residual_type & (QUDA_L2_RELATIVE_RESIDUAL | QUDA_L2_ABSOLUTE_RESIDUAL); // set L2breakdown to be immediately true if we aren't requesting an L2 norm, alternatively, - // it only gets set to true after the L2 norm has "stalled out" - bool L2breakdown = !(param.residual_type & (QUDA_L2_RELATIVE_RESIDUAL | QUDA_L2_ABSOLUTE_RESIDUAL)); + // it only gets set to true after some heuristics suggest the L2 norm has "stalled out" + bool L2breakdown = !L2_required; const double L2breakdown_eps = 100. * uhigh; @@ -768,27 +767,58 @@ namespace quda { int k = 0; - PrintStats("CG", k, r2, b2, heavy_quark_res); + PrintStats("CG", k, r2, b2, hq_res); - bool converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); + bool converged = convergence(r2, hq_res, stop, param.tol_hq); + + // Various parameters related to restarts + // Trackers for the L2 norm: + // rNorm: current iterated |r| + // r0Norm: computed |r| at the last reliable update double rNorm = sqrt(r2); double r0Norm = rNorm; - double maxrx = rNorm; - double maxrr = rNorm; + + // If the computed |r| goes above r0Norm between reliable updates, + // update this ceiling. This goes into "R" type reliable updates. + double maxrx = L2breakdown ? hq_res : rNorm; + double maxrr = L2breakdown ? hq_res : rNorm; + + // Trigger for explicitly counting residual updates and checking for L2breakdown bool updateX = false; + + // idk bool updateR = false; - int steps_since_reliable = 1; - int rUpdate = 0; + // count the number of times the computed residual has jumped above + // the previously computed residual, so long as the L2 norm hasn't + // broken down int resIncrease = 0; + + // count the total number of residual increases has increased independent of resetting int resIncreaseTotal = 0; + + // Trackers for the HQ residual + // hq0Res: computed HQ residual at the last reliable updated + double hq0Res = hq_res; + + // count the number of times the computed hq residual has jumped above + // the previously computed residual int hqresIncrease = 0; + + // count the total number of times a heavy quark restart has been triggered + // THIS NEEDS TO BE FIXED: IT CAN DEPEND ON ONLY THE COMPUTED L2 NORM DROPPING + // AFTER L2 BREAKDOWN! int hqresRestartTotal = 0; + // Count the steps since a reliable update and the total number of reliable updates. + // The steps since a reliable update is also used to make sure final convergence is + // based on the computed residual and not the iterated residual. + int rUpdate = 0; + int steps_since_reliable = 1; + while ( !converged && k < param.maxiter ) { matSloppy(Ap, p); - double sigma; r2_old = r2; @@ -800,117 +830,170 @@ namespace quda { auto cg_norm = blas::axpyCGNorm(-alpha, Ap, rSloppy); r2 = cg_norm.x; // (r_new, r_new) sigma = cg_norm.y >= 0.0 ? cg_norm.y : r2; // use r2 if (r_k+1, r_k+1-r_k) breaks - - // reliable update conditions rNorm = sqrt(r2); - // from reliable_updates.h - if (rNorm > maxrx) maxrx = rNorm; - if (rNorm > maxrr) maxrr = rNorm; - updateX = (rNorm < param.delta * r0Norm && r0Norm <= maxrx); - updateR = ((rNorm < param.delta * maxrr && r0Norm <= maxrr) || updateX); + // If the iterated norm has dropped by more than a factor of delta, trigger + // an update. + + // L2 based reliable update + if (!L2breakdown && (L2_required || convergenceL2(r2, hq_res, stop, param.tol_hq))) { + // if the iterated residual norm has gone above the most recent "baseline" norm, + // update the baseline norm. + if (rNorm > maxrx) maxrx = rNorm; + if (rNorm > maxrr) maxrr = rNorm; + + // Has the iterated norm dropped by a factor of delta from the last computed norm? + updateX = (rNorm < param.delta * r0Norm && r0Norm <= maxrx); - if (updateX) - logQuda(QUDA_VERBOSE, "Triggered updateX via `(rNorm < param.delta * r0Norm && r0Norm <= maxrx)`\n"); + // Has the iterated norm dropped by a factor of delta relative to the largest the + // iterated norm has been since the last update? + updateR = ((rNorm < param.delta * maxrr && r0Norm <= maxrr) || updateX); - if (updateR) - logQuda(QUDA_VERBOSE, "Triggered updateR via `((rNorm < param.delta * maxrr && r0Norm <= maxrr) || updateX)`\n"); + if (updateX) + logQuda(QUDA_VERBOSE, "Triggered L2 updateX via `(rNorm < param.delta * r0Norm && r0Norm <= maxrx)`\n"); + + if (updateR) + logQuda(QUDA_VERBOSE, "Triggered L2 updateR via `((rNorm < param.delta * maxrr && r0Norm <= maxrr) || updateX)`\n"); + } else { + // hqresidual based reliable update + if (hq_res > maxrx) maxrx = hq_res; + if (hq_res > maxrr) maxrr = hq_res; + + // I'm making the decision to use `param.delta` for the hq_res check because + // in some regards it's an L2-esque norm... + + // Has the iterated heavy quark residual dropped by a factor of delta^2 from the last + // computed norm? + updateX = (hq_res < param.delta * param.delta * hq0Res && r0Norm <= maxrx); + + // Has the iterated heavy quark residual dropped by a factor of delta relative + // to the largest the iterated norm has been since the last update? + updateR = ((hq_res < param.delta * param.delta * maxrr && hq0Res <= maxrr) || updateX); + + if (updateX) + logQuda(QUDA_VERBOSE, "Triggered HQ updateX via `(hq_res < param.delta * hq0Res && hq0Res <= maxrx)`\n"); + + if (updateR) + logQuda(QUDA_VERBOSE, "Triggered HQ updateR via `((hq_res < param.delta * maxrr && hq0Res <= maxrr) || updateX)`\n"); + } // force a reliable update if we are within target tolerance (only if doing reliable updates) - if (convergence(r2, heavy_quark_res, stop, param.tol_hq) && param.delta >= param.tol) { + if (convergence(r2, hq_res, stop, param.tol_hq) && param.delta >= param.tol) { updateX = true; logQuda(QUDA_VERBOSE, "Triggered updateX via convergence && delta > tol check\n"); } - if (L2breakdown - && (convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq) || (r2 / b2) < hq_res_stall_check) - && param.delta >= param.tol) { + // force a reliable update based on the HQ residual if L2 breakdown has already happened + if (L2breakdown && (convergenceHQ(r2, hq_res, stop, param.tol_hq) || (r2 / b2) < hq_res_stall_check) && + param.delta >= param.tol) { updateX = true; logQuda(QUDA_VERBOSE, "Triggered updateX via L2breakdown condition\n"); } if (!(updateR || updateX)) { + // No reliable update needed + beta = sigma / r2_old; // use the alternative beta computation blas::axpyZpbx(alpha, p, xSloppy, rSloppy, beta); - if (k % heavy_quark_check == 0) { + if (k % param.heavy_quark_check == 0) { if (&x != &xSloppy) { blas::copy(tmp, y); - heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(xSloppy, tmp, rSloppy).z); + hq_res = sqrt(blas::xpyHeavyQuarkResidualNorm(xSloppy, tmp, rSloppy).z); } else { blas::copy(r, rSloppy); - heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x, y, r).z); + hq_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x, y, r).z); } } steps_since_reliable++; } else { + // We're performing a reliable update: - + // Accumulate p into x, accumulate x into the total solution y, explicitly recompute the residual vector blas::axpy(alpha, p, xSloppy); - blas::copy(x, xSloppy); // no op when these pointers alias blas::xpy(x, y); mat(r, y); r2 = blas::xmyNorm(b, r); + hq_res = sqrt(blas::HeavyQuarkResidualNorm(y, r).z); + + // reset the L2 reliable update baseline + rNorm = sqrt(r2); blas::copy(rSloppy, r); // no op when these pointers alias blas::zero(xSloppy); - rNorm = sqrt(r2); - maxrr = rNorm; - maxrx = rNorm; + // check for L2 convergence + if (!L2breakdown && L2_required && convergenceL2(r2, hq_res, stop, param.tol_hq)) + L2breakdown = true; - // calculate new reliable HQ resididual - heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(y, r).z); + // Reset the reliable update baseline + if (!L2breakdown) { + maxrr = rNorm; + maxrx = rNorm; + } else { + maxrr = hq_res; + maxrx = hq_res; + + // Increment the total number of HQ restarts + hqresRestartTotal++; + + // set that we're doing a heavy quark restart + heavy_quark_restart = true; + + warningQuda("CG: Restarting without reliable updates for heavy-quark residual (total #inc %i)", + hqresRestartTotal); + + if (hqresRestartTotal > param.max_hq_res_restart_total) { + warningQuda("CG: solver exiting due to too many heavy quark residual restarts (%i/%i)", hqresRestartTotal, + param.max_hq_res_restart_total); + break; + } + } // break-out check if we have reached the limit of the precision - if (sqrt(r2) > r0Norm && updateX && !L2breakdown) { // reuse r0Norm for this + // we're reusing r0Norm for this check + if (rNorm > r0Norm && updateX && !L2breakdown) { resIncrease++; resIncreaseTotal++; warningQuda("new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)", sqrt(r2), r0Norm, resIncreaseTotal); - if (sqrt(r2) < L2breakdown_eps - || resIncrease > param.max_res_increase - || resIncreaseTotal > param.max_res_increase_total - || r2 < stop) { + if (rNorm < L2breakdown_eps || resIncrease > param.max_res_increase + || resIncreaseTotal > param.max_res_increase_total || r2 < stop) { L2breakdown = true; - warningQuda("L2 breakdown %e, %e", sqrt(r2), L2breakdown_eps); + warningQuda("L2 breakdown %e, %e", rNorm, L2breakdown_eps); + + // we now switch over to reliable updates based on hq values + // hq0Res is set below + maxrr = hq_res; + maxrx = hq_res; + } } else { resIncrease = 0; } // if L2 broke down already we turn off reliable updates and restart the CG - if (L2breakdown) { - hqresRestartTotal++; // count the number of heavy quark restarts we've done - warningQuda("CG: Restarting without reliable updates for heavy-quark residual (total #inc %i)", - hqresRestartTotal); - heavy_quark_restart = true; - - if (heavy_quark_res > heavy_quark_res_old) { // check if new hq residual is greater than previous - hqresIncrease++; // count the number of consecutive increases - warningQuda("CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", - heavy_quark_res, heavy_quark_res_old); - // break out if we do not improve here anymore - if (hqresIncrease > param.max_hq_res_increase) { - warningQuda("CG: solver exiting due to too many heavy quark residual norm increases (%i/%i)", hqresIncrease, - param.max_hq_res_increase); - break; - } - } else { - hqresIncrease = 0; + if (hq_res > hq0Res && updateX && L2breakdown) { + // count the number of consecutive increases + hqresIncrease++; + + warningQuda("CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", + hq_res, hq0Res); + + // break out if we do not improve here anymore + if (hqresIncrease > param.max_hq_res_increase) { + warningQuda("CG: solver exiting due to too many heavy quark residual norm increases (%i/%i)", hqresIncrease, + param.max_hq_res_increase); + break; } - } - - if (hqresRestartTotal > param.max_hq_res_restart_total) { - warningQuda("CG: solver exiting due to too many heavy quark residual restarts (%i/%i)", hqresRestartTotal, - param.max_hq_res_restart_total); - break; + } else { + hqresIncrease = 0; } if (heavy_quark_restart) { @@ -932,22 +1015,22 @@ namespace quda { r0Norm = sqrt(r2); rUpdate++; - heavy_quark_res_old = heavy_quark_res; + hq0Res = hq_res; } k++; - PrintStats("CG", k, r2, b2, heavy_quark_res); + PrintStats("CG", k, r2, b2, hq_res); // check convergence, if convergence is satisfied we only need to check that we had a reliable update for the heavy quarks recently - converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); + converged = convergence(r2, hq_res, stop, param.tol_hq); // check for recent enough reliable updates of the HQ residual if we use it - // L2 is converged or precision maxed out for L2 - bool L2done = L2breakdown or convergenceL2(r2, heavy_quark_res, stop, param.tol_hq); - // HQ is converged and if we do reliable update the HQ residual has been calculated using a reliable update - bool HQdone = (steps_since_reliable == 0 && param.delta > 0) - && convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq); - converged = L2done && HQdone; + + // L2 is converged or precision maxed out for L2 + bool L2done = L2breakdown || convergenceL2(r2, hq_res, stop, param.tol_hq); + // HQ is converged and if we do reliable update the HQ residual has been calculated using a reliable update + bool HQdone = (steps_since_reliable == 0 && param.delta > 0) && convergenceHQ(r2, hq_res, stop, param.tol_hq); + converged = L2done && HQdone; } From d90e19f7d00132f7624ce74193df84d060fb5579 Mon Sep 17 00:00:00 2001 From: Evan Weinberg Date: Wed, 7 Jun 2023 18:14:23 -0400 Subject: [PATCH 04/12] A little bit of clean up and a lot of documentation --- lib/inv_cg_quda.cpp | 116 +++++++++++++++++++++++++++++--------------- 1 file changed, 76 insertions(+), 40 deletions(-) diff --git a/lib/inv_cg_quda.cpp b/lib/inv_cg_quda.cpp index 6430553a77..fca4f1bec7 100644 --- a/lib/inv_cg_quda.cpp +++ b/lib/inv_cg_quda.cpp @@ -784,31 +784,33 @@ namespace quda { double maxrx = L2breakdown ? hq_res : rNorm; double maxrr = L2breakdown ? hq_res : rNorm; - // Trigger for explicitly counting residual updates and checking for L2breakdown + // Triggers for explicitly counting residual updates and checking for L2breakdown. + // * updateX broadly maps to if the iterated residual has dropped by a factor of delta + // relative to the previously re-computed residual. + // * updateR broadly maps to if the iterated residual has dropped by a factor of delta + // relative to the max of the previously re-computed residual and all iterated residuals + // since the last reliable update. bool updateX = false; - - // idk bool updateR = false; - // count the number of times the computed residual has jumped above - // the previously computed residual, so long as the L2 norm hasn't - // broken down + // Counter for the number of times in a row the computed residual has jumped above the + // previously computed residual. int resIncrease = 0; - // count the total number of residual increases has increased independent of resetting + // Counter for the total number of times the computed residual has increased above the previously + // computed residual, independent of when it happened. int resIncreaseTotal = 0; // Trackers for the HQ residual // hq0Res: computed HQ residual at the last reliable updated double hq0Res = hq_res; - // count the number of times the computed hq residual has jumped above - // the previously computed residual + // Counter for the number of times in a row the computed heavy quark residual has + // jumped above the previously computed heavy quark residual. int hqresIncrease = 0; - // count the total number of times a heavy quark restart has been triggered - // THIS NEEDS TO BE FIXED: IT CAN DEPEND ON ONLY THE COMPUTED L2 NORM DROPPING - // AFTER L2 BREAKDOWN! + // Counter for the total number of times a reliable updated based on the heavy quark residual + // has been triggered. int hqresRestartTotal = 0; // Count the steps since a reliable update and the total number of reliable updates. @@ -833,11 +835,14 @@ namespace quda { rNorm = sqrt(r2); // If the iterated norm has dropped by more than a factor of delta, trigger - // an update. + // an update. The baseline we check against differs depending on if + // we're still checking the L2 norm, or if that has converged/broken down and we're + // now looking at the HQ residual. - // L2 based reliable update if (!L2breakdown && (L2_required || convergenceL2(r2, hq_res, stop, param.tol_hq))) { - // if the iterated residual norm has gone above the most recent "baseline" norm, + // L2 based reliable update + + // If the iterated residual norm has gone above the most recent "baseline" norm, // update the baseline norm. if (rNorm > maxrx) maxrx = rNorm; if (rNorm > maxrr) maxrr = rNorm; @@ -910,100 +915,128 @@ namespace quda { steps_since_reliable++; } else { - // We're performing a reliable update: + // We're performing a reliable update // Accumulate p into x, accumulate x into the total solution y, explicitly recompute the residual vector blas::axpy(alpha, p, xSloppy); blas::copy(x, xSloppy); // no op when these pointers alias - blas::xpy(x, y); mat(r, y); - r2 = blas::xmyNorm(b, r); - hq_res = sqrt(blas::HeavyQuarkResidualNorm(y, r).z); - - // reset the L2 reliable update baseline - rNorm = sqrt(r2); - blas::copy(rSloppy, r); // no op when these pointers alias blas::zero(xSloppy); - // check for L2 convergence - if (!L2breakdown && L2_required && convergenceL2(r2, hq_res, stop, param.tol_hq)) + // Recompute the exact residual and heavy quark residual + r2 = blas::xmyNorm(b, r); + rNorm = sqrt(r2); + hq_res = sqrt(blas::HeavyQuarkResidualNorm(y, r).z); + + // Check and see if we're "done" with the L2 norm. This could be because + // we were already done with it, we never needed it, or the L2 norm has finally converged. + if (!L2breakdown && convergenceL2(r2, hq_res, stop, param.tol_hq)) L2breakdown = true; - // Reset the reliable update baseline + // Depending on if we're still grinding on the L2 norm or if we've moved along to just + // the HQ norm, we reset the baselines for reliable updates that get used on the + // *next* iteration. We still need the baselines that were used for this iteration + // for the checks down below. if (!L2breakdown) { + // If we're still grinding on the L2 norm, the new baseline is the freshly + // recomputed |r|. maxrr = rNorm; maxrx = rNorm; } else { + // If we've made it to the HQ norm, the new baseline is the freshly recomputed + // heavy quark residual maxrr = hq_res; maxrx = hq_res; - // Increment the total number of HQ restarts - hqresRestartTotal++; - - // set that we're doing a heavy quark restart + // Once we're dealing with the heavy quark residual, we perform a *hard* CG + // restart at every reliable update via setting the search vector `p` to the current + // exact residual vector. heavy_quark_restart = true; + // And then we keep track of the fact we're doing a HQ residual reliable update... + hqresRestartTotal++; warningQuda("CG: Restarting without reliable updates for heavy-quark residual (total #inc %i)", hqresRestartTotal); if (hqresRestartTotal > param.max_hq_res_restart_total) { + // ...and if we've restarted too many times, flunk out of the solve. warningQuda("CG: solver exiting due to too many heavy quark residual restarts (%i/%i)", hqresRestartTotal, param.max_hq_res_restart_total); break; } } - // break-out check if we have reached the limit of the precision - // we're reusing r0Norm for this check + // Check and see if we've reached the limit of the precision. There isn't necessarily + // a great way to do this, so as a proxy we check to see if the new computed residual is + // larger than the computed residual from the last reliable update, and if this is the case + // enough times we throw up our hands, say "we're good here", and switch over to the HQ + // residual. if (rNorm > r0Norm && updateX && !L2breakdown) { + // Count the number of times in a row this has happened resIncrease++; + + // And count the total number of times this has happened outright resIncreaseTotal++; + + // ...tell the world about it too. warningQuda("new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)", sqrt(r2), r0Norm, resIncreaseTotal); + // If the norm is ridiculously small in magnitude, we've exceeded the maximums on various + // ways we keep track of residual increases, or the L2 norm converged, we say "we're good here" + // and move over to the HQ residual norm. if (rNorm < L2breakdown_eps || resIncrease > param.max_res_increase || resIncreaseTotal > param.max_res_increase_total || r2 < stop) { L2breakdown = true; warningQuda("L2 breakdown %e, %e", rNorm, L2breakdown_eps); - // we now switch over to reliable updates based on hq values - // hq0Res is set below + // We also have to do a logic correction, switching the reliable update baselines we set above + // from the L2 norm over to the HQ residual. maxrr = hq_res; maxrx = hq_res; } } else { + // This variable counts the number of times in a row the computed residual has gone up, + // so if it hasn't gone up this time around we reset this counter. resIncrease = 0; } - // if L2 broke down already we turn off reliable updates and restart the CG + // If we've done checking the L2 norm, we do a similar check of if the HQ residual has increased + // for multiple reliable updates in a row. if (hq_res > hq0Res && updateX && L2breakdown) { - // count the number of consecutive increases + // Count the number of consecutive increases hqresIncrease++; + // Tell the world about it warningQuda("CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", hq_res, hq0Res); - // break out if we do not improve here anymore + // And if it's increased too many times in a row, flunk out of the solve. if (hqresIncrease > param.max_hq_res_increase) { warningQuda("CG: solver exiting due to too many heavy quark residual norm increases (%i/%i)", hqresIncrease, param.max_hq_res_increase); break; } } else { + // This variable counts the number of times in a row the computed heavy quark residual has increased, + // so if it hasn't gone up this time around we reset the counter. hqresIncrease = 0; } + // Depending on if we're in the L2 norm part of the solve or a HQ residual part of the solve + // we "reset" the solve in a different way. if (heavy_quark_restart) { - // perform a restart + // If we're in the HQ residual part of the solve, we just do a hard CG restart. logQuda(QUDA_SUMMARIZE, "HQ restart == hard CG restart\n"); blas::copy(p, rSloppy); heavy_quark_restart = false; } else { + // If we're still in the L2 norm part of the solve, we explicitly restore + // the orthogonality of the gradient vector, recompute beta, update `p`, and carry on with our lives. logQuda(QUDA_SUMMARIZE, "Regular restart == explicit gradient vector re-orthogonalization\n"); - // explicitly restore the orthogonality of the gradient vector Complex rp = blas::cDotProduct(rSloppy, p) / (r2); blas::caxpy(-rp, rSloppy, p); @@ -1011,9 +1044,12 @@ namespace quda { blas::xpayz(rSloppy, beta, p, p); } + // Last, we increment the reliable update counter, reset the number of steps since the last reliable update, + // and reset the cached value of |r| and the heavy quark residual from the time of this + // reliable update. + rUpdate++; steps_since_reliable = 0; r0Norm = sqrt(r2); - rUpdate++; hq0Res = hq_res; } From babd616c7c3fd87f16fe7a0864383178d22015f1 Mon Sep 17 00:00:00 2001 From: Damon McDougall Date: Wed, 24 May 2023 15:18:44 -0500 Subject: [PATCH 05/12] Adding ROCm build gh action workflow --- .github/workflows/amd-build-ci.yml | 57 ++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 .github/workflows/amd-build-ci.yml diff --git a/.github/workflows/amd-build-ci.yml b/.github/workflows/amd-build-ci.yml new file mode 100644 index 0000000000..08f1ad6d5f --- /dev/null +++ b/.github/workflows/amd-build-ci.yml @@ -0,0 +1,57 @@ +name: amd-build-ci +run-name: ${{ github.actor }} is kicking off a ROCm build +on: pull_request +jobs: + rocm-build: + runs-on: [self-hosted, amd] + steps: + - uses: actions/checkout@v3 + - run: | + export ROCM_PATH=/opt/rocm-5.5.0 + export PATH=${ROCM_PATH}/bin:${ROCM_PATH}/llvm/bin:${PATH} + SRCROOT=`pwd` + BUILDROOT=`mktemp -d build-XXXXXXXX` + INSTALLROOT=`mktemp -d install-XXXXXXXX` + QUDA_GPU_ARCH=gfx90a + cmake ${SRCROOT} \ + -B ${BUILDROOT} \ + -DQUDA_TARGET_TYPE="HIP" \ + -DQUDA_GPU_ARCH=${QUDA_GPU_ARCH} \ + -DROCM_PATH=${ROCM_PATH} \ + -DQUDA_DIRAC_CLOVER=ON \ + -DQUDA_DIRAC_CLOVER_HASENBUSCH=OFF \ + -DQUDA_DIRAC_DOMAIN_WALL=OFF \ + -DQUDA_DIRAC_NDEG_TWISTED_MASS=OFF \ + -DQUDA_DIRAC_STAGGERED=ON \ + -DQUDA_DIRAC_TWISTED_MASS=OFF \ + -DQUDA_DIRAC_TWISTED_CLOVER=OFF \ + -DQUDA_DIRAC_WILSON=ON \ + -DQUDA_CLOVER_DYNAMIC=ON \ + -DQUDA_FORCE_HISQ=ON \ + -DQUDA_QDPJIT=OFF \ + -DQUDA_INTERFACE_QDPJIT=OFF \ + -DQUDA_INTERFACE_MILC=ON \ + -DQUDA_INTERFACE_CPS=OFF \ + -DQUDA_INTERFACE_QDP=ON \ + -DQUDA_INTERFACE_TIFR=OFF \ + -DQUDA_QMP=ON \ + -DQUDA_DOWNLOAD_USQCD=ON \ + -DQUDA_OPENMP=OFF \ + -DQUDA_MULTIGRID=ON \ + -DQUDA_MAX_MULTI_BLAS_N=9 \ + -DQUDA_DOWNLOAD_EIGEN=ON \ + -DQUDA_PRECISION=8 \ + -DCMAKE_INSTALL_PREFIX=${INSTALLROOT} \ + -DCMAKE_BUILD_TYPE="DEBUG" \ + -DCMAKE_CXX_COMPILER="${ROCM_PATH}/llvm/bin/clang++" \ + -DCMAKE_C_COMPILER="${ROCM_PATH}/llvm/bin/clang" \ + -DCMAKE_HIP_COMPILER="${ROCM_PATH}/llvm/bin/clang++" \ + -DBUILD_SHARED_LIBS=ON \ + -DQUDA_BUILD_SHAREDLIB=ON \ + -DQUDA_BUILD_ALL_TESTS=ON \ + -DQUDA_CTEST_DISABLE_BENCHMARKS=ON \ + -DCMAKE_C_STANDARD=99 + cmake --build ${BUILDROOT} -j 16 + cmake --install ${BUILDROOT} + rm -rf ${BUILDROOT} + rm -rf ${INSTALLROOT} From 3efa9893268350e0e2831d3eae951efa4cb27695 Mon Sep 17 00:00:00 2001 From: Damon McDougall Date: Sat, 10 Jun 2023 17:19:07 -0500 Subject: [PATCH 06/12] Build and workflow tweaks - Rename the workflow to include rocm in the name - Remove unused build options - Add more precisions to build workflow --- .github/workflows/{amd-build-ci.yml => rocm-build-ci.yml} | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) rename .github/workflows/{amd-build-ci.yml => rocm-build-ci.yml} (95%) diff --git a/.github/workflows/amd-build-ci.yml b/.github/workflows/rocm-build-ci.yml similarity index 95% rename from .github/workflows/amd-build-ci.yml rename to .github/workflows/rocm-build-ci.yml index 08f1ad6d5f..fa0d86202b 100644 --- a/.github/workflows/amd-build-ci.yml +++ b/.github/workflows/rocm-build-ci.yml @@ -27,7 +27,6 @@ jobs: -DQUDA_DIRAC_TWISTED_CLOVER=OFF \ -DQUDA_DIRAC_WILSON=ON \ -DQUDA_CLOVER_DYNAMIC=ON \ - -DQUDA_FORCE_HISQ=ON \ -DQUDA_QDPJIT=OFF \ -DQUDA_INTERFACE_QDPJIT=OFF \ -DQUDA_INTERFACE_MILC=ON \ @@ -38,9 +37,8 @@ jobs: -DQUDA_DOWNLOAD_USQCD=ON \ -DQUDA_OPENMP=OFF \ -DQUDA_MULTIGRID=ON \ - -DQUDA_MAX_MULTI_BLAS_N=9 \ -DQUDA_DOWNLOAD_EIGEN=ON \ - -DQUDA_PRECISION=8 \ + -DQUDA_PRECISION=14 \ -DCMAKE_INSTALL_PREFIX=${INSTALLROOT} \ -DCMAKE_BUILD_TYPE="DEBUG" \ -DCMAKE_CXX_COMPILER="${ROCM_PATH}/llvm/bin/clang++" \ From 9b48bbafaf37f3132e8a54ad3f293742deea4e5d Mon Sep 17 00:00:00 2001 From: Damon McDougall Date: Sat, 10 Jun 2023 17:28:41 -0500 Subject: [PATCH 07/12] Update workflow name to include 'rocm' --- .github/workflows/rocm-build-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/rocm-build-ci.yml b/.github/workflows/rocm-build-ci.yml index fa0d86202b..6e196aaaef 100644 --- a/.github/workflows/rocm-build-ci.yml +++ b/.github/workflows/rocm-build-ci.yml @@ -1,4 +1,4 @@ -name: amd-build-ci +name: rocm-build-ci run-name: ${{ github.actor }} is kicking off a ROCm build on: pull_request jobs: From e33d162b5b86012af38d361f122f56e46a2a11c8 Mon Sep 17 00:00:00 2001 From: Damon McDougall Date: Sat, 10 Jun 2023 18:49:58 -0500 Subject: [PATCH 08/12] Do release build instead of debug build --- .github/workflows/rocm-build-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/rocm-build-ci.yml b/.github/workflows/rocm-build-ci.yml index 6e196aaaef..8795607321 100644 --- a/.github/workflows/rocm-build-ci.yml +++ b/.github/workflows/rocm-build-ci.yml @@ -40,7 +40,7 @@ jobs: -DQUDA_DOWNLOAD_EIGEN=ON \ -DQUDA_PRECISION=14 \ -DCMAKE_INSTALL_PREFIX=${INSTALLROOT} \ - -DCMAKE_BUILD_TYPE="DEBUG" \ + -DCMAKE_BUILD_TYPE="RELEASE" \ -DCMAKE_CXX_COMPILER="${ROCM_PATH}/llvm/bin/clang++" \ -DCMAKE_C_COMPILER="${ROCM_PATH}/llvm/bin/clang" \ -DCMAKE_HIP_COMPILER="${ROCM_PATH}/llvm/bin/clang++" \ From 232fcfce56ff55a2392ed5b68deb04777417ca91 Mon Sep 17 00:00:00 2001 From: Damon McDougall Date: Sat, 10 Jun 2023 19:49:02 -0500 Subject: [PATCH 09/12] Do DEVEL build --- .github/workflows/rocm-build-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/rocm-build-ci.yml b/.github/workflows/rocm-build-ci.yml index 8795607321..de12f8ae5b 100644 --- a/.github/workflows/rocm-build-ci.yml +++ b/.github/workflows/rocm-build-ci.yml @@ -40,7 +40,7 @@ jobs: -DQUDA_DOWNLOAD_EIGEN=ON \ -DQUDA_PRECISION=14 \ -DCMAKE_INSTALL_PREFIX=${INSTALLROOT} \ - -DCMAKE_BUILD_TYPE="RELEASE" \ + -DCMAKE_BUILD_TYPE="DEVEL" \ -DCMAKE_CXX_COMPILER="${ROCM_PATH}/llvm/bin/clang++" \ -DCMAKE_C_COMPILER="${ROCM_PATH}/llvm/bin/clang" \ -DCMAKE_HIP_COMPILER="${ROCM_PATH}/llvm/bin/clang++" \ From 631184fdf09b7fefd74bf91fc3538a39823fee8d Mon Sep 17 00:00:00 2001 From: Evan Weinberg Date: Wed, 14 Jun 2023 16:39:48 -0400 Subject: [PATCH 10/12] Removed HQ checks from non-HQ CG path --- lib/inv_cg_quda.cpp | 110 +++++++++----------------------------------- 1 file changed, 22 insertions(+), 88 deletions(-) diff --git a/lib/inv_cg_quda.cpp b/lib/inv_cg_quda.cpp index fca4f1bec7..392f9b898d 100644 --- a/lib/inv_cg_quda.cpp +++ b/lib/inv_cg_quda.cpp @@ -224,13 +224,9 @@ namespace quda { const int Np = (param.solution_accumulator_pipeline == 0 ? 1 : param.solution_accumulator_pipeline); if (Np < 0 || Np > 16) errorQuda("Invalid value %d for solution_accumulator_pipeline\n", Np); - // Detect whether this is a pure double solve or not; informs the necessity of some stability checks - bool is_pure_double = (param.precision == QUDA_DOUBLE_PRECISION && param.precision_sloppy == QUDA_DOUBLE_PRECISION); - // Determine whether or not we're doing a heavy quark residual const bool use_heavy_quark_res = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; - bool heavy_quark_restart = false; if (use_heavy_quark_res) { hqsolve(x, b); @@ -238,9 +234,14 @@ namespace quda { return; } + // This check is pointless in the current version of the code, but it's being proactively added + // just in case HQ residual solves are split into a separate file + if (use_heavy_quark_res) + errorQuda("The \"vanilla\" CG solver does not support HQ residual solves"); + // whether to select alternative reliable updates // if we're computing the heavy quark residual, force "traditional" reliable updates - bool alternative_reliable = use_heavy_quark_res ? param.use_alternative_reliable : false; + bool alternative_reliable = param.use_alternative_reliable; /** When CG is used as a preconditioner, and we disable the `advanced features`, these features are turned off: - Reliable updates @@ -281,8 +282,6 @@ namespace quda { param.use_sloppy_partial_accumulator = false; } - // temporary fields - tmpp = ColorSpinorField::Create(csParam); init = true; } @@ -305,7 +304,6 @@ namespace quda { ColorSpinorField &r = *rp; ColorSpinorField &y = *yp; ColorSpinorField &Ap = *App; - ColorSpinorField &tmp = *tmpp; ColorSpinorField &rSloppy = *rSloppyp; ColorSpinorField &xSloppy = param.use_sloppy_partial_accumulator ? *xSloppyp : x; @@ -322,13 +320,6 @@ namespace quda { Anorm = sqrt(blas::norm2(r)/b2); } - // for detecting HQ residual stalls - // let |r2/b2| drop to epsilon tolerance * 1e-30, semi-arbitrarily, but - // with the intent of letting the solve grind as long as possible before - // triggering a `NaN`. Ignored for pure double solves because if - // pure double has stability issues, bigger problems are at hand. - const double hq_res_stall_check = is_pure_double ? 0. : uhigh * uhigh * 1e-60; - // compute initial residual double r2 = 0.0; if (advanced_feature && param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { @@ -375,23 +366,9 @@ namespace quda { double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver - double heavy_quark_res = 0.0; // heavy quark res idual - double heavy_quark_res_old = 0.0; // heavy quark residual - - if (use_heavy_quark_res) { - heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); - heavy_quark_res_old = heavy_quark_res; // heavy quark residual - } - const int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual - auto alpha = std::make_unique(Np); double pAp; - // set this to true if maxResIncrease has been exceeded but when we use heavy quark residual we still want to continue the CG - // only used if we use the heavy_quark_res - bool L2breakdown = false; - const double L2breakdown_eps = 100. * uhigh; - if (!param.is_preconditioner) { profile.TPSTOP(QUDA_PROFILE_PREAMBLE); profile.TPSTART(QUDA_PROFILE_COMPUTE); @@ -400,9 +377,9 @@ namespace quda { int k = 0; - PrintStats("CG", k, r2, b2, heavy_quark_res); + PrintStats("CG", k, r2, b2, 0.0); - bool converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); + bool converged = convergenceL2(r2, 0.0, stop, 0.0); ReliableUpdatesParams ru_params; @@ -414,9 +391,7 @@ namespace quda { ru_params.maxResIncrease = param.max_res_increase; ru_params.maxResIncreaseTotal = param.max_res_increase_total; - ru_params.use_heavy_quark_res = use_heavy_quark_res; - ru_params.hqmaxresIncrease = param.max_hq_res_increase; - ru_params.hqmaxresRestartTotal = param.max_hq_res_restart_total; + ru_params.use_heavy_quark_res = false; // since we've removed HQ residual support ReliableUpdates ru(ru_params, r2); @@ -473,13 +448,7 @@ namespace quda { if (advanced_feature) { ru.evaluate(r2_old); // force a reliable update if we are within target tolerance (only if doing reliable updates) - if (convergence(r2, heavy_quark_res, stop, param.tol_hq) && param.delta >= param.tol) ru.set_updateX(); - - if (use_heavy_quark_res and L2breakdown - and (convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq) or (r2 / b2) < hq_res_stall_check) - and param.delta >= param.tol) { - ru.set_updateX(); - } + if (convergenceL2(r2, 0.0, stop, 0.0) && param.delta >= param.tol) ru.set_updateX(); } if (!ru.trigger()) { @@ -507,16 +476,6 @@ namespace quda { } } - if (use_heavy_quark_res && k % heavy_quark_check == 0) { - if (&x != &xSloppy) { - blas::copy(tmp, y); - heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(xSloppy, tmp, rSloppy).z); - } else { - blas::copy(r, rSloppy); - heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x, y, r).z); - } - } - // alternative reliable updates if (advanced_feature) { ru.accumulate_norm(x_update_batch.get_current_alpha()); } } else { @@ -546,53 +505,28 @@ namespace quda { if (advanced_feature) { ru.update_norm(r2, y); } - // calculate new reliable HQ resididual - if (use_heavy_quark_res) heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(y, r).z); - if (advanced_feature) { - if (ru.reliable_break(r2, stop, L2breakdown, L2breakdown_eps)) { break; } - } - - // if L2 broke down already we turn off reliable updates and restart the CG - if (use_heavy_quark_res && ru.reliable_heavy_quark_break(L2breakdown, heavy_quark_res, heavy_quark_res_old, heavy_quark_restart)) { - break; + // needed as a "dummy parameter" to reliable_break. + bool L2breakdown = false; + if (ru.reliable_break(r2, stop, L2breakdown, 0)) { break; } } - if (use_heavy_quark_res and heavy_quark_restart) { - // perform a restart - x_update_batch.reset(); - blas::copy(x_update_batch.get_current_field(), rSloppy); - heavy_quark_restart = false; - } else { - // explicitly restore the orthogonality of the gradient vector - Complex rp = blas::cDotProduct(rSloppy, x_update_batch.get_current_field()) / (r2); - blas::caxpy(-rp, rSloppy, x_update_batch.get_current_field()); + // explicitly restore the orthogonality of the gradient vector + Complex rp = blas::cDotProduct(rSloppy, x_update_batch.get_current_field()) / (r2); + blas::caxpy(-rp, rSloppy, x_update_batch.get_current_field()); - beta = r2 / r2_old; - blas::xpayz(rSloppy, beta, x_update_batch.get_current_field(), x_update_batch.get_next_field()); - } + beta = r2 / r2_old; + blas::xpayz(rSloppy, beta, x_update_batch.get_current_field(), x_update_batch.get_next_field()); ru.reset(r2); - - heavy_quark_res_old = heavy_quark_res; } breakdown = false; k++; - PrintStats("CG", k, r2, b2, heavy_quark_res); - // check convergence, if convergence is satisfied we only need to check that we had a reliable update for the heavy quarks recently - converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); - - // check for recent enough reliable updates of the HQ residual if we use it - if (use_heavy_quark_res) { - // L2 is converged or precision maxed out for L2 - bool L2done = L2breakdown or convergenceL2(r2, heavy_quark_res, stop, param.tol_hq); - // HQ is converged and if we do reliable update the HQ residual has been calculated using a reliable update - bool HQdone = (ru.steps_since_reliable == 0 and param.delta > 0) - and convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq); - converged = L2done and HQdone; - } + PrintStats("CG", k, r2, b2, 0.0); + // check convergence + converged = convergenceL2(r2, 0.0, stop, 0.0); // if we have converged and need to update any trailing solutions if (converged && ru.steps_since_reliable > 0 && !x_update_batch.is_container_full()) { @@ -630,7 +564,7 @@ namespace quda { param.true_res_hq = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); } - PrintSummary("CG", k, r2, b2, stop, param.tol_hq); + PrintSummary("CG", k, r2, b2, stop, 0.0); if (!param.is_preconditioner) { // reset the flops counters From d7304d522aceb5907ad5f4dd8086353ac5d557eb Mon Sep 17 00:00:00 2001 From: Evan Weinberg Date: Tue, 27 Jun 2023 13:19:11 -0700 Subject: [PATCH 11/12] clang-format --- include/invert_quda.h | 2 +- lib/inv_cg_quda.cpp | 76 +++++++++++++++---------------------------- 2 files changed, 27 insertions(+), 51 deletions(-) diff --git a/include/invert_quda.h b/include/invert_quda.h index a472941b7d..9a8b89c611 100644 --- a/include/invert_quda.h +++ b/include/invert_quda.h @@ -792,7 +792,7 @@ namespace quda { * @param out Solution-vector. * @param in Right-hand side. */ - void hqsolve(ColorSpinorField& out, ColorSpinorField& in); + void hqsolve(ColorSpinorField &out, ColorSpinorField &in); }; class CGNE : public CG diff --git a/lib/inv_cg_quda.cpp b/lib/inv_cg_quda.cpp index 392f9b898d..4ef012e0d7 100644 --- a/lib/inv_cg_quda.cpp +++ b/lib/inv_cg_quda.cpp @@ -225,8 +225,7 @@ namespace quda { if (Np < 0 || Np > 16) errorQuda("Invalid value %d for solution_accumulator_pipeline\n", Np); // Determine whether or not we're doing a heavy quark residual - const bool use_heavy_quark_res = - (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; + const bool use_heavy_quark_res = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; if (use_heavy_quark_res) { hqsolve(x, b); @@ -236,8 +235,7 @@ namespace quda { // This check is pointless in the current version of the code, but it's being proactively added // just in case HQ residual solves are split into a separate file - if (use_heavy_quark_res) - errorQuda("The \"vanilla\" CG solver does not support HQ residual solves"); + if (use_heavy_quark_res) errorQuda("The \"vanilla\" CG solver does not support HQ residual solves"); // whether to select alternative reliable updates // if we're computing the heavy quark residual, force "traditional" reliable updates @@ -580,21 +578,20 @@ namespace quda { } // Separate HQ residual codepath - void CG::hqsolve(ColorSpinorField &x, ColorSpinorField &b) { + void CG::hqsolve(ColorSpinorField &x, ColorSpinorField &b) + { logQuda(QUDA_VERBOSE, "Performing a HQ CG solve\n"); // Verbose errors: HQ solves won't support deflation, pipelining - if (param.deflate) - errorQuda("HQ solves don't support deflation"); - if (param.is_preconditioner) - errorQuda("HQ solves cannot be preconditioners"); + if (param.deflate) errorQuda("HQ solves don't support deflation"); + if (param.is_preconditioner) errorQuda("HQ solves cannot be preconditioners"); // Non-terminal errors: HQ solves don't support advanced reliable updates if (param.use_alternative_reliable) - logQuda(QUDA_SUMMARIZE, "HQ solves don't support alternative reliable updates, reverting to traditional reliable updates\n"); - if (param.pipeline) - logQuda(QUDA_SUMMARIZE, "HQ solves don't support pipelining, disabling..."); + logQuda(QUDA_SUMMARIZE, + "HQ solves don't support alternative reliable updates, reverting to traditional reliable updates\n"); + if (param.pipeline) logQuda(QUDA_SUMMARIZE, "HQ solves don't support pipelining, disabling..."); profile.TPSTART(QUDA_PROFILE_INIT); @@ -625,7 +622,7 @@ namespace quda { csParam.setPrecision(param.precision_sloppy); pp = ColorSpinorField::Create(csParam); App = ColorSpinorField::Create(csParam); - if(param.precision != param.precision_sloppy) { + if (param.precision != param.precision_sloppy) { rSloppyp = ColorSpinorField::Create(csParam); xSloppyp = ColorSpinorField::Create(csParam); } else { @@ -671,7 +668,7 @@ namespace quda { blas::zero(x); if (&x != &xSloppy) blas::zero(xSloppy); - blas::copy(rSloppy,r); + blas::copy(rSloppy, r); blas::copy(p, rSloppy); double r2_old = 0.0; @@ -679,7 +676,7 @@ namespace quda { profile.TPSTOP(QUDA_PROFILE_INIT); profile.TPSTART(QUDA_PROFILE_PREAMBLE); - double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver // compute the initial heavy quark residual double hq_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); @@ -694,7 +691,6 @@ namespace quda { bool L2breakdown = !L2_required; const double L2breakdown_eps = 100. * uhigh; - profile.TPSTOP(QUDA_PROFILE_PREAMBLE); profile.TPSTART(QUDA_PROFILE_COMPUTE); blas::flops = 0; @@ -753,7 +749,7 @@ namespace quda { int rUpdate = 0; int steps_since_reliable = 1; - while ( !converged && k < param.maxiter ) { + while (!converged && k < param.maxiter) { matSloppy(Ap, p); r2_old = r2; @@ -764,8 +760,8 @@ namespace quda { // here we are deploying the alternative beta computation auto cg_norm = blas::axpyCGNorm(-alpha, Ap, rSloppy); - r2 = cg_norm.x; // (r_new, r_new) - sigma = cg_norm.y >= 0.0 ? cg_norm.y : r2; // use r2 if (r_k+1, r_k+1-r_k) breaks + r2 = cg_norm.x; // (r_new, r_new) + sigma = cg_norm.y >= 0.0 ? cg_norm.y : r2; // use r2 if (r_k+1, r_k+1-r_k) breaks rNorm = sqrt(r2); // If the iterated norm has dropped by more than a factor of delta, trigger @@ -787,12 +783,6 @@ namespace quda { // Has the iterated norm dropped by a factor of delta relative to the largest the // iterated norm has been since the last update? updateR = ((rNorm < param.delta * maxrr && r0Norm <= maxrr) || updateX); - - if (updateX) - logQuda(QUDA_VERBOSE, "Triggered L2 updateX via `(rNorm < param.delta * r0Norm && r0Norm <= maxrx)`\n"); - - if (updateR) - logQuda(QUDA_VERBOSE, "Triggered L2 updateR via `((rNorm < param.delta * maxrr && r0Norm <= maxrr) || updateX)`\n"); } else { // hqresidual based reliable update if (hq_res > maxrx) maxrx = hq_res; @@ -808,31 +798,20 @@ namespace quda { // Has the iterated heavy quark residual dropped by a factor of delta relative // to the largest the iterated norm has been since the last update? updateR = ((hq_res < param.delta * param.delta * maxrr && hq0Res <= maxrr) || updateX); - - if (updateX) - logQuda(QUDA_VERBOSE, "Triggered HQ updateX via `(hq_res < param.delta * hq0Res && hq0Res <= maxrx)`\n"); - - if (updateR) - logQuda(QUDA_VERBOSE, "Triggered HQ updateR via `((hq_res < param.delta * maxrr && hq0Res <= maxrr) || updateX)`\n"); } // force a reliable update if we are within target tolerance (only if doing reliable updates) - if (convergence(r2, hq_res, stop, param.tol_hq) && param.delta >= param.tol) { - updateX = true; - logQuda(QUDA_VERBOSE, "Triggered updateX via convergence && delta > tol check\n"); - } + if (convergence(r2, hq_res, stop, param.tol_hq) && param.delta >= param.tol) updateX = true; // force a reliable update based on the HQ residual if L2 breakdown has already happened - if (L2breakdown && (convergenceHQ(r2, hq_res, stop, param.tol_hq) || (r2 / b2) < hq_res_stall_check) && - param.delta >= param.tol) { + if (L2breakdown && (convergenceHQ(r2, hq_res, stop, param.tol_hq) || (r2 / b2) < hq_res_stall_check) + && param.delta >= param.tol) updateX = true; - logQuda(QUDA_VERBOSE, "Triggered updateX via L2breakdown condition\n"); - } if (!(updateR || updateX)) { // No reliable update needed - beta = sigma / r2_old; // use the alternative beta computation + beta = sigma / r2_old; // use the alternative beta computation blas::axpyZpbx(alpha, p, xSloppy, rSloppy, beta); @@ -866,8 +845,7 @@ namespace quda { // Check and see if we're "done" with the L2 norm. This could be because // we were already done with it, we never needed it, or the L2 norm has finally converged. - if (!L2breakdown && convergenceL2(r2, hq_res, stop, param.tol_hq)) - L2breakdown = true; + if (!L2breakdown && convergenceL2(r2, hq_res, stop, param.tol_hq)) L2breakdown = true; // Depending on if we're still grinding on the L2 norm or if we've moved along to just // the HQ norm, we reset the baselines for reliable updates that get used on the @@ -915,14 +893,15 @@ namespace quda { resIncreaseTotal++; // ...tell the world about it too. - warningQuda("new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)", - sqrt(r2), r0Norm, resIncreaseTotal); + warningQuda( + "new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)", + sqrt(r2), r0Norm, resIncreaseTotal); // If the norm is ridiculously small in magnitude, we've exceeded the maximums on various // ways we keep track of residual increases, or the L2 norm converged, we say "we're good here" // and move over to the HQ residual norm. if (rNorm < L2breakdown_eps || resIncrease > param.max_res_increase - || resIncreaseTotal > param.max_res_increase_total || r2 < stop) { + || resIncreaseTotal > param.max_res_increase_total || r2 < stop) { L2breakdown = true; warningQuda("L2 breakdown %e, %e", rNorm, L2breakdown_eps); @@ -930,7 +909,6 @@ namespace quda { // from the L2 norm over to the HQ residual. maxrr = hq_res; maxrx = hq_res; - } } else { // This variable counts the number of times in a row the computed residual has gone up, @@ -945,8 +923,8 @@ namespace quda { hqresIncrease++; // Tell the world about it - warningQuda("CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", - hq_res, hq0Res); + warningQuda("CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", hq_res, + hq0Res); // And if it's increased too many times in a row, flunk out of the solve. if (hqresIncrease > param.max_hq_res_increase) { @@ -1001,7 +979,6 @@ namespace quda { // HQ is converged and if we do reliable update the HQ residual has been calculated using a reliable update bool HQdone = (steps_since_reliable == 0 && param.delta > 0) && convergenceHQ(r2, hq_res, stop, param.tol_hq); converged = L2done && HQdone; - } blas::copy(x, xSloppy); @@ -1035,7 +1012,6 @@ namespace quda { matPrecon.flops(); profile.TPSTOP(QUDA_PROFILE_EPILOGUE); - } // use BlockCGrQ algortithm or BlockCG (with / without GS, see BLOCKCG_GS option) From 3efa9913c20ddfe7ffc0afbbe86e679d3437df60 Mon Sep 17 00:00:00 2001 From: Evan Weinberg Date: Mon, 3 Jul 2023 12:47:50 -0700 Subject: [PATCH 12/12] Respond to review feedback --- lib/inv_cg_quda.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/inv_cg_quda.cpp b/lib/inv_cg_quda.cpp index 4ef012e0d7..b88193083c 100644 --- a/lib/inv_cg_quda.cpp +++ b/lib/inv_cg_quda.cpp @@ -238,7 +238,6 @@ namespace quda { if (use_heavy_quark_res) errorQuda("The \"vanilla\" CG solver does not support HQ residual solves"); // whether to select alternative reliable updates - // if we're computing the heavy quark residual, force "traditional" reliable updates bool alternative_reliable = param.use_alternative_reliable; /** When CG is used as a preconditioner, and we disable the `advanced features`, these features are turned off: