From 2a16cc377590e1b61efb0c778bb2afa10d904056 Mon Sep 17 00:00:00 2001 From: Abel Soares Siqueira Date: Sat, 16 Nov 2019 14:52:34 -0300 Subject: [PATCH 1/2] Annotate a few types to improve type inference and thus allocation Necessary for Krylov. --- docs/src/tutorial.md | 25 +++++++++---------------- src/LinearOperators.jl | 18 +++++++++--------- src/adjtrans.jl | 24 ++++++++++++------------ test/test_linop.jl | 2 +- 4 files changed, 31 insertions(+), 38 deletions(-) diff --git a/docs/src/tutorial.md b/docs/src/tutorial.md index fe349508..40c0a30b 100644 --- a/docs/src/tutorial.md +++ b/docs/src/tutorial.md @@ -64,28 +64,21 @@ op = LinearOperator(Float64, 10, 10, false, false, nothing, w -> [w[1]; w[1] + w[2]]) ``` -Notice, however, that type is not enforced, which can cause unintended consequences +Make sure that the type passed to `LinearOperator` is correct, otherwise errors may occur. ```@example ex1 +using LinearOperators, FFTW # hide dft = LinearOperator(Float64, 10, 10, false, false, v -> fft(v), nothing, w -> ifft(w)) v = rand(10) -println("eltype(dft) = $(eltype(dft))") -println("eltype(v) = $(eltype(v))") -println("eltype(dft * v) = $(eltype(dft * v))") -``` -or even errors -```jldoctest -using LinearOperators -A = [im 1.0; 0.0 1.0] -op = LinearOperator(Float64, 2, 2, false, false, - v -> A * v, u -> transpose(A) * u, w -> A' * w) -Matrix(op) # Tries to create Float64 matrix with contents of A -# output -ERROR: InexactError: Float64(0.0 + 1.0im) -[...] -``` +println("eltype(dft) = $(eltype(dft))") +println("eltype(v) = $(eltype(v))") +println("eltype(dft.prod(v)) = $(eltype(dft.prod(v)))") +# dft * v # ERROR: expected Vector{Float64} +# Matrix(dft) # ERROR: tried to create a Matrix of Float64 +``` + ## Limited memory BFGS and SR1 Two other useful operators are the Limited-Memory BFGS in forward and inverse form. diff --git a/src/LinearOperators.jl b/src/LinearOperators.jl index 2bf23cfc..8800c00d 100644 --- a/src/LinearOperators.jl +++ b/src/LinearOperators.jl @@ -239,10 +239,10 @@ end # Apply an operator to a vector. -function *(op :: AbstractLinearOperator, v :: AbstractVector) +function *(op :: AbstractLinearOperator{T}, v :: AbstractVector{S}) where {T,S} size(v, 1) == size(op, 2) || throw(LinearOperatorException("shape mismatch")) increase_nprod(op) - op.prod(v) + op.prod(v)::Vector{promote_type(T,S)} end @@ -251,14 +251,14 @@ end Materialize an operator as a dense array using `op.ncol` products. """ -function Base.Matrix(op :: AbstractLinearOperator) +function Base.Matrix(op :: AbstractLinearOperator{T}) where T (m, n) = size(op) - A = Array{eltype(op)}(undef, m, n) - ei = zeros(eltype(op), n) + A = Array{T}(undef, m, n) + ei = zeros(T, n) for i = 1 : n - ei[i] = 1 + ei[i] = one(T) A[:, i] = op * ei - ei[i] = 0 + ei[i] = zero(T) end return A end @@ -531,8 +531,8 @@ Zero operator of size `nrow`-by-`ncol` and of data type `T` (defaults to `Float64`). """ function opZeros(T :: DataType, nrow :: Int, ncol :: Int) - prod = @closure v -> zeros(T, nrow) - tprod = @closure u -> zeros(T, ncol) + prod = @closure v -> zeros(promote_type(T,eltype(v)), nrow) + tprod = @closure u -> zeros(promote_type(T,eltype(u)), ncol) LinearOperator{T}(nrow, ncol, nrow == ncol, nrow == ncol, prod, tprod, tprod) end diff --git a/src/adjtrans.jl b/src/adjtrans.jl index dcb91ccf..d54aaa7c 100644 --- a/src/adjtrans.jl +++ b/src/adjtrans.jl @@ -71,13 +71,13 @@ function show(io :: IO, op :: ConjugateLinearOperator) show(io, op.parent) end -function *(op :: AdjointLinearOperator, v :: AbstractVector) - length(v) == size(op.parent, 1) || throw(LinearOperatorException("shape mismatch")) - p = op.parent +function *(op :: AdjointLinearOperator{T}, v :: AbstractVector{S}) where {T,S} + p = op.parent::AbstractLinearOperator{T} + length(v) == size(p, 1) || throw(LinearOperatorException("shape mismatch")) ishermitian(p) && return p * v if p.ctprod !== nothing increase_nctprod(p) - return p.ctprod(v) + return p.ctprod(v)::Vector{promote_type(T,S)} end tprod = p.tprod increment_tprod = true @@ -94,16 +94,16 @@ function *(op :: AdjointLinearOperator, v :: AbstractVector) else increase_nprod(p) end - return conj.(tprod(conj.(v))) + return conj.(tprod(conj.(v)))::Vector{promote_type(T,S)} end -function *(op :: TransposeLinearOperator, v :: AbstractVector) - length(v) == size(op.parent, 1) || throw(LinearOperatorException("shape mismatch")) - p = op.parent +function *(op :: TransposeLinearOperator{T}, v :: AbstractVector{S}) where {T,S} + p = op.parent::AbstractLinearOperator{T} + length(v) == size(p, 1) || throw(LinearOperatorException("shape mismatch")) issymmetric(p) && return p * v if p.tprod !== nothing increase_ntprod(p) - return p.tprod(v) + return p.tprod(v)::Vector{promote_type(T,S)} end increment_ctprod = true ctprod = p.ctprod @@ -120,12 +120,12 @@ function *(op :: TransposeLinearOperator, v :: AbstractVector) else increase_nprod(p) end - return conj.(ctprod(conj.(v))) + return conj.(ctprod(conj.(v)))::Vector{promote_type(T,S)} end -function *(op :: ConjugateLinearOperator, v :: AbstractVector) +function *(op :: ConjugateLinearOperator{T}, v :: AbstractVector{S}) where {T,S} p = op.parent - return conj.(p * conj.(v)) + return conj.(p * conj.(v))::Vector{promote_type(T,S)} end -(op :: AdjointLinearOperator) = adjoint(-op.parent) diff --git a/test/test_linop.jl b/test/test_linop.jl index 0ef9eb7e..3d295dae 100644 --- a/test/test_linop.jl +++ b/test/test_linop.jl @@ -502,7 +502,7 @@ function test_linop() @test A == Matrix(opC) opF = LinearOperator(Float64, 2, 2, false, false, prod, tprod, ctprod) # The type is a lie @test eltype(opF) == Float64 - @test_throws InexactError Matrix(opF) + @test_throws TypeError Matrix(opF) end # Issue #80 From 37aa757a9b55ea6e0a0f0100a6bfaa06bf6d73d7 Mon Sep 17 00:00:00 2001 From: Abel Soares Siqueira Date: Mon, 18 Nov 2019 17:38:36 -0300 Subject: [PATCH 2/2] Parametrize Adjoint, Transpose and Conjugate --- src/adjtrans.jl | 45 +++++++++++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/src/adjtrans.jl b/src/adjtrans.jl index d54aaa7c..1c3299e2 100644 --- a/src/adjtrans.jl +++ b/src/adjtrans.jl @@ -2,18 +2,31 @@ export AdjointLinearOperator, TransposeLinearOperator, ConjugateLinearOperator, adjoint, transpose, conj # From julialang:stdlib/LinearAlgebra/src/adjtrans.jl -struct AdjointLinearOperator{T} <: AbstractLinearOperator{T} - parent :: AbstractLinearOperator{T} +struct AdjointLinearOperator{T,S} <: AbstractLinearOperator{T} + parent :: S + function AdjointLinearOperator{T,S}(A :: S) where {T,S} + new(A) + end end -struct TransposeLinearOperator{T} <: AbstractLinearOperator{T} - parent :: AbstractLinearOperator{T} +struct TransposeLinearOperator{T,S} <: AbstractLinearOperator{T} + parent :: S + function TransposeLinearOperator{T,S}(A :: S) where {T,S} + new(A) + end end -struct ConjugateLinearOperator{T} <: AbstractLinearOperator{T} - parent :: AbstractLinearOperator{T} +struct ConjugateLinearOperator{T,S} <: AbstractLinearOperator{T} + parent :: S + function ConjugateLinearOperator{T,S}(A :: S) where {T,S} + new(A) + end end +AdjointLinearOperator(A) = AdjointLinearOperator{eltype(A),typeof(A)}(A) +TransposeLinearOperator(A) = TransposeLinearOperator{eltype(A),typeof(A)}(A) +ConjugateLinearOperator(A) = ConjugateLinearOperator{eltype(A),typeof(A)}(A) + adjoint(A :: AbstractLinearOperator) = AdjointLinearOperator(A) adjoint(A :: AdjointLinearOperator) = A.parent transpose(A :: AbstractLinearOperator) = TransposeLinearOperator(A) @@ -71,13 +84,13 @@ function show(io :: IO, op :: ConjugateLinearOperator) show(io, op.parent) end -function *(op :: AdjointLinearOperator{T}, v :: AbstractVector{S}) where {T,S} - p = op.parent::AbstractLinearOperator{T} +function *(op :: AdjointLinearOperator{T,S}, v :: AbstractVector{U}) where {T,S,U} + p = op.parent length(v) == size(p, 1) || throw(LinearOperatorException("shape mismatch")) ishermitian(p) && return p * v if p.ctprod !== nothing increase_nctprod(p) - return p.ctprod(v)::Vector{promote_type(T,S)} + return p.ctprod(v)::Vector{promote_type(T,U)} end tprod = p.tprod increment_tprod = true @@ -94,16 +107,16 @@ function *(op :: AdjointLinearOperator{T}, v :: AbstractVector{S}) where {T,S} else increase_nprod(p) end - return conj.(tprod(conj.(v)))::Vector{promote_type(T,S)} + return conj.(tprod(conj.(v)))::Vector{promote_type(T,U)} end -function *(op :: TransposeLinearOperator{T}, v :: AbstractVector{S}) where {T,S} - p = op.parent::AbstractLinearOperator{T} +function *(op :: TransposeLinearOperator{T,S}, v :: AbstractVector{U}) where {T,S,U} + p = op.parent length(v) == size(p, 1) || throw(LinearOperatorException("shape mismatch")) issymmetric(p) && return p * v if p.tprod !== nothing increase_ntprod(p) - return p.tprod(v)::Vector{promote_type(T,S)} + return p.tprod(v)::Vector{promote_type(T,U)} end increment_ctprod = true ctprod = p.ctprod @@ -120,12 +133,12 @@ function *(op :: TransposeLinearOperator{T}, v :: AbstractVector{S}) where {T,S} else increase_nprod(p) end - return conj.(ctprod(conj.(v)))::Vector{promote_type(T,S)} + return conj.(ctprod(conj.(v)))::Vector{promote_type(T,U)} end -function *(op :: ConjugateLinearOperator{T}, v :: AbstractVector{S}) where {T,S} +function *(op :: ConjugateLinearOperator{T,S}, v :: AbstractVector{U}) where {T,S,U} p = op.parent - return conj.(p * conj.(v))::Vector{promote_type(T,S)} + return conj.(p * conj.(v))::Vector{promote_type(T,U)} end -(op :: AdjointLinearOperator) = adjoint(-op.parent)