Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Annotate a few types to improve type inference and thus allocation #127

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 9 additions & 16 deletions docs/src/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,28 +64,21 @@ op = LinearOperator(Float64, 10, 10, false, false,
nothing,
w -> [w[1]; w[1] + w[2]])
```
Notice, however, that type is not enforced, which can cause unintended consequences
Make sure that the type passed to `LinearOperator` is correct, otherwise errors may occur.
```@example ex1
using LinearOperators, FFTW # hide
dft = LinearOperator(Float64, 10, 10, false, false,
v -> fft(v),
nothing,
w -> ifft(w))
v = rand(10)
println("eltype(dft) = $(eltype(dft))")
println("eltype(v) = $(eltype(v))")
println("eltype(dft * v) = $(eltype(dft * v))")
```
or even errors
```jldoctest
using LinearOperators
A = [im 1.0; 0.0 1.0]
op = LinearOperator(Float64, 2, 2, false, false,
v -> A * v, u -> transpose(A) * u, w -> A' * w)
Matrix(op) # Tries to create Float64 matrix with contents of A
# output
ERROR: InexactError: Float64(0.0 + 1.0im)
[...]
```
println("eltype(dft) = $(eltype(dft))")
println("eltype(v) = $(eltype(v))")
println("eltype(dft.prod(v)) = $(eltype(dft.prod(v)))")
# dft * v # ERROR: expected Vector{Float64}
# Matrix(dft) # ERROR: tried to create a Matrix of Float64
```

## Limited memory BFGS and SR1

Two other useful operators are the Limited-Memory BFGS in forward and inverse form.
Expand Down
18 changes: 9 additions & 9 deletions src/LinearOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,10 @@ end


# Apply an operator to a vector.
function *(op :: AbstractLinearOperator, v :: AbstractVector)
function *(op :: AbstractLinearOperator{T}, v :: AbstractVector{S}) where {T,S}
size(v, 1) == size(op, 2) || throw(LinearOperatorException("shape mismatch"))
increase_nprod(op)
op.prod(v)
op.prod(v)::Vector{promote_type(T,S)}
end


Expand All @@ -251,14 +251,14 @@ end
Materialize an operator as a dense array using `op.ncol` products.
"""
function Base.Matrix(op :: AbstractLinearOperator)
function Base.Matrix(op :: AbstractLinearOperator{T}) where T
(m, n) = size(op)
A = Array{eltype(op)}(undef, m, n)
ei = zeros(eltype(op), n)
A = Array{T}(undef, m, n)
ei = zeros(T, n)
for i = 1 : n
ei[i] = 1
ei[i] = one(T)
A[:, i] = op * ei
ei[i] = 0
ei[i] = zero(T)
end
return A
end
Expand Down Expand Up @@ -531,8 +531,8 @@ Zero operator of size `nrow`-by-`ncol` and of data type `T` (defaults to
`Float64`).
"""
function opZeros(T :: DataType, nrow :: Int, ncol :: Int)
prod = @closure v -> zeros(T, nrow)
tprod = @closure u -> zeros(T, ncol)
prod = @closure v -> zeros(promote_type(T,eltype(v)), nrow)
tprod = @closure u -> zeros(promote_type(T,eltype(u)), ncol)
LinearOperator{T}(nrow, ncol, nrow == ncol, nrow == ncol, prod, tprod, tprod)
end

Expand Down
24 changes: 12 additions & 12 deletions src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ function show(io :: IO, op :: ConjugateLinearOperator)
show(io, op.parent)
end

function *(op :: AdjointLinearOperator, v :: AbstractVector)
length(v) == size(op.parent, 1) || throw(LinearOperatorException("shape mismatch"))
p = op.parent
function *(op :: AdjointLinearOperator{T}, v :: AbstractVector{S}) where {T,S}
p = op.parent::AbstractLinearOperator{T}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that AbstractLinearOperator{T} is too generic for the compiler.
I just replace it by :

p = op.parent::Union{LinearOperator{T}, PreallocatedLinearOperator{T}}

and increase_nctprod(p) plus length(v) == size(p, 1) stop allocating 16 bits at each call.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's too specific though. Maybe we need to parametrize AdjointLinearOperator with the type of the parent?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've included a new commit with the parametrized version.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Does it resolve the allocation issue?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like it does, but I'll wait for Alexis' tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works with my Krylov tests. 🤘 Thank you Abel! You can merge it and add a new tag.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before merging, would you mind checking that this doesn't recreate the hvcat issues from #97?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

N = 16

@time include("lops.jl")
  0.208034 seconds (621.17 k allocations: 31.138 MiB, 5.06% gc time)
  0.574005 seconds (1.23 M allocations: 64.762 MiB, 1.83% gc time)
5-element Array{Float64,1}:
 -10.027547809205767 
   6.387858545582767 
   4.945048065166462 
   9.668653838837873 
   3.2258245551624194

N = 25

@time include("lops.jl")
  0.212824 seconds (621.21 k allocations: 31.152 MiB, 4.82% gc time)
  0.561557 seconds (1.23 M allocations: 64.921 MiB, 1.83% gc time)
5-element Array{Float64,1}:
 10.628747752647548 
 -2.02582122181704  
 -4.198864112230334 
  3.0656881247622376
 10.586622803189744 

N = 50

@time include("lops.jl")
  0.221327 seconds (621.31 k allocations: 31.201 MiB, 4.93% gc time)
  0.592618 seconds (1.23 M allocations: 64.986 MiB, 1.84% gc time)
5-element Array{Float64,1}:
 -24.72968159374163  
 -17.87085267785233  
  21.676648685374687 
   8.68226076917717  
   2.8795559248799285

N = 100

@time include("lops.jl")
  0.213366 seconds (621.51 k allocations: 31.370 MiB, 4.85% gc time)
  0.586298 seconds (1.24 M allocations: 65.185 MiB, 1.76% gc time)
5-element Array{Float64,1}:
  -6.961226546652531 
  27.177113229951946 
  -3.5866280698379254
  23.01856931520228  
 -23.290794085006873 

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

length(v) == size(p, 1) || throw(LinearOperatorException("shape mismatch"))
ishermitian(p) && return p * v
if p.ctprod !== nothing
increase_nctprod(p)
return p.ctprod(v)
return p.ctprod(v)::Vector{promote_type(T,S)}
end
tprod = p.tprod
increment_tprod = true
Expand All @@ -94,16 +94,16 @@ function *(op :: AdjointLinearOperator, v :: AbstractVector)
else
increase_nprod(p)
end
return conj.(tprod(conj.(v)))
return conj.(tprod(conj.(v)))::Vector{promote_type(T,S)}
end

function *(op :: TransposeLinearOperator, v :: AbstractVector)
length(v) == size(op.parent, 1) || throw(LinearOperatorException("shape mismatch"))
p = op.parent
function *(op :: TransposeLinearOperator{T}, v :: AbstractVector{S}) where {T,S}
p = op.parent::AbstractLinearOperator{T}
length(v) == size(p, 1) || throw(LinearOperatorException("shape mismatch"))
issymmetric(p) && return p * v
if p.tprod !== nothing
increase_ntprod(p)
return p.tprod(v)
return p.tprod(v)::Vector{promote_type(T,S)}
end
increment_ctprod = true
ctprod = p.ctprod
Expand All @@ -120,12 +120,12 @@ function *(op :: TransposeLinearOperator, v :: AbstractVector)
else
increase_nprod(p)
end
return conj.(ctprod(conj.(v)))
return conj.(ctprod(conj.(v)))::Vector{promote_type(T,S)}
end

function *(op :: ConjugateLinearOperator, v :: AbstractVector)
function *(op :: ConjugateLinearOperator{T}, v :: AbstractVector{S}) where {T,S}
p = op.parent
return conj.(p * conj.(v))
return conj.(p * conj.(v))::Vector{promote_type(T,S)}
end

-(op :: AdjointLinearOperator) = adjoint(-op.parent)
Expand Down
2 changes: 1 addition & 1 deletion test/test_linop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ function test_linop()
@test A == Matrix(opC)
opF = LinearOperator(Float64, 2, 2, false, false, prod, tprod, ctprod) # The type is a lie
@test eltype(opF) == Float64
@test_throws InexactError Matrix(opF)
@test_throws TypeError Matrix(opF)
end

# Issue #80
Expand Down