Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

better handling of unknown symmetry representations #10

Merged
merged 1 commit into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 78 additions & 17 deletions src/brillouin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
domain, thus mapping the value of `x` on the parametrization to the full domain.
"""
symmetrize(f, bz, xs...) = map(x -> symmetrize(f, bz, x), xs)
symmetrize(f, bz, x) = symmetrize_(SymRep(f), bz, x)
symmetrize(f, bz, x) = symmetrize_(f isa AbstractSymRep ? f : SymRep(f), bz, x)

Check warning on line 97 in src/brillouin.jl

View check run for this annotation

Codecov / codecov/patch

src/brillouin.jl#L97

Added line #L97 was not covered by tests
symmetrize(f, bz, x::TrivialRepType) =
symmetrize_(TrivialRep(), bz, x)

Expand All @@ -105,17 +105,44 @@
the result of an integral on the FBZ from `x`, which was calculated on the IBZ.
"""
symmetrize_(::TrivialRep, bz::SymmetricBZ, x) = nsyms(bz)*x
function symmetrize_(::UnknownRep, ::SymmetricBZ, x)
@warn "Symmetric BZ detected but the integrand's symmetry representation is unknown. Define a trait for your integrand by extending SymRep"
x
end
symmetrize_(::UnknownRep, ::SymmetricBZ, x) = x

Check warning on line 108 in src/brillouin.jl

View check run for this annotation

Codecov / codecov/patch

src/brillouin.jl#L108

Added line #L108 was not covered by tests

symmetrize(_, ::FullBZ, x) = x
symmetrize(_, ::FullBZ, x::TrivialRepType) = x

symmetrize(f, bz, x::AuxValue) = AuxValue(symmetrize(f, bz, x.val, x.aux)...)
symmetrize(_, ::FullBZ, x::AuxValue) = x

struct SymmetricRule{R,U,B}
rule::R
rep::U
bz::B
end

Base.getindex(r::SymmetricRule, i) = getindex(r.rule, i)
Base.eltype(::Type{SymmetricRule{R,U,B}}) where {R,U,B} = eltype(R)

Check warning on line 123 in src/brillouin.jl

View check run for this annotation

Codecov / codecov/patch

src/brillouin.jl#L122-L123

Added lines #L122 - L123 were not covered by tests
Base.length(r::SymmetricRule) = length(r.rule)
Base.iterate(r::SymmetricRule, args...) = iterate(r.rule, args...)
rule_type(r::SymmetricRule) = rule_type(r.rule)

Check warning on line 126 in src/brillouin.jl

View check run for this annotation

Codecov / codecov/patch

src/brillouin.jl#L125-L126

Added lines #L125 - L126 were not covered by tests
function (r::SymmetricRule)(f::F, args...) where {F}
out = r.rule(f, args...)
return symmetrize(r.rep, r.bz, out)
end

struct SymmetricRuleDef{R,U,B}
rule::R
rep::U
bz::B
end

AutoSymPTR.nsyms(r::SymmetricRuleDef) = AutoSymPTR.nsyms(r.r)

Check warning on line 138 in src/brillouin.jl

View check run for this annotation

Codecov / codecov/patch

src/brillouin.jl#L138

Added line #L138 was not covered by tests
function (r::SymmetricRuleDef)(::Type{T}, v::Val{d}) where {T,d}
return SymmetricRule(r.rule(T, v), r.rep, r.bz)
end
function AutoSymPTR.nextrule(r::SymmetricRule, ruledef::SymmetricRuleDef)
return SymmetricRule(AutoSymPTR.nextrule(r.rule, ruledef.rule), ruledef.rep, ruledef.bz)
end

# Here we provide utilities to build BZs

"""
Expand Down Expand Up @@ -289,6 +316,7 @@
Brillouin zone becomes `[0,1]^d`, and integrands should have this periodicity. If the
integrand depends on the Brillouin zone basis, then it may have to be transformed to the
Cartesian coordinates as a post-processing step.
These algorithms also use the symmetries of the Brillouin zone and the integrand.
"""
abstract type AutoBZAlgorithm <: IntegralAlgorithm end

