Skip to content

Commit

Permalink
change function type in integrate
Browse files Browse the repository at this point in the history
  • Loading branch information
svretina committed Feb 4, 2025
1 parent 5168c79 commit b1f2899
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 44 deletions.
68 changes: 32 additions & 36 deletions src/FastTanhSinhQuadrature.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@ function tanhsinh(::Type{T}, n::Int) where {T<:AbstractFloat}
x = ordinate.(t)
w = weight.(t)
N = length(x)
if n < 100
return SVector{N,T}(x), SVector{N,T}(w), h
else
return x, w, h
end
return x, w, h
end

tanhsinh(n::Int) = tanhsinh(Float64, n)
Expand Down Expand Up @@ -70,6 +66,7 @@ end
xp = x₀ + Δxxi
xm = x₀ - Δxxi
if xp xmax || xm xmin
@show i
x[i] = x₀
w[i] = zero(T)
end
Expand Down Expand Up @@ -123,8 +120,8 @@ function integrate(f::Function, n::Int)
end

# [-1,1] by default 1D
function integrate(f::Function, x::AbstractVector{T}, w::AbstractVector{T},
h::T) where {T<:Real}
function integrate(f::X, x::AbstractVector{T}, w::AbstractVector{T},
h::T) where {T<:Real,X}
s = weight(zero(T)) * f(zero(T))
# ncalls[1] += 1
for i in 1:length(x)
Expand All @@ -140,25 +137,23 @@ end
@fastmath @inbounds begin
Δx = (xmax - xmin) / 2
x₀ = (xmax + xmin) / 2
@show x₀, Δx
s = weight(zero(T)) * f(x₀)
#ncalls[1] += 1
for i in 1:length(x)
xp = x₀ + Δx * x[i]
xm = x₀ - Δx * x[i]
if xm > xmin
if xm > xmin ## this is a problematic check if xmin>xmax
s += w[i] * f(xm)
#ncalls[1] += 1
end
if xp < xmax
s += w[i] * f(xp)
# ncalls[1] += 1
end
end
end
return Δx * h * s
end


## carefull, this is unsafe for a function with a singularity at the endpoints
## if you want to use this with a singular function, then first run
## remove_endpoints! on your weights and points and then use this function
Expand All @@ -175,6 +170,16 @@ end
return @fastmath Δx * h * s
end

# [-1, 1] by default
@inline function integrate_avx(f::S, x::AbstractVector{T}, w::AbstractVector{T}, h::T) where {T<:Real,S}
μηδεν = zero(T)
@fastmath s = weight(μηδεν) * f(μηδεν)
@turbo for i in 1:length(x)
s += w[i] * (f(-x[i]) + f(x[i]))
end
return @fastmath h * s
end

