From c02b5269ecf8a37464b0648d6c7d7b9fa5450263 Mon Sep 17 00:00:00 2001 From: lxvm Date: Sun, 10 Mar 2024 14:59:28 -0400 Subject: [PATCH] better handling of unknown symmetry representations --- src/brillouin.jl | 95 +++++++++++++++++++++++++++++++++++++++--------- src/fourier.jl | 10 ++++- 2 files changed, 86 insertions(+), 19 deletions(-) diff --git a/src/brillouin.jl b/src/brillouin.jl index 64c9ff4..84b1b71 100644 --- a/src/brillouin.jl +++ b/src/brillouin.jl @@ -94,7 +94,7 @@ Transform `x` by the symmetries of the parametrization used to reduce the domain, thus mapping the value of `x` on the parametrization to the full domain. """ symmetrize(f, bz, xs...) = map(x -> symmetrize(f, bz, x), xs) -symmetrize(f, bz, x) = symmetrize_(SymRep(f), bz, x) +symmetrize(f, bz, x) = symmetrize_(f isa AbstractSymRep ? f : SymRep(f), bz, x) symmetrize(f, bz, x::TrivialRepType) = symmetrize_(TrivialRep(), bz, x) @@ -105,10 +105,7 @@ Transform `x` under representation `rep` using the symmetries in `bz` to obtain the result of an integral on the FBZ from `x`, which was calculated on the IBZ. """ symmetrize_(::TrivialRep, bz::SymmetricBZ, x) = nsyms(bz)*x -function symmetrize_(::UnknownRep, ::SymmetricBZ, x) - @warn "Symmetric BZ detected but the integrand's symmetry representation is unknown. Define a trait for your integrand by extending SymRep" - x -end +symmetrize_(::UnknownRep, ::SymmetricBZ, x) = x symmetrize(_, ::FullBZ, x) = x symmetrize(_, ::FullBZ, x::TrivialRepType) = x @@ -116,6 +113,36 @@ symmetrize(_, ::FullBZ, x::TrivialRepType) = x symmetrize(f, bz, x::AuxValue) = AuxValue(symmetrize(f, bz, x.val, x.aux)...) symmetrize(_, ::FullBZ, x::AuxValue) = x +struct SymmetricRule{R,U,B} + rule::R + rep::U + bz::B +end + +Base.getindex(r::SymmetricRule, i) = getindex(r.rule, i) +Base.eltype(::Type{SymmetricRule{R,U,B}}) where {R,U,B} = eltype(R) +Base.length(r::SymmetricRule) = length(r.rule) +Base.iterate(r::SymmetricRule, args...) = iterate(r.rule, args...) +rule_type(r::SymmetricRule) = rule_type(r.rule) +function (r::SymmetricRule)(f::F, args...) where {F} + out = r.rule(f, args...) + return symmetrize(r.rep, r.bz, out) +end + +struct SymmetricRuleDef{R,U,B} + rule::R + rep::U + bz::B +end + +AutoSymPTR.nsyms(r::SymmetricRuleDef) = AutoSymPTR.nsyms(r.r) +function (r::SymmetricRuleDef)(::Type{T}, v::Val{d}) where {T,d} + return SymmetricRule(r.rule(T, v), r.rep, r.bz) +end +function AutoSymPTR.nextrule(r::SymmetricRule, ruledef::SymmetricRuleDef) + return SymmetricRule(AutoSymPTR.nextrule(r.rule, ruledef.rule), ruledef.rep, ruledef.bz) +end + # Here we provide utilities to build BZs """ @@ -289,6 +316,7 @@ All integration problems on the BZ get rescaled to fractional coordinates so tha Brillouin zone becomes `[0,1]^d`, and integrands should have this periodicity. If the integrand depends on the Brillouin zone basis, then it may have to be transformed to the Cartesian coordinates as a post-processing step. +These algorithms also use the symmetries of the Brillouin zone and the integrand. """ abstract type AutoBZAlgorithm <: IntegralAlgorithm end @@ -297,7 +325,16 @@ function init_cacheval(f, bz::SymmetricBZ, p, bzalg::AutoBZAlgorithm) return init_cacheval(f, dom, p, alg) end -function do_solve(f, bz::SymmetricBZ, p, bzalg::AutoBZAlgorithm, cacheval; _kws...) +function do_solve(f, bz::SymmetricBZ, p, bzalg::AutoBZAlgorithm, cacheval; kws...) + do_solve_autobz(bz_to_standard, f, bz, p, bzalg, cacheval; kws...) +end + +const WARN_UNKNOWN_SYMMETRY = """ +A symmetric BZ was used with an integrand whose symmetry representation is unknown. +For correctness, the calculation will be repeated on the full BZ. +However, it is better either to integrate without symmetries or to use symmetries by extending SymRep for your type. +""" +function do_solve_autobz(bz_to_standard, f, bz, p, bzalg::AutoBZAlgorithm, cacheval; _kws...) bz_, dom, alg = bz_to_standard(bz, bzalg) j = abs(det(bz_.B)) # rescale tolerance to (I)BZ coordinate and get the right number of digits @@ -305,6 +342,13 @@ function do_solve(f, bz::SymmetricBZ, p, bzalg::AutoBZAlgorithm, cacheval; _kws. kws_ = haskey(kws, :abstol) ? merge(kws, (abstol=kws.abstol / (j * nsyms(bz_)),)) : kws sol = do_solve(f, dom, p, alg, cacheval; kws_...) + # TODO find a way to throw a warning when constructing the problem instead of after a solve + SymRep(f) isa UnknownRep && !(bz_ isa FullBZ) && !(sol.u isa TrivialRepType) && begin + @warn WARN_UNKNOWN_SYMMETRY + fbz = SymmetricBZ(bz_.A, bz_.B, lattice_bz_limits(bz_.B), nothing) + _cacheval = init_cacheval(f, fbz, p, bzalg) + return do_solve(f, fbz, p, bzalg, _cacheval; _kws...) + end val = j*symmetrize(f, bz_, sol.u) err = sol.resid === nothing ? nothing : j*symmetrize(f, bz_, sol.resid) return IntegralSolution(val, err, sol.retcode, sol.numevals) @@ -374,6 +418,30 @@ end function bz_to_standard(bz::SymmetricBZ, alg::AutoPTR) return bz, canonical_ptr_basis(bz.B), AutoSymPTRJL(norm=alg.norm, a=alg.a, nmin=alg.nmin, nmax=alg.nmax, n₀=alg.n₀, Δn=alg.Δn, keepmost=alg.keepmost, syms=bz.syms, nthreads=alg.nthreads) end +function init_cacheval(f, bz::SymmetricBZ, p, bzalg::AutoPTR) + bz_, dom, alg = bz_to_standard(bz, bzalg) + f isa NestedBatchIntegrand && throw(ArgumentError("AutoSymPTRJL doesn't support nested batching")) + rule = SymmetricRuleDef(init_rule(dom, alg), SymRep(f), bz_) + cache = AutoSymPTR.alloc_cache(eltype(dom), Val(ndims(dom)), rule) + buffer = init_buffer(f, alg.nthreads) + return (rule=rule, cache=cache, buffer=buffer) +end +function do_solve_autobz(bz_to_standard, f, bz, p, bzalg::AutoPTR, cacheval; _kws...) + bz_, dom, alg = bz_to_standard(bz, bzalg) + j = abs(det(bz_.B)) # rescale tolerance to (I)BZ coordinate and get the right number of digits + kws = NamedTuple(_kws) + kws_ = haskey(kws, :abstol) ? merge(kws, (abstol=kws.abstol / j,)) : kws + + sol = do_solve(f, dom, p, alg, cacheval; kws_...) + # TODO find a way to throw a warning when constructing the problem instead of after a solve + SymRep(f) isa UnknownRep && !(bz_ isa FullBZ) && !(sol.u isa TrivialRepType) && begin + @warn WARN_UNKNOWN_SYMMETRY + fbz = SymmetricBZ(bz_.A, bz_.B, lattice_bz_limits(bz_.B), nothing) + _cacheval = init_cacheval(f, fbz, p, bzalg) + return do_solve(f, fbz, p, bzalg, _cacheval; _kws...) + end + return IntegralSolution(sol.u * j, sol.resid * j, sol.retcode, sol.numevals) +end """ TAI(; norm=norm, initdivs=1) @@ -421,18 +489,11 @@ the naive `nested_quadgk`. """ AutoPTR_IAI(; reltol=1.0, ptr=AutoPTR(), iai=IAI(), kws...) = AbsoluteEstimate(ptr, iai; reltol=reltol, kws...) - -# do not export this, just for internal use -struct AutoBZEvalCounter{B,D,A} <: AutoBZAlgorithm - bz::B - dom::D - alg::A -end - -function bz_to_standard(bz::SymmetricBZ, alg::AutoBZEvalCounter) - return alg.bz, alg.dom, EvalCounter(alg.alg) +function count_bz_to_standard(bz, alg) + _bz, dom, _alg = bz_to_standard(bz, alg) + return _bz, dom, EvalCounter(_alg) end function do_solve(f, bz::SymmetricBZ, p, alg::EvalCounter{<:AutoBZAlgorithm}, cacheval; kws...) - return do_solve(f, bz, p, AutoBZEvalCounter(bz_to_standard(bz, alg.alg)...), cacheval; kws...) + return do_solve_autobz(count_bz_to_standard, f, bz, p, alg.alg, cacheval; kws...) end diff --git a/src/fourier.jl b/src/fourier.jl index dd18353..98a0457 100644 --- a/src/fourier.jl +++ b/src/fourier.jl @@ -351,7 +351,13 @@ function init_cacheval(f::FourierIntegrand, dom::Basis, p, alg::AutoSymPTRJL) buffer = init_buffer(f, alg.nthreads) return (rule=rule, cache=cache, buffer=buffer) end - +function init_cacheval(f::FourierIntegrand, bz::SymmetricBZ, p, bzalg::AutoPTR) + bz_, dom, alg = bz_to_standard(bz, bzalg) + rule = SymmetricRuleDef(init_fourier_rule(f.w, dom, alg), SymRep(f), bz_) + cache = AutoSymPTR.alloc_cache(eltype(dom), Val(ndims(dom)), rule) + buffer = init_buffer(f, alg.nthreads) + return (rule=rule, cache=cache, buffer=buffer) +end function init_fourier_rule(s::AbstractFourierSeries, bz::SymmetricBZ, alg::PTR) dom = Basis(bz.B) return FourierMonkhorstPack(s, eltype(dom), Val(ndims(dom)), alg.npt, bz.syms) @@ -520,5 +526,5 @@ end # method needed to resolve ambiguities function do_solve(f::FourierIntegrand, bz::SymmetricBZ, p, alg::EvalCounter{<:AutoBZAlgorithm}, cacheval; kws...) - return do_solve(f, bz, p, AutoBZEvalCounter(bz_to_standard(bz, alg.alg)...), cacheval; kws...) + return do_solve_autobz(count_bz_to_standard, f, bz, p, alg.alg, cacheval; kws...) end