Skip to content

Commit

Permalink
Annotate a few types to improve type inference and thus allocation
Browse files Browse the repository at this point in the history
Necessary for Krylov.
  • Loading branch information
abelsiqueira committed Nov 17, 2019
1 parent 49398fe commit 2a16cc3
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 38 deletions.
25 changes: 9 additions & 16 deletions docs/src/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,28 +64,21 @@ op = LinearOperator(Float64, 10, 10, false, false,
nothing,
w -> [w[1]; w[1] + w[2]])
```
Notice, however, that type is not enforced, which can cause unintended consequences
Make sure that the type passed to `LinearOperator` is correct, otherwise errors may occur.
```@example ex1
using LinearOperators, FFTW # hide
dft = LinearOperator(Float64, 10, 10, false, false,
v -> fft(v),
nothing,
w -> ifft(w))
v = rand(10)
println("eltype(dft) = $(eltype(dft))")
println("eltype(v) = $(eltype(v))")
println("eltype(dft * v) = $(eltype(dft * v))")
```
or even errors
```jldoctest
using LinearOperators
A = [im 1.0; 0.0 1.0]
op = LinearOperator(Float64, 2, 2, false, false,
v -> A * v, u -> transpose(A) * u, w -> A' * w)
Matrix(op) # Tries to create Float64 matrix with contents of A
# output
ERROR: InexactError: Float64(0.0 + 1.0im)
[...]
```
println("eltype(dft) = $(eltype(dft))")
println("eltype(v) = $(eltype(v))")
println("eltype(dft.prod(v)) = $(eltype(dft.prod(v)))")
# dft * v # ERROR: expected Vector{Float64}
# Matrix(dft) # ERROR: tried to create a Matrix of Float64
```

## Limited memory BFGS and SR1

Two other useful operators are the Limited-Memory BFGS in forward and inverse form.
Expand Down
18 changes: 9 additions & 9 deletions src/LinearOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,10 @@ end


# Apply an operator to a vector.
function *(op :: AbstractLinearOperator, v :: AbstractVector)
function *(op :: AbstractLinearOperator{T}, v :: AbstractVector{S}) where {T,S}
size(v, 1) == size(op, 2) || throw(LinearOperatorException("shape mismatch"))
increase_nprod(op)
op.prod(v)
op.prod(v)::Vector{promote_type(T,S)}
end


Expand All @@ -251,14 +251,14 @@ end
Materialize an operator as a dense array using `op.ncol` products.
"""
function Base.Matrix(op :: AbstractLinearOperator)
function Base.Matrix(op :: AbstractLinearOperator{T}) where T
(m, n) = size(op)
A = Array{eltype(op)}(undef, m, n)
ei = zeros(eltype(op), n)
A = Array{T}(undef, m, n)
ei = zeros(T, n)
for i = 1 : n
ei[i] = 1
ei[i] = one(T)
A[:, i] = op * ei
ei[i] = 0
ei[i] = zero(T)
end
return A
end
Expand Down Expand Up @@ -531,8 +531,8 @@ Zero operator of size `nrow`-by-`ncol` and of data type `T` (defaults to
`Float64`).
"""
function opZeros(T :: DataType, nrow :: Int, ncol :: Int)
prod = @closure v -> zeros(T, nrow)
tprod = @closure u -> zeros(T, ncol)
prod = @closure v -> zeros(promote_type(T,eltype(v)), nrow)
tprod = @closure u -> zeros(promote_type(T,eltype(u)), ncol)
LinearOperator{T}(nrow, ncol, nrow == ncol, nrow == ncol, prod, tprod, tprod)
end

Expand Down
24 changes: 12 additions & 12 deletions src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ function show(io :: IO, op :: ConjugateLinearOperator)
show(io, op.parent)
end

function *(op :: AdjointLinearOperator, v :: AbstractVector)
length(v) == size(op.parent, 1) || throw(LinearOperatorException("shape mismatch"))
p = op.parent
function *(op :: AdjointLinearOperator{T}, v :: AbstractVector{S}) where {T,S}
p = op.parent::AbstractLinearOperator{T}
length(v) == size(p, 1) || throw(LinearOperatorException("shape mismatch"))
ishermitian(p) && return p * v
if p.ctprod !== nothing
increase_nctprod(p)
return p.ctprod(v)
return p.ctprod(v)::Vector{promote_type(T,S)}
end
tprod = p.tprod
increment_tprod = true
Expand All @@ -94,16 +94,16 @@ function *(op :: AdjointLinearOperator, v :: AbstractVector)
else
increase_nprod(p)
end
return conj.(tprod(conj.(v)))
return conj.(tprod(conj.(v)))::Vector{promote_type(T,S)}
end

function *(op :: TransposeLinearOperator, v :: AbstractVector)
length(v) == size(op.parent, 1) || throw(LinearOperatorException("shape mismatch"))
p = op.parent
function *(op :: TransposeLinearOperator{T}, v :: AbstractVector{S}) where {T,S}
p = op.parent::AbstractLinearOperator{T}
length(v) == size(p, 1) || throw(LinearOperatorException("shape mismatch"))
issymmetric(p) && return p * v
if p.tprod !== nothing
increase_ntprod(p)
return p.tprod(v)
return p.tprod(v)::Vector{promote_type(T,S)}
end
increment_ctprod = true
ctprod = p.ctprod
Expand All @@ -120,12 +120,12 @@ function *(op :: TransposeLinearOperator, v :: AbstractVector)
else
increase_nprod(p)
end
return conj.(ctprod(conj.(v)))
return conj.(ctprod(conj.(v)))::Vector{promote_type(T,S)}
end

function *(op :: ConjugateLinearOperator, v :: AbstractVector)
function *(op :: ConjugateLinearOperator{T}, v :: AbstractVector{S}) where {T,S}
p = op.parent
return conj.(p * conj.(v))
return conj.(p * conj.(v))::Vector{promote_type(T,S)}
end

-(op :: AdjointLinearOperator) = adjoint(-op.parent)
Expand Down
2 changes: 1 addition & 1 deletion test/test_linop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ function test_linop()
@test A == Matrix(opC)
opF = LinearOperator(Float64, 2, 2, false, false, prod, tprod, ctprod) # The type is a lie
@test eltype(opF) == Float64
@test_throws InexactError Matrix(opF)
@test_throws TypeError Matrix(opF)
end

# Issue #80
Expand Down

0 comments on commit 2a16cc3

Please sign in to comment.