Expand All @@ -297,14 +325,30 @@
return init_cacheval(f, dom, p, alg)
end

function do_solve(f, bz::SymmetricBZ, p, bzalg::AutoBZAlgorithm, cacheval; _kws...)
function do_solve(f, bz::SymmetricBZ, p, bzalg::AutoBZAlgorithm, cacheval; kws...)
do_solve_autobz(bz_to_standard, f, bz, p, bzalg, cacheval; kws...)
end

const WARN_UNKNOWN_SYMMETRY = """
A symmetric BZ was used with an integrand whose symmetry representation is unknown.
For correctness, the calculation will be repeated on the full BZ.
However, it is better either to integrate without symmetries or to use symmetries by extending SymRep for your type.
"""
function do_solve_autobz(bz_to_standard, f, bz, p, bzalg::AutoBZAlgorithm, cacheval; _kws...)
bz_, dom, alg = bz_to_standard(bz, bzalg)

j = abs(det(bz_.B)) # rescale tolerance to (I)BZ coordinate and get the right number of digits
kws = NamedTuple(_kws)
kws_ = haskey(kws, :abstol) ? merge(kws, (abstol=kws.abstol / (j * nsyms(bz_)),)) : kws

sol = do_solve(f, dom, p, alg, cacheval; kws_...)
# TODO find a way to throw a warning when constructing the problem instead of after a solve
SymRep(f) isa UnknownRep && !(bz_ isa FullBZ) && !(sol.u isa TrivialRepType) && begin
@warn WARN_UNKNOWN_SYMMETRY
fbz = SymmetricBZ(bz_.A, bz_.B, lattice_bz_limits(bz_.B), nothing)
_cacheval = init_cacheval(f, fbz, p, bzalg)
return do_solve(f, fbz, p, bzalg, _cacheval; _kws...)

Check warning on line 350 in src/brillouin.jl

View check run for this annotation

Codecov / codecov/patch

src/brillouin.jl#L347-L350

Added lines #L347 - L350 were not covered by tests
end
val = j*symmetrize(f, bz_, sol.u)
err = sol.resid === nothing ? nothing : j*symmetrize(f, bz_, sol.resid)
return IntegralSolution(val, err, sol.retcode, sol.numevals)
Expand Down Expand Up @@ -374,6 +418,30 @@
function bz_to_standard(bz::SymmetricBZ, alg::AutoPTR)
return bz, canonical_ptr_basis(bz.B), AutoSymPTRJL(norm=alg.norm, a=alg.a, nmin=alg.nmin, nmax=alg.nmax, n₀=alg.n₀, Δn=alg.Δn, keepmost=alg.keepmost, syms=bz.syms, nthreads=alg.nthreads)
end
function init_cacheval(f, bz::SymmetricBZ, p, bzalg::AutoPTR)
bz_, dom, alg = bz_to_standard(bz, bzalg)
f isa NestedBatchIntegrand && throw(ArgumentError("AutoSymPTRJL doesn't support nested batching"))
rule = SymmetricRuleDef(init_rule(dom, alg), SymRep(f), bz_)
cache = AutoSymPTR.alloc_cache(eltype(dom), Val(ndims(dom)), rule)
buffer = init_buffer(f, alg.nthreads)
return (rule=rule, cache=cache, buffer=buffer)
end
function do_solve_autobz(bz_to_standard, f, bz, p, bzalg::AutoPTR, cacheval; _kws...)
bz_, dom, alg = bz_to_standard(bz, bzalg)
j = abs(det(bz_.B)) # rescale tolerance to (I)BZ coordinate and get the right number of digits
kws = NamedTuple(_kws)
kws_ = haskey(kws, :abstol) ? merge(kws, (abstol=kws.abstol / j,)) : kws

