diff --git a/src/toeplitzhankel.jl b/src/toeplitzhankel.jl index 75e4ef88..b5387c74 100644 --- a/src/toeplitzhankel.jl +++ b/src/toeplitzhankel.jl @@ -309,6 +309,21 @@ end # th_jac2jac ### +struct ComposePlan{T, Plans} <: Plan{T} + plans::Plans +end + +ComposePlan{T}(A...) where T = ComposePlan{T,typeof(A)}(A) +ComposePlan(A...) = ComposePlan{mapreduce(eltype, promote_type, A)}(A...) + +function *(P::ComposePlan, A::AbstractArray) + ret = A + for p in reverse(P.plans) + ret = p*ret + end + ret +end + function alternatesign!(v) @inbounds for k = 2:2:length(v) v[k] = -v[k] @@ -343,13 +358,27 @@ function _jac2jacTH_TLC(::Type{S}, mn, α, β, γ, δ, d) where {S} (T, DL .* C, DR .* C) end -plan_th_jac2jac!(::Type{S}, mn, α, β, γ, δ, dims::Int) where {S} = ToeplitzHankelPlan(_jac2jacTH_TLC(S, mn, α, β, γ, δ, dims)..., dims) +function plan_th_jac2jac!(::Type{S}, mn, α, β, γ, δ, dims::Int) where {S} + if α == γ || β == δ + ToeplitzHankelPlan(_jac2jacTH_TLC(S, mn, α, β, γ, δ, dims)..., dims) + else + P1 = ToeplitzHankelPlan(_jac2jacTH_TLC(S, mn, α, β, α, δ, dims)..., dims) + P2 = ToeplitzHankelPlan(_jac2jacTH_TLC(S, mn, α, δ, γ, δ, dims)..., dims) + ComposePlan(P2, P1) + end +end function plan_th_jac2jac!(::Type{S}, mn::NTuple{2,Int}, α, β, γ, δ, dims::NTuple{2,Int}) where {S} @assert dims == (1,2) - T1,L1,C1 = _jac2jacTH_TLC(S, mn, α, β, γ, δ, 1) - T2,L2,C2 = _jac2jacTH_TLC(S, mn, α, β, γ, δ, 2) - ToeplitzHankelPlan((T1,T2), (L1,L2), (C1,C2), dims) + if α == γ || β == δ + T1,L1,C1 = _jac2jacTH_TLC(S, mn, α, β, γ, δ, 1) + T2,L2,C2 = _jac2jacTH_TLC(S, mn, α, β, γ, δ, 2) + ToeplitzHankelPlan((T1,T2), (L1,L2), (C1,C2), dims) + else + P1 = plan_th_jac2jac!(S, mn, α, β, α, δ, dims) + P2 = plan_th_jac2jac!(S, mn, α, δ, γ, δ, dims) + ComposePlan(P2, P1) + end end diff --git a/test/toeplitzhankeltests.jl b/test/toeplitzhankeltests.jl index 5d194c69..5d53fc8f 100644 --- a/test/toeplitzhankeltests.jl +++ b/test/toeplitzhankeltests.jl @@ -11,10 +11,14 @@ import FastTransforms: th_leg2cheb, th_cheb2leg, th_leg2chebu, th_ultra2ultra,th @test th_ultra2ultra(x,0.1, 0.2) ≈ lib_ultra2ultra(x, 0.1, 0.2) @test th_jac2jac(x,0.1, 0.2,0.1,0.4) ≈ lib_jac2jac(x, 0.1, 0.2,0.1,0.4) @test th_jac2jac(x,0.1, 0.2,0.3,0.2) ≈ lib_jac2jac(x, 0.1, 0.2,0.3,0.2) + @test th_jac2jac(x,0.1, 0.2,0.3,0.4) ≈ lib_jac2jac(x, 0.1, 0.2,0.3,0.4) - @test th_cheb2leg(th_leg2cheb(x)) ≈ x atol=1E-9 - @test th_leg2cheb(th_cheb2leg(x)) ≈ x atol=1E-10 + @test th_cheb2leg(th_leg2cheb(x)) ≈ x + @test th_leg2cheb(th_cheb2leg(x)) ≈ x + @test th_ultra2ultra(th_ultra2ultra(x, 0.1, 0.6), 0.6, 0.1) ≈ x + @test th_jac2jac(th_jac2jac(x, 0.1, 0.6, 0.1, 0.8), 0.1, 0.8, 0.1, 0.6) ≈ x + @test th_jac2jac(th_jac2jac(x, 0.1, 0.6, 0.2, 0.8), 0.2, 0.8, 0.1, 0.6) ≈ x end for X in (randn(5,4), randn(5,4) + im*randn(5,4)) @@ -53,6 +57,9 @@ import FastTransforms: th_leg2cheb, th_cheb2leg, th_leg2chebu, th_ultra2ultra,th @test th_jac2jac(X, 0.1, 0.6, 0.1, 0.8, 2) ≈ vcat([permutedims(jac2jac(X[k,:], 0.1, 0.6, 0.1, 0.8)) for k=1:size(X,1)]...) @test th_jac2jac(X, 0.1, 0.6, 0.1, 0.8) ≈ th_jac2jac(th_jac2jac(X, 0.1, 0.6, 0.1, 0.8, 1), 0.1, 0.6, 0.1, 0.8, 2) + @test th_jac2jac(X, 0.1, 0.6, 0.2, 0.8, 1) ≈ hcat([jac2jac(X[:,j], 0.1, 0.6, 0.2, 0.8) for j=1:size(X,2)]...) + @test th_jac2jac(X, 0.1, 0.6, 0.2, 0.8, 2) ≈ vcat([permutedims(jac2jac(X[k,:], 0.1, 0.6, 0.2, 0.8)) for k=1:size(X,1)]...) + @test th_jac2jac(X, 0.1, 0.6, 0.1, 0.8) == plan_th_jac2jac!(X, 0.1, 0.6, 0.1, 0.8, 1:2)*copy(X) @test th_jac2jac(X, 0.1, 0.6, 0.1, 0.8) == plan_th_jac2jac!(X, 0.1, 0.6, 0.1, 0.8, 1:2)*copy(X)