From 3071f34b16059ea7279c2af5694f902155d0806e Mon Sep 17 00:00:00 2001 From: Sebastian Stock <42280794+sostock@users.noreply.github.com> Date: Tue, 23 Mar 2021 08:53:54 +0100 Subject: [PATCH] Preserve `Symmetric`/`Hermitian` shape in more cases (#40126) --- stdlib/LinearAlgebra/src/symmetric.jl | 4 ++++ stdlib/LinearAlgebra/test/symmetric.jl | 25 +++++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl index 8f1073817b71d..e206cfe7178d3 100644 --- a/stdlib/LinearAlgebra/src/symmetric.jl +++ b/stdlib/LinearAlgebra/src/symmetric.jl @@ -482,6 +482,10 @@ for f in (:+, :-) @eval begin $f(A::Hermitian, B::Symmetric{<:Real}) = $f(A, Hermitian(parent(B), sym_uplo(B.uplo))) $f(A::Symmetric{<:Real}, B::Hermitian) = $f(Hermitian(parent(A), sym_uplo(A.uplo)), B) + $f(A::SymTridiagonal, B::Symmetric) = Symmetric($f(A, B.data), sym_uplo(B.uplo)) + $f(A::Symmetric, B::SymTridiagonal) = Symmetric($f(A.data, B), sym_uplo(A.uplo)) + $f(A::SymTridiagonal{<:Real}, B::Hermitian) = Hermitian($f(A, B.data), sym_uplo(B.uplo)) + $f(A::Hermitian, B::SymTridiagonal{<:Real}) = Hermitian($f(A.data, B), sym_uplo(A.uplo)) end end diff --git a/stdlib/LinearAlgebra/test/symmetric.jl b/stdlib/LinearAlgebra/test/symmetric.jl index 7d99dd32889fd..d23eecb5be46e 100644 --- a/stdlib/LinearAlgebra/test/symmetric.jl +++ b/stdlib/LinearAlgebra/test/symmetric.jl @@ -727,4 +727,29 @@ end end end +@testset "Addition/subtraction with SymTridiagonal" begin + TR = SymTridiagonal(randn(Float64,5), randn(Float64,4)) + TC = SymTridiagonal(randn(ComplexF64,5), randn(ComplexF64,4)) + SR = Symmetric(randn(Float64,5,5)) + SC = Symmetric(randn(ComplexF64,5,5)) + HR = Hermitian(randn(Float64,5,5)) + HC = Hermitian(randn(ComplexF64,5,5)) + for op = (+,-) + for T = (TR, TC), S = (SR, SC) + @test op(T, S) == op(Array(T), S) + @test op(S, T) == op(S, Array(T)) + @test op(T, S) isa Symmetric + @test op(S, T) isa Symmetric + end + for H = (HR, HC) + for T = (TR, TC) + @test op(T, H) == op(Array(T), H) + @test op(H, T) == op(H, Array(T)) + end + @test op(TR, H) isa Hermitian + @test op(H, TR) isa Hermitian + end + end +end + end # module TestSymmetric