diff --git a/include/invert_quda.h b/include/invert_quda.h index b2699ad3b7..9a8b89c611 100644 --- a/include/invert_quda.h +++ b/include/invert_quda.h @@ -785,6 +785,14 @@ namespace quda { void blocksolve(ColorSpinorField& out, ColorSpinorField& in); virtual bool hermitian() { return true; } /** CG is only for Hermitian systems */ + + protected: + /** + * @brief Separate codepath for performing a "simpler" CG solve when a heavy quark residual is requested. + * @param out Solution-vector. + * @param in Right-hand side. + */ + void hqsolve(ColorSpinorField &out, ColorSpinorField &in); }; class CGNE : public CG diff --git a/lib/inv_cg_quda.cpp b/lib/inv_cg_quda.cpp index bfce1b422b..b88193083c 100644 --- a/lib/inv_cg_quda.cpp +++ b/lib/inv_cg_quda.cpp @@ -224,8 +224,18 @@ namespace quda { const int Np = (param.solution_accumulator_pipeline == 0 ? 1 : param.solution_accumulator_pipeline); if (Np < 0 || Np > 16) errorQuda("Invalid value %d for solution_accumulator_pipeline\n", Np); - // Detect whether this is a pure double solve or not; informs the necessity of some stability checks - bool is_pure_double = (param.precision == QUDA_DOUBLE_PRECISION && param.precision_sloppy == QUDA_DOUBLE_PRECISION); + // Determine whether or not we're doing a heavy quark residual + const bool use_heavy_quark_res = (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; + + if (use_heavy_quark_res) { + hqsolve(x, b); + if (param.is_preconditioner) commGlobalReductionPop(); + return; + } + + // This check is pointless in the current version of the code, but it's being proactively added + // just in case HQ residual solves are split into a separate file + if (use_heavy_quark_res) errorQuda("The \"vanilla\" CG solver does not support HQ residual solves"); // whether to select alternative reliable updates bool alternative_reliable = param.use_alternative_reliable; @@ -269,8 +279,6 @@ namespace quda { param.use_sloppy_partial_accumulator = false; } - // temporary fields - tmpp = ColorSpinorField::Create(csParam); init = true; } @@ -293,7 +301,6 @@ namespace quda { ColorSpinorField &r = *rp; ColorSpinorField &y = *yp; ColorSpinorField &Ap = *App; - ColorSpinorField &tmp = *tmpp; ColorSpinorField &rSloppy = *rSloppyp; ColorSpinorField &xSloppy = param.use_sloppy_partial_accumulator ? *xSloppyp : x; @@ -310,13 +317,6 @@ namespace quda { Anorm = sqrt(blas::norm2(r)/b2); } - // for detecting HQ residual stalls - // let |r2/b2| drop to epsilon tolerance * 1e-30, semi-arbitrarily, but - // with the intent of letting the solve grind as long as possible before - // triggering a `NaN`. Ignored for pure double solves because if - // pure double has stability issues, bigger problems are at hand. - const double hq_res_stall_check = is_pure_double ? 0. : uhigh * uhigh * 1e-60; - // compute initial residual double r2 = 0.0; if (advanced_feature && param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { @@ -356,10 +356,6 @@ namespace quda { blas::xpayz(rSloppy, beta, x_update_batch.get_current_field(), x_update_batch.get_current_field()); } - const bool use_heavy_quark_res = - (param.residual_type & QUDA_HEAVY_QUARK_RESIDUAL) ? true : false; - bool heavy_quark_restart = false; - if (!param.is_preconditioner) { profile.TPSTOP(QUDA_PROFILE_INIT); profile.TPSTART(QUDA_PROFILE_PREAMBLE); @@ -367,23 +363,9 @@ namespace quda { double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver - double heavy_quark_res = 0.0; // heavy quark res idual - double heavy_quark_res_old = 0.0; // heavy quark residual - - if (use_heavy_quark_res) { - heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); - heavy_quark_res_old = heavy_quark_res; // heavy quark residual - } - const int heavy_quark_check = param.heavy_quark_check; // how often to check the heavy quark residual - auto alpha = std::make_unique(Np); double pAp; - // set this to true if maxResIncrease has been exceeded but when we use heavy quark residual we still want to continue the CG - // only used if we use the heavy_quark_res - bool L2breakdown = false; - const double L2breakdown_eps = 100. * uhigh; - if (!param.is_preconditioner) { profile.TPSTOP(QUDA_PROFILE_PREAMBLE); profile.TPSTART(QUDA_PROFILE_COMPUTE); @@ -392,9 +374,9 @@ namespace quda { int k = 0; - PrintStats("CG", k, r2, b2, heavy_quark_res); + PrintStats("CG", k, r2, b2, 0.0); - bool converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); + bool converged = convergenceL2(r2, 0.0, stop, 0.0); ReliableUpdatesParams ru_params; @@ -406,9 +388,7 @@ namespace quda { ru_params.maxResIncrease = param.max_res_increase; ru_params.maxResIncreaseTotal = param.max_res_increase_total; - ru_params.use_heavy_quark_res = use_heavy_quark_res; - ru_params.hqmaxresIncrease = param.max_hq_res_increase; - ru_params.hqmaxresRestartTotal = param.max_hq_res_restart_total; + ru_params.use_heavy_quark_res = false; // since we've removed HQ residual support ReliableUpdates ru(ru_params, r2); @@ -465,13 +445,7 @@ namespace quda { if (advanced_feature) { ru.evaluate(r2_old); // force a reliable update if we are within target tolerance (only if doing reliable updates) - if (convergence(r2, heavy_quark_res, stop, param.tol_hq) && param.delta >= param.tol) ru.set_updateX(); - - if (use_heavy_quark_res and L2breakdown - and (convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq) or (r2 / b2) < hq_res_stall_check) - and param.delta >= param.tol) { - ru.set_updateX(); - } + if (convergenceL2(r2, 0.0, stop, 0.0) && param.delta >= param.tol) ru.set_updateX(); } if (!ru.trigger()) { @@ -499,16 +473,6 @@ namespace quda { } } - if (use_heavy_quark_res && k % heavy_quark_check == 0) { - if (&x != &xSloppy) { - blas::copy(tmp, y); - heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(xSloppy, tmp, rSloppy).z); - } else { - blas::copy(r, rSloppy); - heavy_quark_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x, y, r).z); - } - } - // alternative reliable updates if (advanced_feature) { ru.accumulate_norm(x_update_batch.get_current_alpha()); } } else { @@ -538,53 +502,28 @@ namespace quda { if (advanced_feature) { ru.update_norm(r2, y); } - // calculate new reliable HQ resididual - if (use_heavy_quark_res) heavy_quark_res = sqrt(blas::HeavyQuarkResidualNorm(y, r).z); - if (advanced_feature) { - if (ru.reliable_break(r2, stop, L2breakdown, L2breakdown_eps)) { break; } + // needed as a "dummy parameter" to reliable_break. + bool L2breakdown = false; + if (ru.reliable_break(r2, stop, L2breakdown, 0)) { break; } } - // if L2 broke down already we turn off reliable updates and restart the CG - if (use_heavy_quark_res && ru.reliable_heavy_quark_break(L2breakdown, heavy_quark_res, heavy_quark_res_old, heavy_quark_restart)) { - break; - } + // explicitly restore the orthogonality of the gradient vector + Complex rp = blas::cDotProduct(rSloppy, x_update_batch.get_current_field()) / (r2); + blas::caxpy(-rp, rSloppy, x_update_batch.get_current_field()); - if (use_heavy_quark_res and heavy_quark_restart) { - // perform a restart - x_update_batch.reset(); - blas::copy(x_update_batch.get_current_field(), rSloppy); - heavy_quark_restart = false; - } else { - // explicitly restore the orthogonality of the gradient vector - Complex rp = blas::cDotProduct(rSloppy, x_update_batch.get_current_field()) / (r2); - blas::caxpy(-rp, rSloppy, x_update_batch.get_current_field()); - - beta = r2 / r2_old; - blas::xpayz(rSloppy, beta, x_update_batch.get_current_field(), x_update_batch.get_next_field()); - } + beta = r2 / r2_old; + blas::xpayz(rSloppy, beta, x_update_batch.get_current_field(), x_update_batch.get_next_field()); ru.reset(r2); - - heavy_quark_res_old = heavy_quark_res; } breakdown = false; k++; - PrintStats("CG", k, r2, b2, heavy_quark_res); - // check convergence, if convergence is satisfied we only need to check that we had a reliable update for the heavy quarks recently - converged = convergence(r2, heavy_quark_res, stop, param.tol_hq); - - // check for recent enough reliable updates of the HQ residual if we use it - if (use_heavy_quark_res) { - // L2 is converged or precision maxed out for L2 - bool L2done = L2breakdown or convergenceL2(r2, heavy_quark_res, stop, param.tol_hq); - // HQ is converged and if we do reliable update the HQ residual has been calculated using a reliable update - bool HQdone = (ru.steps_since_reliable == 0 and param.delta > 0) - and convergenceHQ(r2, heavy_quark_res, stop, param.tol_hq); - converged = L2done and HQdone; - } + PrintStats("CG", k, r2, b2, 0.0); + // check convergence + converged = convergenceL2(r2, 0.0, stop, 0.0); // if we have converged and need to update any trailing solutions if (converged && ru.steps_since_reliable > 0 && !x_update_batch.is_container_full()) { @@ -622,7 +561,7 @@ namespace quda { param.true_res_hq = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); } - PrintSummary("CG", k, r2, b2, stop, param.tol_hq); + PrintSummary("CG", k, r2, b2, stop, 0.0); if (!param.is_preconditioner) { // reset the flops counters @@ -637,6 +576,443 @@ namespace quda { if (param.is_preconditioner) commGlobalReductionPop(); } + // Separate HQ residual codepath + void CG::hqsolve(ColorSpinorField &x, ColorSpinorField &b) + { + + logQuda(QUDA_VERBOSE, "Performing a HQ CG solve\n"); + + // Verbose errors: HQ solves won't support deflation, pipelining + if (param.deflate) errorQuda("HQ solves don't support deflation"); + if (param.is_preconditioner) errorQuda("HQ solves cannot be preconditioners"); + + // Non-terminal errors: HQ solves don't support advanced reliable updates + if (param.use_alternative_reliable) + logQuda(QUDA_SUMMARIZE, + "HQ solves don't support alternative reliable updates, reverting to traditional reliable updates\n"); + if (param.pipeline) logQuda(QUDA_SUMMARIZE, "HQ solves don't support pipelining, disabling..."); + + profile.TPSTART(QUDA_PROFILE_INIT); + + double b2 = blas::norm2(b); + + // Detect whether this is a pure double solve or not; informs the necessity of some stability checks + bool is_pure_double = (param.precision == QUDA_DOUBLE_PRECISION && param.precision_sloppy == QUDA_DOUBLE_PRECISION); + + bool heavy_quark_restart = false; + + // Check to see that we're not trying to invert on a zero-field source + if (b2 == 0 && param.compute_null_vector == QUDA_COMPUTE_NULL_VECTOR_NO) { + profile.TPSTOP(QUDA_PROFILE_INIT); + printfQuda("Warning: inverting on zero-field source\n"); + x = b; + param.true_res = 0.0; + param.true_res_hq = 0.0; + return; + } + + if (!init) { + ColorSpinorParam csParam(x); + csParam.create = QUDA_NULL_FIELD_CREATE; + rp = ColorSpinorField::Create(csParam); + yp = ColorSpinorField::Create(csParam); + + // sloppy fields + csParam.setPrecision(param.precision_sloppy); + pp = ColorSpinorField::Create(csParam); + App = ColorSpinorField::Create(csParam); + if (param.precision != param.precision_sloppy) { + rSloppyp = ColorSpinorField::Create(csParam); + xSloppyp = ColorSpinorField::Create(csParam); + } else { + rSloppyp = rp; + param.use_sloppy_partial_accumulator = false; + } + + // temporary fields + tmpp = ColorSpinorField::Create(csParam); + init = true; + } + + ColorSpinorField &r = *rp; + ColorSpinorField &y = *yp; + ColorSpinorField &p = *pp; + ColorSpinorField &Ap = *App; + ColorSpinorField &tmp = *tmpp; + ColorSpinorField &rSloppy = *rSloppyp; + ColorSpinorField &xSloppy = param.use_sloppy_partial_accumulator ? *xSloppyp : x; + + // for detecting HQ residual stalls + // let |r2/b2| drop to epsilon tolerance * 1e-30, semi-arbitrarily, but + // with the intent of letting the solve grind as long as possible before + // triggering a `NaN`. Ignored for pure double solves because if + // pure double has stability issues, bigger problems are at hand. + const double uhigh = precisionEpsilon(); // solver precision + const double hq_res_stall_check = is_pure_double ? 0. : uhigh * uhigh * 1e-60; + + // compute initial residual + double r2 = 0.0; + if (param.use_init_guess == QUDA_USE_INIT_GUESS_YES) { + // Compute r = b - A * x + mat(r, x); + r2 = blas::xmyNorm(b, r); + if (b2 == 0) b2 = r2; + // y contains the original guess. + blas::copy(y, x); + } else { + if (&r != &b) blas::copy(r, b); + r2 = b2; + blas::zero(y); + } + + blas::zero(x); + if (&x != &xSloppy) blas::zero(xSloppy); + blas::copy(rSloppy, r); + blas::copy(p, rSloppy); + + double r2_old = 0.0; + + profile.TPSTOP(QUDA_PROFILE_INIT); + profile.TPSTART(QUDA_PROFILE_PREAMBLE); + + double stop = stopping(param.tol, b2, param.residual_type); // stopping condition of solver + + // compute the initial heavy quark residual + double hq_res = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + + double alpha, beta, sigma, pAp; + + // Whether or not we also need to compute the L2 norm + const bool L2_required = param.residual_type & (QUDA_L2_RELATIVE_RESIDUAL | QUDA_L2_ABSOLUTE_RESIDUAL); + + // set L2breakdown to be immediately true if we aren't requesting an L2 norm, alternatively, + // it only gets set to true after some heuristics suggest the L2 norm has "stalled out" + bool L2breakdown = !L2_required; + const double L2breakdown_eps = 100. * uhigh; + + profile.TPSTOP(QUDA_PROFILE_PREAMBLE); + profile.TPSTART(QUDA_PROFILE_COMPUTE); + blas::flops = 0; + + int k = 0; + + PrintStats("CG", k, r2, b2, hq_res); + + bool converged = convergence(r2, hq_res, stop, param.tol_hq); + + // Various parameters related to restarts + + // Trackers for the L2 norm: + // rNorm: current iterated |r| + // r0Norm: computed |r| at the last reliable update + double rNorm = sqrt(r2); + double r0Norm = rNorm; + + // If the computed |r| goes above r0Norm between reliable updates, + // update this ceiling. This goes into "R" type reliable updates. + double maxrx = L2breakdown ? hq_res : rNorm; + double maxrr = L2breakdown ? hq_res : rNorm; + + // Triggers for explicitly counting residual updates and checking for L2breakdown. + // * updateX broadly maps to if the iterated residual has dropped by a factor of delta + // relative to the previously re-computed residual. + // * updateR broadly maps to if the iterated residual has dropped by a factor of delta + // relative to the max of the previously re-computed residual and all iterated residuals + // since the last reliable update. + bool updateX = false; + bool updateR = false; + + // Counter for the number of times in a row the computed residual has jumped above the + // previously computed residual. + int resIncrease = 0; + + // Counter for the total number of times the computed residual has increased above the previously + // computed residual, independent of when it happened. + int resIncreaseTotal = 0; + + // Trackers for the HQ residual + // hq0Res: computed HQ residual at the last reliable updated + double hq0Res = hq_res; + + // Counter for the number of times in a row the computed heavy quark residual has + // jumped above the previously computed heavy quark residual. + int hqresIncrease = 0; + + // Counter for the total number of times a reliable updated based on the heavy quark residual + // has been triggered. + int hqresRestartTotal = 0; + + // Count the steps since a reliable update and the total number of reliable updates. + // The steps since a reliable update is also used to make sure final convergence is + // based on the computed residual and not the iterated residual. + int rUpdate = 0; + int steps_since_reliable = 1; + + while (!converged && k < param.maxiter) { + matSloppy(Ap, p); + + r2_old = r2; + + pAp = blas::reDotProduct(p, Ap); + + alpha = r2 / pAp; + + // here we are deploying the alternative beta computation + auto cg_norm = blas::axpyCGNorm(-alpha, Ap, rSloppy); + r2 = cg_norm.x; // (r_new, r_new) + sigma = cg_norm.y >= 0.0 ? cg_norm.y : r2; // use r2 if (r_k+1, r_k+1-r_k) breaks + rNorm = sqrt(r2); + + // If the iterated norm has dropped by more than a factor of delta, trigger + // an update. The baseline we check against differs depending on if + // we're still checking the L2 norm, or if that has converged/broken down and we're + // now looking at the HQ residual. + + if (!L2breakdown && (L2_required || convergenceL2(r2, hq_res, stop, param.tol_hq))) { + // L2 based reliable update + + // If the iterated residual norm has gone above the most recent "baseline" norm, + // update the baseline norm. + if (rNorm > maxrx) maxrx = rNorm; + if (rNorm > maxrr) maxrr = rNorm; + + // Has the iterated norm dropped by a factor of delta from the last computed norm? + updateX = (rNorm < param.delta * r0Norm && r0Norm <= maxrx); + + // Has the iterated norm dropped by a factor of delta relative to the largest the + // iterated norm has been since the last update? + updateR = ((rNorm < param.delta * maxrr && r0Norm <= maxrr) || updateX); + } else { + // hqresidual based reliable update + if (hq_res > maxrx) maxrx = hq_res; + if (hq_res > maxrr) maxrr = hq_res; + + // I'm making the decision to use `param.delta` for the hq_res check because + // in some regards it's an L2-esque norm... + + // Has the iterated heavy quark residual dropped by a factor of delta^2 from the last + // computed norm? + updateX = (hq_res < param.delta * param.delta * hq0Res && r0Norm <= maxrx); + + // Has the iterated heavy quark residual dropped by a factor of delta relative + // to the largest the iterated norm has been since the last update? + updateR = ((hq_res < param.delta * param.delta * maxrr && hq0Res <= maxrr) || updateX); + } + + // force a reliable update if we are within target tolerance (only if doing reliable updates) + if (convergence(r2, hq_res, stop, param.tol_hq) && param.delta >= param.tol) updateX = true; + + // force a reliable update based on the HQ residual if L2 breakdown has already happened + if (L2breakdown && (convergenceHQ(r2, hq_res, stop, param.tol_hq) || (r2 / b2) < hq_res_stall_check) + && param.delta >= param.tol) + updateX = true; + + if (!(updateR || updateX)) { + // No reliable update needed + + beta = sigma / r2_old; // use the alternative beta computation + + blas::axpyZpbx(alpha, p, xSloppy, rSloppy, beta); + + if (k % param.heavy_quark_check == 0) { + if (&x != &xSloppy) { + blas::copy(tmp, y); + hq_res = sqrt(blas::xpyHeavyQuarkResidualNorm(xSloppy, tmp, rSloppy).z); + } else { + blas::copy(r, rSloppy); + hq_res = sqrt(blas::xpyHeavyQuarkResidualNorm(x, y, r).z); + } + } + + steps_since_reliable++; + + } else { + // We're performing a reliable update + + // Accumulate p into x, accumulate x into the total solution y, explicitly recompute the residual vector + blas::axpy(alpha, p, xSloppy); + blas::copy(x, xSloppy); // no op when these pointers alias + blas::xpy(x, y); + mat(r, y); + blas::copy(rSloppy, r); // no op when these pointers alias + blas::zero(xSloppy); + + // Recompute the exact residual and heavy quark residual + r2 = blas::xmyNorm(b, r); + rNorm = sqrt(r2); + hq_res = sqrt(blas::HeavyQuarkResidualNorm(y, r).z); + + // Check and see if we're "done" with the L2 norm. This could be because + // we were already done with it, we never needed it, or the L2 norm has finally converged. + if (!L2breakdown && convergenceL2(r2, hq_res, stop, param.tol_hq)) L2breakdown = true; + + // Depending on if we're still grinding on the L2 norm or if we've moved along to just + // the HQ norm, we reset the baselines for reliable updates that get used on the + // *next* iteration. We still need the baselines that were used for this iteration + // for the checks down below. + if (!L2breakdown) { + // If we're still grinding on the L2 norm, the new baseline is the freshly + // recomputed |r|. + maxrr = rNorm; + maxrx = rNorm; + } else { + // If we've made it to the HQ norm, the new baseline is the freshly recomputed + // heavy quark residual + maxrr = hq_res; + maxrx = hq_res; + + // Once we're dealing with the heavy quark residual, we perform a *hard* CG + // restart at every reliable update via setting the search vector `p` to the current + // exact residual vector. + heavy_quark_restart = true; + + // And then we keep track of the fact we're doing a HQ residual reliable update... + hqresRestartTotal++; + warningQuda("CG: Restarting without reliable updates for heavy-quark residual (total #inc %i)", + hqresRestartTotal); + + if (hqresRestartTotal > param.max_hq_res_restart_total) { + // ...and if we've restarted too many times, flunk out of the solve. + warningQuda("CG: solver exiting due to too many heavy quark residual restarts (%i/%i)", hqresRestartTotal, + param.max_hq_res_restart_total); + break; + } + } + + // Check and see if we've reached the limit of the precision. There isn't necessarily + // a great way to do this, so as a proxy we check to see if the new computed residual is + // larger than the computed residual from the last reliable update, and if this is the case + // enough times we throw up our hands, say "we're good here", and switch over to the HQ + // residual. + if (rNorm > r0Norm && updateX && !L2breakdown) { + // Count the number of times in a row this has happened + resIncrease++; + + // And count the total number of times this has happened outright + resIncreaseTotal++; + + // ...tell the world about it too. + warningQuda( + "new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)", + sqrt(r2), r0Norm, resIncreaseTotal); + + // If the norm is ridiculously small in magnitude, we've exceeded the maximums on various + // ways we keep track of residual increases, or the L2 norm converged, we say "we're good here" + // and move over to the HQ residual norm. + if (rNorm < L2breakdown_eps || resIncrease > param.max_res_increase + || resIncreaseTotal > param.max_res_increase_total || r2 < stop) { + L2breakdown = true; + warningQuda("L2 breakdown %e, %e", rNorm, L2breakdown_eps); + + // We also have to do a logic correction, switching the reliable update baselines we set above + // from the L2 norm over to the HQ residual. + maxrr = hq_res; + maxrx = hq_res; + } + } else { + // This variable counts the number of times in a row the computed residual has gone up, + // so if it hasn't gone up this time around we reset this counter. + resIncrease = 0; + } + + // If we've done checking the L2 norm, we do a similar check of if the HQ residual has increased + // for multiple reliable updates in a row. + if (hq_res > hq0Res && updateX && L2breakdown) { + // Count the number of consecutive increases + hqresIncrease++; + + // Tell the world about it + warningQuda("CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", hq_res, + hq0Res); + + // And if it's increased too many times in a row, flunk out of the solve. + if (hqresIncrease > param.max_hq_res_increase) { + warningQuda("CG: solver exiting due to too many heavy quark residual norm increases (%i/%i)", hqresIncrease, + param.max_hq_res_increase); + break; + } + } else { + // This variable counts the number of times in a row the computed heavy quark residual has increased, + // so if it hasn't gone up this time around we reset the counter. + hqresIncrease = 0; + } + + // Depending on if we're in the L2 norm part of the solve or a HQ residual part of the solve + // we "reset" the solve in a different way. + if (heavy_quark_restart) { + // If we're in the HQ residual part of the solve, we just do a hard CG restart. + logQuda(QUDA_SUMMARIZE, "HQ restart == hard CG restart\n"); + blas::copy(p, rSloppy); + heavy_quark_restart = false; + } else { + // If we're still in the L2 norm part of the solve, we explicitly restore + // the orthogonality of the gradient vector, recompute beta, update `p`, and carry on with our lives. + logQuda(QUDA_SUMMARIZE, "Regular restart == explicit gradient vector re-orthogonalization\n"); + Complex rp = blas::cDotProduct(rSloppy, p) / (r2); + blas::caxpy(-rp, rSloppy, p); + + beta = r2 / r2_old; + blas::xpayz(rSloppy, beta, p, p); + } + + // Last, we increment the reliable update counter, reset the number of steps since the last reliable update, + // and reset the cached value of |r| and the heavy quark residual from the time of this + // reliable update. + rUpdate++; + steps_since_reliable = 0; + r0Norm = sqrt(r2); + + hq0Res = hq_res; + } + + k++; + + PrintStats("CG", k, r2, b2, hq_res); + // check convergence, if convergence is satisfied we only need to check that we had a reliable update for the heavy quarks recently + converged = convergence(r2, hq_res, stop, param.tol_hq); + + // check for recent enough reliable updates of the HQ residual if we use it + + // L2 is converged or precision maxed out for L2 + bool L2done = L2breakdown || convergenceL2(r2, hq_res, stop, param.tol_hq); + // HQ is converged and if we do reliable update the HQ residual has been calculated using a reliable update + bool HQdone = (steps_since_reliable == 0 && param.delta > 0) && convergenceHQ(r2, hq_res, stop, param.tol_hq); + converged = L2done && HQdone; + } + + blas::copy(x, xSloppy); + blas::xpy(y, x); + + profile.TPSTOP(QUDA_PROFILE_COMPUTE); + profile.TPSTART(QUDA_PROFILE_EPILOGUE); + + param.secs = profile.Last(QUDA_PROFILE_COMPUTE); + double gflops = (blas::flops + mat.flops() + matSloppy.flops() + matPrecon.flops() + matEig.flops()) * 1e-9; + param.gflops = gflops; + param.iter += k; + + if (k == param.maxiter) warningQuda("Exceeded maximum iterations %d", param.maxiter); + + if (getVerbosity() >= QUDA_VERBOSE) printfQuda("CG: Reliable updates = %d\n", rUpdate); + + if (param.compute_true_res) { + // compute the true residuals + mat(r, x); + param.true_res = sqrt(blas::xmyNorm(b, r) / b2); + param.true_res_hq = sqrt(blas::HeavyQuarkResidualNorm(x, r).z); + } + + PrintSummary("CG", k, r2, b2, stop, param.tol_hq); + + // reset the flops counters + blas::flops = 0; + mat.flops(); + matSloppy.flops(); + matPrecon.flops(); + + profile.TPSTOP(QUDA_PROFILE_EPILOGUE); + } + // use BlockCGrQ algortithm or BlockCG (with / without GS, see BLOCKCG_GS option) #define BCGRQ 1 #if BCGRQ