Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Annotate a few types to improve type inference and thus allocation #127

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 9 additions & 16 deletions docs/src/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,28 +64,21 @@ op = LinearOperator(Float64, 10, 10, false, false,
nothing,
w -> [w[1]; w[1] + w[2]])
```
Notice, however, that type is not enforced, which can cause unintended consequences
Make sure that the type passed to `LinearOperator` is correct, otherwise errors may occur.
```@example ex1
using LinearOperators, FFTW # hide
dft = LinearOperator(Float64, 10, 10, false, false,
v -> fft(v),
nothing,
w -> ifft(w))
v = rand(10)
println("eltype(dft) = $(eltype(dft))")
println("eltype(v) = $(eltype(v))")
println("eltype(dft * v) = $(eltype(dft * v))")
```
or even errors
```jldoctest
using LinearOperators
A = [im 1.0; 0.0 1.0]
op = LinearOperator(Float64, 2, 2, false, false,
v -> A * v, u -> transpose(A) * u, w -> A' * w)
Matrix(op) # Tries to create Float64 matrix with contents of A
# output
ERROR: InexactError: Float64(0.0 + 1.0im)
[...]
```
println("eltype(dft) = $(eltype(dft))")
println("eltype(v) = $(eltype(v))")
println("eltype(dft.prod(v)) = $(eltype(dft.prod(v)))")
# dft * v # ERROR: expected Vector{Float64}
# Matrix(dft) # ERROR: tried to create a Matrix of Float64
```

## Limited memory BFGS and SR1

Two other useful operators are the Limited-Memory BFGS in forward and inverse form.
Expand Down
18 changes: 9 additions & 9 deletions src/LinearOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,10 @@ end


# Apply an operator to a vector.
function *(op :: AbstractLinearOperator, v :: AbstractVector)
function *(op :: AbstractLinearOperator{T}, v :: AbstractVector{S}) where {T,S}
size(v, 1) == size(op, 2) || throw(LinearOperatorException("shape mismatch"))
increase_nprod(op)
op.prod(v)
op.prod(v)::Vector{promote_type(T,S)}
end


Expand All @@ -251,14 +251,14 @@ end

Materialize an operator as a dense array using `op.ncol` products.
"""
function Base.Matrix(op :: AbstractLinearOperator)
function Base.Matrix(op :: AbstractLinearOperator{T}) where T
(m, n) = size(op)
A = Array{eltype(op)}(undef, m, n)
ei = zeros(eltype(op), n)
A = Array{T}(undef, m, n)
ei = zeros(T, n)
for i = 1 : n
ei[i] = 1
ei[i] = one(T)
A[:, i] = op * ei
ei[i] = 0
ei[i] = zero(T)
end
return A
end
Expand Down Expand Up @@ -531,8 +531,8 @@ Zero operator of size `nrow`-by-`ncol` and of data type `T` (defaults to
`Float64`).
"""
function opZeros(T :: DataType, nrow :: Int, ncol :: Int)
prod = @closure v -> zeros(T, nrow)
tprod = @closure u -> zeros(T, ncol)
prod = @closure v -> zeros(promote_type(T,eltype(v)), nrow)
tprod = @closure u -> zeros(promote_type(T,eltype(u)), ncol)
LinearOperator{T}(nrow, ncol, nrow == ncol, nrow == ncol, prod, tprod, tprod)
end

Expand Down
45 changes: 29 additions & 16 deletions src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,31 @@ export AdjointLinearOperator, TransposeLinearOperator, ConjugateLinearOperator,
adjoint, transpose, conj

# From julialang:stdlib/LinearAlgebra/src/adjtrans.jl
struct AdjointLinearOperator{T} <: AbstractLinearOperator{T}
parent :: AbstractLinearOperator{T}
struct AdjointLinearOperator{T,S} <: AbstractLinearOperator{T}
parent :: S
function AdjointLinearOperator{T,S}(A :: S) where {T,S}
new(A)
end
end

struct TransposeLinearOperator{T} <: AbstractLinearOperator{T}
parent :: AbstractLinearOperator{T}
struct TransposeLinearOperator{T,S} <: AbstractLinearOperator{T}
parent :: S
function TransposeLinearOperator{T,S}(A :: S) where {T,S}
new(A)
end
end

struct ConjugateLinearOperator{T} <: AbstractLinearOperator{T}
parent :: AbstractLinearOperator{T}
struct ConjugateLinearOperator{T,S} <: AbstractLinearOperator{T}
parent :: S
function ConjugateLinearOperator{T,S}(A :: S) where {T,S}
new(A)
end
end

AdjointLinearOperator(A) = AdjointLinearOperator{eltype(A),typeof(A)}(A)
TransposeLinearOperator(A) = TransposeLinearOperator{eltype(A),typeof(A)}(A)
ConjugateLinearOperator(A) = ConjugateLinearOperator{eltype(A),typeof(A)}(A)

adjoint(A :: AbstractLinearOperator) = AdjointLinearOperator(A)
adjoint(A :: AdjointLinearOperator) = A.parent
transpose(A :: AbstractLinearOperator) = TransposeLinearOperator(A)
Expand Down Expand Up @@ -71,13 +84,13 @@ function show(io :: IO, op :: ConjugateLinearOperator)
show(io, op.parent)
end

function *(op :: AdjointLinearOperator, v :: AbstractVector)
length(v) == size(op.parent, 1) || throw(LinearOperatorException("shape mismatch"))
function *(op :: AdjointLinearOperator{T,S}, v :: AbstractVector{U}) where {T,S,U}
p = op.parent
length(v) == size(p, 1) || throw(LinearOperatorException("shape mismatch"))
ishermitian(p) && return p * v
if p.ctprod !== nothing
increase_nctprod(p)
return p.ctprod(v)
return p.ctprod(v)::Vector{promote_type(T,U)}
end
tprod = p.tprod
increment_tprod = true
Expand All @@ -94,16 +107,16 @@ function *(op :: AdjointLinearOperator, v :: AbstractVector)
else
increase_nprod(p)
end
return conj.(tprod(conj.(v)))
return conj.(tprod(conj.(v)))::Vector{promote_type(T,U)}
end

function *(op :: TransposeLinearOperator, v :: AbstractVector)
length(v) == size(op.parent, 1) || throw(LinearOperatorException("shape mismatch"))
function *(op :: TransposeLinearOperator{T,S}, v :: AbstractVector{U}) where {T,S,U}
p = op.parent
length(v) == size(p, 1) || throw(LinearOperatorException("shape mismatch"))
issymmetric(p) && return p * v
if p.tprod !== nothing
increase_ntprod(p)
return p.tprod(v)
return p.tprod(v)::Vector{promote_type(T,U)}
end
increment_ctprod = true
ctprod = p.ctprod
Expand All @@ -120,12 +133,12 @@ function *(op :: TransposeLinearOperator, v :: AbstractVector)
else
increase_nprod(p)
end
return conj.(ctprod(conj.(v)))
return conj.(ctprod(conj.(v)))::Vector{promote_type(T,U)}
end

function *(op :: ConjugateLinearOperator, v :: AbstractVector)
function *(op :: ConjugateLinearOperator{T,S}, v :: AbstractVector{U}) where {T,S,U}
p = op.parent
return conj.(p * conj.(v))
return conj.(p * conj.(v))::Vector{promote_type(T,U)}
end

-(op :: AdjointLinearOperator) = adjoint(-op.parent)
Expand Down
2 changes: 1 addition & 1 deletion test/test_linop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ function test_linop()
@test A == Matrix(opC)
opF = LinearOperator(Float64, 2, 2, false, false, prod, tprod, ctprod) # The type is a lie
@test eltype(opF) == Float64
@test_throws InexactError Matrix(opF)
@test_throws TypeError Matrix(opF)
end

# Issue #80
Expand Down