diff --git a/stdlib/LinearAlgebra/src/tridiag.jl b/stdlib/LinearAlgebra/src/tridiag.jl index 235fa5636877a..da15e1680f53a 100644 --- a/stdlib/LinearAlgebra/src/tridiag.jl +++ b/stdlib/LinearAlgebra/src/tridiag.jl @@ -70,9 +70,14 @@ julia> A[2,1] SymTridiagonal(dv::V, ev::V) where {T,V<:AbstractVector{T}} = SymTridiagonal{T}(dv, ev) SymTridiagonal{T}(dv::V, ev::V) where {T,V<:AbstractVector{T}} = SymTridiagonal{T,V}(dv, ev) function SymTridiagonal{T}(dv::AbstractVector, ev::AbstractVector) where {T} - SymTridiagonal(convert(AbstractVector{T}, dv)::AbstractVector{T}, - convert(AbstractVector{T}, ev)::AbstractVector{T}) + d = convert(AbstractVector{T}, dv)::AbstractVector{T} + e = convert(AbstractVector{T}, ev)::AbstractVector{T} + typeof(d) == typeof(e) ? + SymTridiagonal{T}(d, e) : + throw(ArgumentError("diagonal vectors needed to be convertible to same type")) end +SymTridiagonal(d::AbstractVector{T}, e::AbstractVector{S}) where {T,S} = + SymTridiagonal{promote_type(T, S)}(d, e) """ SymTridiagonal(A::AbstractMatrix) @@ -513,11 +518,21 @@ julia> Tridiagonal(dl, d, du) """ Tridiagonal(dl::V, d::V, du::V) where {T,V<:AbstractVector{T}} = Tridiagonal{T,V}(dl, d, du) Tridiagonal(dl::V, d::V, du::V, du2::V) where {T,V<:AbstractVector{T}} = Tridiagonal{T,V}(dl, d, du, du2) +Tridiagonal(dl::AbstractVector{T}, d::AbstractVector{S}, du::AbstractVector{U}) where {T,S,U} = + Tridiagonal{promote_type(T, S, U)}(dl, d, du) +Tridiagonal(dl::AbstractVector{T}, d::AbstractVector{S}, du::AbstractVector{U}, du2::AbstractVector{V}) where {T,S,U,V} = + Tridiagonal{promote_type(T, S, U, V)}(dl, d, du, du2) function Tridiagonal{T}(dl::AbstractVector, d::AbstractVector, du::AbstractVector) where {T} - Tridiagonal(map(x->convert(AbstractVector{T}, x), (dl, d, du))...) + l, d, u = map(x->convert(AbstractVector{T}, x), (dl, d, du)) + typeof(l) == typeof(d) == typeof(u) ? + Tridiagonal(l, d, u) : + throw(ArgumentError("diagonal vectors needed to be convertible to same type")) end function Tridiagonal{T}(dl::AbstractVector, d::AbstractVector, du::AbstractVector, du2::AbstractVector) where {T} - Tridiagonal(map(x->convert(AbstractVector{T}, x), (dl, d, du, du2))...) + l, d, u, u2 = map(x->convert(AbstractVector{T}, x), (dl, d, du, du2)) + typeof(l) == typeof(d) == typeof(u) == typeof(u2) ? + Tridiagonal(l, d, u, u2) : + throw(ArgumentError("diagonal vectors needed to be convertible to same type")) end """ diff --git a/stdlib/LinearAlgebra/test/tridiag.jl b/stdlib/LinearAlgebra/test/tridiag.jl index 2d846683e38c3..0c07e5b160c58 100644 --- a/stdlib/LinearAlgebra/test/tridiag.jl +++ b/stdlib/LinearAlgebra/test/tridiag.jl @@ -95,12 +95,12 @@ end @test isa(ST, SymTridiagonal{elty,Vector{elty}}) TT = Tridiagonal{elty,Vector{elty}}(GenericArray(dl), d, GenericArray(dl)) @test isa(TT, Tridiagonal{elty,Vector{elty}}) - @test_throws MethodError SymTridiagonal(d, GenericArray(dl)) - @test_throws MethodError SymTridiagonal(GenericArray(d), dl) - @test_throws MethodError Tridiagonal(GenericArray(dl), d, GenericArray(dl)) - @test_throws MethodError Tridiagonal(dl, GenericArray(d), dl) - @test_throws MethodError SymTridiagonal{elty}(d, GenericArray(dl)) - @test_throws MethodError Tridiagonal{elty}(GenericArray(dl), d,GenericArray(dl)) + @test_throws ArgumentError SymTridiagonal(d, GenericArray(dl)) + @test_throws ArgumentError SymTridiagonal(GenericArray(d), dl) + @test_throws ArgumentError Tridiagonal(GenericArray(dl), d, GenericArray(dl)) + @test_throws ArgumentError Tridiagonal(dl, GenericArray(d), dl) + @test_throws ArgumentError SymTridiagonal{elty}(d, GenericArray(dl)) + @test_throws ArgumentError Tridiagonal{elty}(GenericArray(dl), d,GenericArray(dl)) STI = SymTridiagonal([1,2,3,4], [1,2,3]) TTI = Tridiagonal([1,2,3], [1,2,3,4], [1,2,3]) TTI2 = Tridiagonal([1,2,3], [1,2,3,4], [1,2,3], [1,2]) @@ -505,6 +505,11 @@ end @test SymTridiagonal([1, 2], [0])^3 == [1 0; 0 8] end +@testset "Issue #48505" begin + @test SymTridiagonal([1,2,3],[4,5.0]) == [1.0 4.0 0.0; 4.0 2.0 5.0; 0.0 5.0 3.0] + @test Tridiagonal([1, 2], [4, 5, 1], [6.0, 7]) == [4.0 6.0 0.0; 1.0 5.0 7.0; 0.0 2.0 1.0] +end + @testset "convert for SymTridiagonal" begin STF32 = SymTridiagonal{Float32}(fill(1f0, 5), fill(1f0, 4)) @test convert(SymTridiagonal{Float64}, STF32)::SymTridiagonal{Float64} == STF32