-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
put LinearAlgebra and SparseArrays in extensions (#841)
* put LinearAlgebra and SparseArrays in extensions * typo * fix folder path * a typo * spelling * imports * imports
- Loading branch information
Showing
10 changed files
with
197 additions
and
152 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
11 changes: 11 additions & 0 deletions
11
ext/DimensionalDataLinearAlgebraExt/DimensionalDataLinearAlgebraExt.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
module DimensionalDataLinearAlgebraExt | ||
|
||
using DimensionalData | ||
using LinearAlgebra | ||
|
||
const DD = DimensionalData | ||
|
||
include("matmul.jl") | ||
include("methods.jl") | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
using LinearAlgebra: AbstractTriangular, AbstractRotation | ||
|
||
using DimensionalData: AnonDim, strict_matmul, comparedims | ||
|
||
# Copied from symmetric.jl | ||
const AdjTransVec = Union{Transpose{<:Any,<:AbstractVector},Adjoint{<:Any,<:AbstractVector}} | ||
const RealHermSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}} | ||
const RealHermSymComplexHerm{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{Complex{T},S}} | ||
const RealHermSymComplexSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Symmetric{Complex{T},S}} | ||
|
||
# Ambiguities | ||
for (a, b) in ( | ||
(AbstractDimVector, AbstractDimMatrix), | ||
(AbstractDimMatrix, AbstractDimVector), | ||
(AbstractDimMatrix, AbstractDimMatrix), | ||
(AbstractDimMatrix, AbstractVector), | ||
(AbstractDimVector, AbstractMatrix), | ||
(AbstractDimMatrix, AbstractMatrix), | ||
(AbstractMatrix, AbstractDimVector), | ||
(AbstractVector, AbstractDimMatrix), | ||
(AbstractMatrix, AbstractDimMatrix), | ||
(AbstractDimVector, Adjoint{<:Any,<:AbstractMatrix}), | ||
(AbstractDimVector, AdjTransVec), | ||
(AbstractDimVector, Transpose{<:Any,<:AbstractMatrix}), | ||
(AbstractDimMatrix, Diagonal), | ||
(AbstractDimMatrix, Adjoint{<:Any,<:RealHermSymComplexHerm}), | ||
(AbstractDimMatrix, Adjoint{<:Any,<:AbstractTriangular}), | ||
(AbstractDimMatrix, Transpose{<:Any,<:AbstractTriangular}), | ||
(AbstractDimMatrix, Transpose{<:Any,<:RealHermSymComplexSym}), | ||
(AbstractDimMatrix, AbstractTriangular), | ||
(Diagonal, AbstractDimVector), | ||
(Diagonal, AbstractDimMatrix), | ||
(Transpose{<:Any,<:AbstractTriangular}, AbstractDimVector), | ||
(Transpose{<:Any,<:AbstractTriangular}, AbstractDimMatrix), | ||
(Transpose{<:Any,<:AbstractVector}, AbstractDimVector), | ||
(Transpose{<:Real,<:AbstractVector}, AbstractDimVector), | ||
(Transpose{<:Any,<:AbstractVector}, AbstractDimMatrix), | ||
(Transpose{<:Any,<:RealHermSymComplexSym}, AbstractDimMatrix), | ||
(Transpose{<:Any,<:RealHermSymComplexSym}, AbstractDimVector), | ||
(AbstractTriangular, AbstractDimVector), | ||
(AbstractTriangular, AbstractDimMatrix), | ||
(Adjoint{<:Any,<:AbstractTriangular}, AbstractDimVector), | ||
(Adjoint{<:Any,<:AbstractVector}, AbstractDimMatrix), | ||
(Adjoint{<:Any,<:RealHermSymComplexHerm}, AbstractDimMatrix), | ||
(Adjoint{<:Any,<:AbstractTriangular}, AbstractDimMatrix), | ||
(Adjoint{<:Number,<:AbstractVector}, AbstractDimVector{<:Number}), | ||
(AdjTransVec, AbstractDimVector), | ||
(Adjoint{<:Any,<:RealHermSymComplexHerm}, AbstractDimVector), | ||
) | ||
@eval Base.:*(A::$a, B::$b) = _rebuildmul(A, B) | ||
end | ||
|
||
Base.:*(A::AbstractDimVector, B::Adjoint{T,<:AbstractRotation}) where T = _rebuildmul(A, B) | ||
Base.:*(A::Adjoint{T,<:AbstractRotation}, B::AbstractDimMatrix) where T = _rebuildmul(A, B) | ||
Base.:*(A::Transpose{<:Any,<:AbstractMatrix{T}}, B::AbstractDimArray{S,1}) where {T,S} = _rebuildmul(A, B) | ||
Base.:*(A::Adjoint{<:Any,<:AbstractMatrix{T}}, B::AbstractDimArray{S,1}) where {T,S} = _rebuildmul(A, B) | ||
|
||
function _rebuildmul(A::AbstractDimVector, B::AbstractDimMatrix) | ||
# Vector has no dim 2 to compare | ||
rebuild(A, parent(A) * parent(B), (first(dims(A)), last(dims(B)),)) | ||
end | ||
function _rebuildmul(A::AbstractDimMatrix, B::AbstractDimVector) | ||
_comparedims_mul(A, B) | ||
rebuild(A, parent(A) * parent(B), (first(dims(A)),)) | ||
end | ||
function _rebuildmul(A::AbstractDimMatrix, B::AbstractDimMatrix) | ||
_comparedims_mul(A, B) | ||
rebuild(A, parent(A) * parent(B), (first(dims(A)), last(dims(B)))) | ||
end | ||
function _rebuildmul(A::AbstractDimVector, B::AbstractMatrix) | ||
rebuild(A, parent(A) * B, (first(dims(A)), AnonDim(Base.OneTo(size(B, 2))))) | ||
end | ||
function _rebuildmul(A::AbstractDimMatrix, B::AbstractVector) | ||
newdata = parent(A) * B | ||
if newdata isa AbstractArray | ||
rebuild(A, parent(A) * B, (first(dims(A)),)) | ||
else | ||
newdata | ||
end | ||
end | ||
function _rebuildmul(A::AbstractDimMatrix, B::AbstractMatrix) | ||
rebuild(A, parent(A) * B, (first(dims(A)), AnonDim(Base.OneTo(size(B, 2))))) | ||
end | ||
function _rebuildmul(A::AbstractVector, B::AbstractDimMatrix) | ||
rebuild(B, A * parent(B), (AnonDim(Base.OneTo(size(A, 1))), last(dims(B)))) | ||
end | ||
function _rebuildmul(A::AbstractMatrix, B::AbstractDimVector) | ||
newdata = A * parent(B) | ||
if newdata isa AbstractArray | ||
rebuild(B, A * parent(B), (AnonDim(Base.OneTo(1)),)) | ||
else | ||
newdata | ||
end | ||
end | ||
function _rebuildmul(A::AbstractMatrix, B::AbstractDimMatrix) | ||
rebuild(B, A * parent(B), (AnonDim(Base.OneTo(size(A, 1))), last(dims(B)))) | ||
end | ||
|
||
function _comparedims_mul(a, b) | ||
# Dont need to compare length if we compare values | ||
isstrict = strict_matmul() | ||
comparedims(last(dims(a)), first(dims(b)); | ||
order=isstrict, val=isstrict, length=false | ||
) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Ambiguity | ||
Base.copyto!(dst::AbstractDimArray{T,2} where T, src::LinearAlgebra.AbstractQ) = | ||
(copyto!(parent(dst), src); dst) | ||
|
||
# We need to override copy_similar because our `similar` doesn't work with size changes | ||
# Fixed in Base in https://github.com/JuliaLang/julia/pull/53210 | ||
LinearAlgebra.copy_similar(A::AbstractDimArray, ::Type{T}) where {T} = copyto!(similar(A, T), A) | ||
|
||
# See methods.jl | ||
@eval begin | ||
@inline LinearAlgebra.Transpose(A::AbstractDimArray{<:Any,2}) = | ||
rebuild(A, LinearAlgebra.Transpose(parent(A)), reverse(dims(A))) | ||
@inline LinearAlgebra.Transpose(A::AbstractDimArray{<:Any,1}) = | ||
rebuild(A, LinearAlgebra.Transpose(parent(A)), (AnonDim(NoLookup(Base.OneTo(1))), dims(A)...)) | ||
@inline function LinearAlgebra.Transpose(s::AbstractDimStack) | ||
maplayers(s) do l | ||
ndims(l) > 1 ? LinearAlgebra.Transpose(l) : l | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
module DimensionalDataSparseArraysExt | ||
|
||
using DimensionalData | ||
using SparseArrays | ||
|
||
# Ambiguity | ||
Base.copyto!(dst::AbstractDimArray{T,2}, src::SparseArrays.CHOLMOD.Dense{T}) where T<:Union{Float64,ComplexF64} = | ||
(copyto!(parent(dst), src); dst) | ||
Base.copyto!(dst::AbstractDimArray{T}, src::SparseArrays.CHOLMOD.Dense{T}) where T<:Union{Float64,ComplexF64} = | ||
(copyto!(parent(dst), src); dst) | ||
Base.copyto!(dst::DimensionalData.AbstractDimArray, src::SparseArrays.CHOLMOD.Dense) = | ||
(copyto!(parent(dst), src); dst) | ||
Base.copyto!(dst::AbstractDimArray{T,2} where T, src::SparseArrays.AbstractSparseMatrixCSC) = | ||
(copyto!(parent(dst), src); dst) | ||
Base.copyto!(dst::SparseArrays.AbstractCompressedVector, src::AbstractDimArray{T, 1} where T) = | ||
(copyto!(dst, parent(src)); dst) | ||
|
||
function Base.copyto!( | ||
dst::AbstractDimArray{<:Any,2}, | ||
dst_i::CartesianIndices{2, R} where R<:Tuple{OrdinalRange{Int64, Int64}, OrdinalRange{Int64, Int64}}, | ||
src::SparseArrays.AbstractSparseMatrixCSC{<:Any}, | ||
src_i::CartesianIndices{2, R} where R<:Tuple{OrdinalRange{Int64, Int64}, OrdinalRange{Int64, Int64}} | ||
) | ||
copyto!(parent(dst), dst_i, src, src_i) | ||
return dst | ||
end | ||
Base.copy!(dst::SparseArrays.AbstractCompressedVector{T}, src::AbstractDimArray{T, 1}) where T = | ||
(copy!(dst, parent(src)); dst) | ||
Base.copy!(dst::SparseArrays.SparseVector, src::AbstractDimArray{T,1}) where T = | ||
(copy!(dst, parent(src)); dst) | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.