Skip to content

Commit

Permalink
Parametrize Adjoint, Transpose and Conjugate
Browse files Browse the repository at this point in the history
  • Loading branch information
abelsiqueira committed Nov 18, 2019
1 parent 2a16cc3 commit 37aa757
Showing 1 changed file with 29 additions and 16 deletions.
45 changes: 29 additions & 16 deletions src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,31 @@ export AdjointLinearOperator, TransposeLinearOperator, ConjugateLinearOperator,
adjoint, transpose, conj

# From julialang:stdlib/LinearAlgebra/src/adjtrans.jl
struct AdjointLinearOperator{T} <: AbstractLinearOperator{T}
parent :: AbstractLinearOperator{T}
struct AdjointLinearOperator{T,S} <: AbstractLinearOperator{T}
parent :: S
function AdjointLinearOperator{T,S}(A :: S) where {T,S}
new(A)
end
end

struct TransposeLinearOperator{T} <: AbstractLinearOperator{T}
parent :: AbstractLinearOperator{T}
struct TransposeLinearOperator{T,S} <: AbstractLinearOperator{T}
parent :: S
function TransposeLinearOperator{T,S}(A :: S) where {T,S}
new(A)
end
end

struct ConjugateLinearOperator{T} <: AbstractLinearOperator{T}
parent :: AbstractLinearOperator{T}
struct ConjugateLinearOperator{T,S} <: AbstractLinearOperator{T}
parent :: S
function ConjugateLinearOperator{T,S}(A :: S) where {T,S}
new(A)
end
end

AdjointLinearOperator(A) = AdjointLinearOperator{eltype(A),typeof(A)}(A)
TransposeLinearOperator(A) = TransposeLinearOperator{eltype(A),typeof(A)}(A)
ConjugateLinearOperator(A) = ConjugateLinearOperator{eltype(A),typeof(A)}(A)

adjoint(A :: AbstractLinearOperator) = AdjointLinearOperator(A)
adjoint(A :: AdjointLinearOperator) = A.parent
transpose(A :: AbstractLinearOperator) = TransposeLinearOperator(A)
Expand Down Expand Up @@ -71,13 +84,13 @@ function show(io :: IO, op :: ConjugateLinearOperator)
show(io, op.parent)
end

function *(op :: AdjointLinearOperator{T}, v :: AbstractVector{S}) where {T,S}
p = op.parent::AbstractLinearOperator{T}
function *(op :: AdjointLinearOperator{T,S}, v :: AbstractVector{U}) where {T,S,U}
p = op.parent
length(v) == size(p, 1) || throw(LinearOperatorException("shape mismatch"))
ishermitian(p) && return p * v
if p.ctprod !== nothing
increase_nctprod(p)
return p.ctprod(v)::Vector{promote_type(T,S)}
return p.ctprod(v)::Vector{promote_type(T,U)}
end
tprod = p.tprod
increment_tprod = true
Expand All @@ -94,16 +107,16 @@ function *(op :: AdjointLinearOperator{T}, v :: AbstractVector{S}) where {T,S}
else
increase_nprod(p)
end
return conj.(tprod(conj.(v)))::Vector{promote_type(T,S)}
return conj.(tprod(conj.(v)))::Vector{promote_type(T,U)}
end

function *(op :: TransposeLinearOperator{T}, v :: AbstractVector{S}) where {T,S}
p = op.parent::AbstractLinearOperator{T}
function *(op :: TransposeLinearOperator{T,S}, v :: AbstractVector{U}) where {T,S,U}
p = op.parent
length(v) == size(p, 1) || throw(LinearOperatorException("shape mismatch"))
issymmetric(p) && return p * v
if p.tprod !== nothing
increase_ntprod(p)
return p.tprod(v)::Vector{promote_type(T,S)}
return p.tprod(v)::Vector{promote_type(T,U)}
end
increment_ctprod = true
ctprod = p.ctprod
Expand All @@ -120,12 +133,12 @@ function *(op :: TransposeLinearOperator{T}, v :: AbstractVector{S}) where {T,S}
else
increase_nprod(p)
end
return conj.(ctprod(conj.(v)))::Vector{promote_type(T,S)}
return conj.(ctprod(conj.(v)))::Vector{promote_type(T,U)}
end

function *(op :: ConjugateLinearOperator{T}, v :: AbstractVector{S}) where {T,S}
function *(op :: ConjugateLinearOperator{T,S}, v :: AbstractVector{U}) where {T,S,U}
p = op.parent
return conj.(p * conj.(v))::Vector{promote_type(T,S)}
return conj.(p * conj.(v))::Vector{promote_type(T,U)}
end

-(op :: AdjointLinearOperator) = adjoint(-op.parent)
Expand Down

0 comments on commit 37aa757

Please sign in to comment.