Skip to content

Commit

Permalink
Add Spline{-1} support (#168)
Browse files Browse the repository at this point in the history
* Add Spline{-1} support

* add tests
  • Loading branch information
dlfivefifty authored Dec 5, 2023
1 parent c42fedf commit 1cb30aa
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 4 deletions.
51 changes: 48 additions & 3 deletions src/bases/splines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,41 @@ function getindex(B::HeavisideSpline{T}, x::Number, k::Int) where T
x axes(B,1) && 1 k  size(B,2)|| throw(BoundsError())

p = B.points
n = length(p)

p[k] < x < p[k+1] && return one(T)
p[k] == x && return one(T)/2
p[k+1] == x && return one(T)/2
return zero(T)
end

function getindex(B::Spline{-1,T}, x::Number, k::Int) where T
x axes(B,1) && 1 k size(B,2)|| throw(BoundsError())

p = B.points
p[k+1] == x && return convert(T,Inf)
zero(T)
end



grid(L::HeavisideSpline, n...) = L.points[1:end-1] .+ diff(L.points)/2
plotgrid(L::HeavisideSpline, n...) = [L.points'; L.points'][2:end-1]
function plotgridvalues(f::ApplyQuasiVector{<:Any,typeof(*),<:Tuple{HeavisideSpline,Any}})
g = plotgrid(basis(f))
c = coefficients(f)
g,vec([c'; c'])
end

function plotgrid(L::Spline{-1}, n...)
p = L.points[2:end-1]
vec([p'; p'; p'])
end
function plotgridvalues(f::ApplyQuasiVector{<:Any,typeof(*),<:Tuple{Spline{-1},Any}})
g = plotgrid(basis(f))
c = coefficients(f)
g,vec([zeros(1,length(c)); c'; fill(NaN,1,length(c))])
end


# Splines sample same number of points regardless of length.
grid(L::HeavisideSpline, ::Integer) = L.points[1:end-1] .+ diff(L.points)/2
grid(L::LinearSpline, ::Integer) = L.points
Expand Down Expand Up @@ -88,6 +115,17 @@ function diff(L::LinearSpline{T}; dims::Integer=1) where T
ApplyQuasiMatrix(*, HeavisideSpline{T}(x), D)
end

function diff(L::HeavisideSpline{T}; dims::Integer=1) where T
dims == 1 || error("not implemented")
n = size(L,2)
x = L.points
D = BandedMatrix{T}(undef, (n-1,n), (0,1))
d = diff(x)
D[band(0)] .= -one(T)
D[band(1)] .= one(T)
ApplyQuasiMatrix(*, Spline{-1,T}(x), D)
end


##
# sum
Expand All @@ -99,6 +137,7 @@ function _sum(A::HeavisideSpline, dims)
end

function _sum(P::LinearSpline, dims)
dims == 1 || error("not implemented")
d = diff(P.points)
ret = Array{float(eltype(d))}(undef, length(d)+1)
ret[1] = d[1]/2
Expand All @@ -109,4 +148,10 @@ function _sum(P::LinearSpline, dims)
permutedims(ret)
end

_cumsum(H::HeavisideSpline{T}, dims) where T = LinearSpline(H.points) * tril(Ones{T}(length(H.points),length(H.points)-1) .* diff(H.points)',-1)
function _sum(P::Spline{-1,T}, dims) where T
dims == 1 || error("not implemented")
Ones{T}(1, size(P,2))
end

_cumsum(H::HeavisideSpline{T}, dims) where T = LinearSpline(H.points) * tril(Ones{T}(length(H.points),length(H.points)-1) .* diff(H.points)',-1)
_cumsum(S::Spline{-1,T}, dims) where T = HeavisideSpline(S.points) * tril(ones(T,length(S.points)-1,length(S.points)-2),-1)
10 changes: 9 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using ContinuumArrays, QuasiArrays, IntervalSets, DomainSets, FillArrays, LinearAlgebra, BandedMatrices, InfiniteArrays, Test, Base64
import ContinuumArrays: ℵ₁, materialize, AffineQuasiVector, BasisLayout, AdjointBasisLayout, SubBasisLayout, ℵ₁,
MappedBasisLayout, AdjointMappedBasisLayouts, MappedWeightedBasisLayout, TransformFactorization, Weight, WeightedBasisLayout, SubWeightedBasisLayout, WeightLayout,
basis, invmap, Map, checkpoints, plotgrid, plotgrid_layout, mul, plotvalues
basis, invmap, Map, checkpoints, plotgrid, plotgrid_layout, mul, plotvalues, plotgridvalues
import QuasiArrays: SubQuasiArray, MulQuasiMatrix, Vec, Inclusion, QuasiDiagonal, LazyQuasiArrayApplyStyle, LazyQuasiArrayStyle
import LazyArrays: MemoryLayout, ApplyStyle, Applied, colsupport, arguments, ApplyLayout, LdivStyle, MulStyle

Expand Down Expand Up @@ -94,6 +94,14 @@ include("test_basisconcat.jl")
a = affine(0..1, 1..5)
v = L[a,:] * c
@test plotvalues(v) == v[plotgrid(v)]

H = HeavisideSpline(1:5)
u = H * (2:5)
x,v = plotgridvalues(u)
@test u[[1+4eps(),2-4eps(),2+4eps(),3-4eps(),3+4eps(),4-4eps(),4+4eps(),5-4eps()]] v

x,v = plotgridvalues(diff(u))
@test x == [2,2,2,3,3,3,4,4,4]
end

include("test_recipesbaseext.jl")
Expand Down
18 changes: 18 additions & 0 deletions test/test_splines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -569,4 +569,22 @@ import ContinuumArrays: basis, AdjointBasisLayout, ExpansionLayout, BasisLayout,
@test Pl*(exp.(x .+ y')) plan_transform(L, Block(1,1), 2) * (plan_transform(L, Block(1,1), 1) * exp.(x .+ y'))
end
end

@testset "Dirac" begin
H = HeavisideSpline(0:5)
S = Spline{-1}(0:5)
@test iszero(S[0.1,1])
@test iszero(S[0.1,1:4])
@test isinf(S[1,1])
@test iszero(S[1,2])
@test iszero(S[0,:])
@test_throws BoundsError S[0.1,0]
@test_throws BoundsError S[-1,1]

@test S \ diff(H) == diagm(0 => fill(-1,4), 1 => fill(1, 4))[1:end-1,:]

u = S * (1:4)
@test sum(u) == 10
@test cumsum(u)[5-4eps()] == 10
end
end

0 comments on commit 1cb30aa

Please sign in to comment.