Skip to content

Commit

Permalink
Annotate a few types to improve type inference and thus allocation
Browse files Browse the repository at this point in the history
Necessary for Krylov.
  • Loading branch information
abelsiqueira committed Nov 16, 2019
1 parent 49398fe commit 1042fa7
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 22 deletions.
18 changes: 9 additions & 9 deletions src/LinearOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,10 @@ end


# Apply an operator to a vector.
function *(op :: AbstractLinearOperator, v :: AbstractVector)
function *(op :: AbstractLinearOperator{T}, v :: AbstractVector{S}) where {T,S}
size(v, 1) == size(op, 2) || throw(LinearOperatorException("shape mismatch"))
increase_nprod(op)
op.prod(v)
op.prod(v)::Vector{promote_type(T,S)}
end


Expand All @@ -251,14 +251,14 @@ end
Materialize an operator as a dense array using `op.ncol` products.
"""
function Base.Matrix(op :: AbstractLinearOperator)
function Base.Matrix(op :: AbstractLinearOperator{T}) where T
(m, n) = size(op)
A = Array{eltype(op)}(undef, m, n)
ei = zeros(eltype(op), n)
A = Array{T}(undef, m, n)
ei = zeros(T, n)
for i = 1 : n
ei[i] = 1
ei[i] = one(T)
A[:, i] = op * ei
ei[i] = 0
ei[i] = zero(T)
end
return A
end
Expand Down Expand Up @@ -531,8 +531,8 @@ Zero operator of size `nrow`-by-`ncol` and of data type `T` (defaults to
`Float64`).
"""
function opZeros(T :: DataType, nrow :: Int, ncol :: Int)
prod = @closure v -> zeros(T, nrow)
tprod = @closure u -> zeros(T, ncol)
prod = @closure v -> zeros(promote_type(T,eltype(v)), nrow)
tprod = @closure u -> zeros(promote_type(T,eltype(u)), ncol)
LinearOperator{T}(nrow, ncol, nrow == ncol, nrow == ncol, prod, tprod, tprod)
end

Expand Down
24 changes: 12 additions & 12 deletions src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ function show(io :: IO, op :: ConjugateLinearOperator)
show(io, op.parent)
end

function *(op :: AdjointLinearOperator, v :: AbstractVector)
length(v) == size(op.parent, 1) || throw(LinearOperatorException("shape mismatch"))
p = op.parent
function *(op :: AdjointLinearOperator{T}, v :: AbstractVector{S}) where {T,S}
p = op.parent::AbstractLinearOperator{T}
length(v) == size(p, 1) || throw(LinearOperatorException("shape mismatch"))
ishermitian(p) && return p * v
if p.ctprod !== nothing
increase_nctprod(p)
return p.ctprod(v)
return p.ctprod(v)::Vector{promote_type(T,S)}
end
tprod = p.tprod
increment_tprod = true
Expand All @@ -94,16 +94,16 @@ function *(op :: AdjointLinearOperator, v :: AbstractVector)
else
increase_nprod(p)
end
return conj.(tprod(conj.(v)))
return conj.(tprod(conj.(v)))::Vector{promote_type(T,S)}
end

function *(op :: TransposeLinearOperator, v :: AbstractVector)
length(v) == size(op.parent, 1) || throw(LinearOperatorException("shape mismatch"))
p = op.parent
function *(op :: TransposeLinearOperator{T}, v :: AbstractVector{S}) where {T,S}
p = op.parent::AbstractLinearOperator{T}
length(v) == size(p, 1) || throw(LinearOperatorException("shape mismatch"))
issymmetric(p) && return p * v
if p.tprod !== nothing
increase_ntprod(p)
return p.tprod(v)
return p.tprod(v)::Vector{promote_type(T,S)}
end
increment_ctprod = true
ctprod = p.ctprod
Expand All @@ -120,12 +120,12 @@ function *(op :: TransposeLinearOperator, v :: AbstractVector)
else
increase_nprod(p)
end
return conj.(ctprod(conj.(v)))
return conj.(ctprod(conj.(v)))::Vector{promote_type(T,S)}
end

function *(op :: ConjugateLinearOperator, v :: AbstractVector)
function *(op :: ConjugateLinearOperator{T}, v :: AbstractVector{S}) where {T,S}
p = op.parent
return conj.(p * conj.(v))
return conj.(p * conj.(v))::Vector{promote_type(T,S)}
end

-(op :: AdjointLinearOperator) = adjoint(-op.parent)
Expand Down
2 changes: 1 addition & 1 deletion test/test_linop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ function test_linop()
@test A == Matrix(opC)
opF = LinearOperator(Float64, 2, 2, false, false, prod, tprod, ctprod) # The type is a lie
@test eltype(opF) == Float64
@test_throws InexactError Matrix(opF)
@test_throws TypeError Matrix(opF)
end

# Issue #80
Expand Down

0 comments on commit 1042fa7

Please sign in to comment.