## 2D
@inline function integrate(f::S, xmin::SVector{2,T}, xmax::SVector{2,T},
x::AbstractVector{T}, w::AbstractVector{T}, h::T) where {T<:Real,S}
Expand All @@ -199,18 +204,6 @@ function integrate(f::X, xmin::AbstractVector{S}, xmax::AbstractVector{S}, x::Ab
return integrate(f, SVector{n,T}(xmin), SVector{n,T}(xmax), x, w, h)
end

function _integrate(f::Function, D::Int, x::AbstractVector{T}, w::AbstractVector{T}, h::T) where {T<:Real}
if D == 2
f2(x1) = quad(y -> f(x1, y), x, w, h)
return quad(x1 -> f2(x1), x, w, h)
elseif D == 3
g1(x1, y1) = quad(z -> f(x1, y1, z), x, w, h)
g2(x1) = quad(y -> g1(x1, y), x, w, h)
return quad(x1 -> g2(x1), x, w, h)
end
return zero(T)
end

# helper function for generality
@inline function quad(f::Function, xmin::T, xmax::T, x::AbstractVector{T}, w::AbstractVector{T}, h::T) where {T<:Real}
if xmin == xmax
Expand All @@ -235,13 +228,22 @@ end
if (xmin[1] == xmax[1]) || (xmin[2] == xmax[2])
return zero(T)
end
#ncalls = @MVector [0]
# if (xmin[1] == -1) && (xmin[2] == -1) && (xmax[1] == 1) && (xmax[2] == -1)
# return _integrate(f, 2, x, w, h)#, ncalls[1]
# else
# return integrate(f, xmin, xmax, x, w, h)#, ncalls[1]
# end
return integrate(f, xmin, xmax, x, w, h)#, ncalls[1]
if all(xmin .< xmax)
return integrate(f, xmin, xmax, x, w, h)
else
sign = 1
@inbounds for i in 1:2
low[i] = xmin[i]
up[i] = xmax[i]
if xmin[i] > xmax[i]
sign *= -1
tmp = xmin[i]
xmin[i] = xmax[i]
xmax[i] = tmp
end
end
return sign * integrate(f, xmin, xmax, x, w, h)
end
end

# 3D
Expand All @@ -250,12 +252,6 @@ end
if (xmin[1] == xmax[1]) || (xmin[2] == xmax[2]) || (xmin[3] == xmax[3])
return zero(T)
end
#ncalls = @MVector [0]
# if (xmin[1] == -1) && (xmin[2] == -1) && (xmin[3] == -1) && (xmax[1] == 1) && (xmax[2] == -1) && (xmax[3] == -1)
# return _integrate(f, 3, x, w, h)#, ncalls[1]
# else
# return integrate(f, xmin, xmax, x, w, h)#, ncalls[1]
# end
return integrate(f, xmin, xmax, x, w, h)#, ncalls[1]
end

Expand Down
42 changes: 34 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using FastTanhSinhQuadrature
using Test
using DoubleFloats
using Random
using StaticArrays

const Types = [Float32, Float64, Double64, BigFloat]
const rtol = Dict(Float32 => 10 * sqrt(eps(Float32)),
Expand Down Expand Up @@ -49,17 +50,42 @@ Random.seed!(0)
f(x) = b0 + b1 * x + b2 * x^2
g(x) = c0 + c1 * x + c2 * x^2

F = integrate(f, x, w, h)
G = integrate(g, x, w, h)
F1 = integrate(f, x, w, h)
F2 = integrate_avx(f, x, w, h)
G1 = integrate(g, x, w, h)
G2 = integrate_avx(g, x, w, h)

afg(x) = a * f(x) + g(x)
@test integrate(afg, x, w, h) a * F + G
@test integrate(afg, x, w, h) a * F1 + G1
@test integrate_avx(afg, x, w, h) a * F2 + G2

d = T(rand(-9:9)) / 10

#@test integrate(f, one(T), -one(T), x, w, h) ≈ -F
# @test integrate(f, one(T), -one(T), x, w, h) ≈ -F1
@test integrate_avx(f, one(T), -one(T), x, w, h) -F2

F0 = integrate(f, -one(T), a, x, w, h)
F1 = integrate(f, a, one(T), x, w, h)
@test isapprox(F0 + F1, F, rtol=rtol[T])
F01 = integrate(f, -one(T), a, x, w, h)
F11 = integrate(f, a, one(T), x, w, h)
F02 = integrate_avx(f, -one(T), a, x, w, h)
F12 = integrate_avx(f, a, one(T), x, w, h)
@test isapprox(F01 + F11, F1, rtol=rtol[T])
@test isapprox(F02 + F12, F2, rtol=rtol[T])
end

@testset "2D polynomials for [-1, 1], T=$T" for T in Types
x, w, h = tanhsinh(T, 80)
ψ(x, y) = one(T)
f(x, y) = x * y
g(x, y) = x^2 * y^2

low = SVector{2,T}(-1.0, -1.0)
up = SVector{2,T}(1.0, 1.0)

Ψ = integrate(ψ, low, up, x, w, h)
Ψ2 = integrate(ψ, up, low, x, w, h)
F = integrate(f, low, up, x, w, h)
G = integrate(g, low, up, x, w, h)
@test Ψ 4one(T)
@test Ψ2 4one(T)
@test F zero(T)
@test G T(4) / T(9)
end

0 comments on commit b1f2899

Please sign in to comment.