sol = do_solve(f, dom, p, alg, cacheval; kws_...)
# TODO find a way to throw a warning when constructing the problem instead of after a solve
SymRep(f) isa UnknownRep && !(bz_ isa FullBZ) && !(sol.u isa TrivialRepType) && begin
@warn WARN_UNKNOWN_SYMMETRY
fbz = SymmetricBZ(bz_.A, bz_.B, lattice_bz_limits(bz_.B), nothing)
_cacheval = init_cacheval(f, fbz, p, bzalg)
return do_solve(f, fbz, p, bzalg, _cacheval; _kws...)

Check warning on line 441 in src/brillouin.jl

View check run for this annotation

Codecov / codecov/patch

src/brillouin.jl#L438-L441

Added lines #L438 - L441 were not covered by tests
end
return IntegralSolution(sol.u * j, sol.resid * j, sol.retcode, sol.numevals)
end

"""
TAI(; norm=norm, initdivs=1)
Expand Down Expand Up @@ -421,18 +489,11 @@
"""
AutoPTR_IAI(; reltol=1.0, ptr=AutoPTR(), iai=IAI(), kws...) = AbsoluteEstimate(ptr, iai; reltol=reltol, kws...)


# do not export this, just for internal use
struct AutoBZEvalCounter{B,D,A} <: AutoBZAlgorithm
bz::B
dom::D
alg::A
end

function bz_to_standard(bz::SymmetricBZ, alg::AutoBZEvalCounter)
return alg.bz, alg.dom, EvalCounter(alg.alg)
function count_bz_to_standard(bz, alg)
_bz, dom, _alg = bz_to_standard(bz, alg)
return _bz, dom, EvalCounter(_alg)
end

function do_solve(f, bz::SymmetricBZ, p, alg::EvalCounter{<:AutoBZAlgorithm}, cacheval; kws...)
return do_solve(f, bz, p, AutoBZEvalCounter(bz_to_standard(bz, alg.alg)...), cacheval; kws...)
return do_solve_autobz(count_bz_to_standard, f, bz, p, alg.alg, cacheval; kws...)

Check warning on line 498 in src/brillouin.jl

View check run for this annotation

Codecov / codecov/patch

src/brillouin.jl#L498

Added line #L498 was not covered by tests
end
10 changes: 8 additions & 2 deletions src/fourier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,13 @@ function init_cacheval(f::FourierIntegrand, dom::Basis, p, alg::AutoSymPTRJL)
buffer = init_buffer(f, alg.nthreads)
return (rule=rule, cache=cache, buffer=buffer)
end

function init_cacheval(f::FourierIntegrand, bz::SymmetricBZ, p, bzalg::AutoPTR)
bz_, dom, alg = bz_to_standard(bz, bzalg)
rule = SymmetricRuleDef(init_fourier_rule(f.w, dom, alg), SymRep(f), bz_)
cache = AutoSymPTR.alloc_cache(eltype(dom), Val(ndims(dom)), rule)
buffer = init_buffer(f, alg.nthreads)
return (rule=rule, cache=cache, buffer=buffer)
end
function init_fourier_rule(s::AbstractFourierSeries, bz::SymmetricBZ, alg::PTR)
dom = Basis(bz.B)
return FourierMonkhorstPack(s, eltype(dom), Val(ndims(dom)), alg.npt, bz.syms)
Expand Down Expand Up @@ -520,5 +526,5 @@ end

# method needed to resolve ambiguities
function do_solve(f::FourierIntegrand, bz::SymmetricBZ, p, alg::EvalCounter{<:AutoBZAlgorithm}, cacheval; kws...)
return do_solve(f, bz, p, AutoBZEvalCounter(bz_to_standard(bz, alg.alg)...), cacheval; kws...)
return do_solve_autobz(count_bz_to_standard, f, bz, p, alg.alg, cacheval; kws...)
end
Loading