From 576463f71ae3e08e60d3e519f7dfc82af85c141f Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Tue, 5 Dec 2023 23:23:33 +0000 Subject: [PATCH] 4-tensor Mul/InvPlan (#172) * 4-tensor Mul/InvPlan * Update test_splines.jl --- Project.toml | 2 +- src/plans.jl | 139 ++++++++++++++++++++----------------------- test/test_splines.jl | 22 +++++++ 3 files changed, 87 insertions(+), 76 deletions(-) diff --git a/Project.toml b/Project.toml index 4b977c4..d7dde3e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ContinuumArrays" uuid = "7ae1f121-cc2c-504b-ac30-9b923412ae5c" -version = "0.17" +version = "0.17.1" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/src/plans.jl b/src/plans.jl index acfb8a5..497041c 100644 --- a/src/plans.jl +++ b/src/plans.jl @@ -42,47 +42,6 @@ InvPlan(fact, dims) = InvPlan((fact,), dims) size(F::InvPlan) = size.(F.factorizations, 1) -function *(P::InvPlan{<:Any,<:Tuple,Int}, x::AbstractVector) - @assert P.dims == 1 - only(P.factorizations) \ x # Only a single factorization when dims isa Int -end - -function *(P::InvPlan{<:Any,<:Tuple,Int}, X::AbstractMatrix) - if P.dims == 1 - only(P.factorizations) \ X # Only a single factorization when dims isa Int - else - @assert P.dims == 2 - permutedims(only(P.factorizations) \ permutedims(X)) - end -end - -function *(P::InvPlan{<:Any,<:Tuple,Int}, X::AbstractArray{<:Any,3}) - Y = similar(X) - if P.dims == 1 - for j in axes(X,3) - Y[:,:,j] = only(P.factorizations) \ X[:,:,j] - end - elseif P.dims == 2 - for k in axes(X,1) - Y[k,:,:] = only(P.factorizations) \ X[k,:,:] - end - else - @assert P.dims == 3 - for k in axes(X,1), j in axes(X,2) - Y[k,j,:] = only(P.factorizations) \ X[k,j,:] - end - end - Y -end - -function *(P::InvPlan, X::AbstractArray) - for d in P.dims - X = InvPlan(P.factorizations[d], d) * X - end - X -end - - """ MulPlan(matrix, dims) @@ -96,44 +55,74 @@ end MulPlan(mats::Tuple, dims) = MulPlan{eltype(mats), typeof(mats), typeof(dims)}(mats, dims) MulPlan(mats::AbstractMatrix, dims) = MulPlan((mats,), dims) -function *(P::MulPlan{<:Any,<:Tuple,Int}, x::AbstractVector) - @assert P.dims == 1 - only(P.matrices) * x -end - -function *(P::MulPlan{<:Any,<:Tuple,Int}, X::AbstractMatrix) - if P.dims == 1 - only(P.matrices) * X - else - @assert P.dims == 2 - permutedims(only(P.matrices) * permutedims(X)) - end -end - -function *(P::MulPlan{<:Any,<:Tuple,Int}, X::AbstractArray{<:Any,3}) - Y = similar(X) - if P.dims == 1 - for j in axes(X,3) - Y[:,:,j] = only(P.matrices) * X[:,:,j] +for (Pln,op,fld) in ((:MulPlan, :*, :(:matrices)), (:InvPlan, :\, :(:factorizations))) + @eval begin + function *(P::$Pln{<:Any,<:Tuple,Int}, x::AbstractVector) + @assert P.dims == 1 + $op(only(getfield(P, $fld)), x) # Only a single factorization when dims isa Int end - elseif P.dims == 2 - for k in axes(X,1) - Y[k,:,:] = only(P.matrices) * X[k,:,:] + + function *(P::$Pln{<:Any,<:Tuple,Int}, X::AbstractMatrix) + if P.dims == 1 + $op(only(getfield(P, $fld)), X) # Only a single factorization when dims isa Int + else + @assert P.dims == 2 + permutedims($op(only(getfield(P, $fld)), permutedims(X))) + end end - else - @assert P.dims == 3 - for k in axes(X,1), j in axes(X,2) - Y[k,j,:] = only(P.matrices) * X[k,j,:] + + function *(P::$Pln{<:Any,<:Tuple,Int}, X::AbstractArray{<:Any,3}) + Y = similar(X) + if P.dims == 1 + for j in axes(X,3) + Y[:,:,j] = $op(only(getfield(P, $fld)), X[:,:,j]) + end + elseif P.dims == 2 + for k in axes(X,1) + Y[k,:,:] = $op(only(getfield(P, $fld)), X[k,:,:]) + end + else + @assert P.dims == 3 + for k in axes(X,1), j in axes(X,2) + Y[k,j,:] = $op(only(getfield(P, $fld)), X[k,j,:]) + end + end + Y + end + + function *(P::$Pln{<:Any,<:Tuple,Int}, X::AbstractArray{<:Any,4}) + Y = similar(X) + if P.dims == 1 + for j in axes(X,3), l in axes(X,4) + Y[:,:,j,l] = $op(only(getfield(P, $fld)), X[:,:,j,l]) + end + elseif P.dims == 2 + for k in axes(X,1), l in axes(X,4) + Y[k,:,:,l] = $op(only(getfield(P, $fld)), X[k,:,:,l]) + end + elseif P.dims == 3 + for k in axes(X,1), j in axes(X,2) + Y[k,j,:,:] = $op(only(getfield(P, $fld)), X[k,j,:,:]) + end + elseif P.dims == 4 + for k in axes(X,1), j in axes(X,2), l in axes(X,3) + Y[k,j,l,:] = $op(only(getfield(P, $fld)), X[k,j,l,:]) + end + end + Y + end + + + + *(P::$Pln{<:Any,<:Tuple,Int}, X::AbstractArray) = error("Overload") + + function *(P::$Pln, X::AbstractArray) + for (fac,dim) in zip(getfield(P, $fld), P.dims) + X = $Pln(fac, dim) * X + end + X end end - Y -end - -function *(P::MulPlan, X::AbstractArray) - for d in P.dims - X = MulPlan(P.matrices[d], d) * X - end - X end *(A::AbstractMatrix, P::MulPlan) = MulPlan(Ref(A) .* P.matrices, P.dims) diff --git a/test/test_splines.jl b/test/test_splines.jl index cecdea0..ede1f20 100644 --- a/test/test_splines.jl +++ b/test/test_splines.jl @@ -526,6 +526,28 @@ import ContinuumArrays: basis, AdjointBasisLayout, ExpansionLayout, BasisLayout, X[k, j, :] = L[g,:] \ X[k, j, :] end @test PX ≈ X + + n = size(L,2) + X = randn(n, n, n, n) + P = plan_transform(L, X) + PX = P * X + for k = 1:n, j = 1:n, l = 1:n + X[:, k, j, l] = L[g,:] \ X[:, k, j, l] + end + for k = 1:n, j = 1:n, l = 1:n + X[k, :, j, l] = L[g,:] \ X[k, :, j, l] + end + for k = 1:n, j = 1:n, l = 1:n + X[k, j, :, l] = L[g,:] \ X[k, j, :, l] + end + for k = 1:n, j = 1:n, l = 1:n + X[k, j, l, :] = L[g,:] \ X[k, j, l, :] + end + @test PX ≈ X + + X = randn(n, n, n, n, n) + P = plan_transform(L, X) + @test_throws ErrorException P * X end @testset "Mul coefficients" begin