From b7809051397dddc70408bb4670b8094bb126c157 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 5 Apr 2021 10:37:56 -0700 Subject: [PATCH] Fix real matrix sqrt and log for edge cases (#40144) * Add failing tests * Improve 2x2 real sqrt for scalar diagonal * Avoid sylvester if zero is a solution * Avoid unnecessary sqrt * Use correct variable names * Avoid erroring due to NaNs * Exactly adapt Higham's algorithm, with hypot * Remove unreachable branch * Add failing tests * Handle overflow/underflow * Invert instead of dividing * Apply d=0 constraint from standardized real schur form * Handle underflow * Avoid underflow/overflow in log diagonal blocks * Add tests for log underflow/overflow --- stdlib/LinearAlgebra/src/triangular.jl | 46 ++++++++++++++++---------- stdlib/LinearAlgebra/test/dense.jl | 32 ++++++++++++++++++ 2 files changed, 61 insertions(+), 17 deletions(-) diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index cfc4e948d8d3d..1752d29a9ec65 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -2156,12 +2156,17 @@ end # 35(4), (2013), C394–C410. # Eq. 6.1 Base.@propagate_inbounds function _log_diag_block_2x2!(A, A0) - a, b, c, d = A0[1,1], A0[1,2], A0[2,1], A0[2,2] - bc = b * c - s = sqrt(-bc) + a, b, c = A0[1,1], A0[1,2], A0[2,1] + # avoid underflow/overflow for large/small b and c + s = sqrt(abs(b)) * sqrt(abs(c)) θ = atan(s, a) t = θ / s - a1 = log(a^2 - bc) / 2 + au = abs(a) + if au > s + a1 = log1p((s / au)^2) / 2 + log(au) + else + a1 = log1p((au / s)^2) / 2 + log(s) + end A[1,1] = a1 A[2,1] = c*t A[1,2] = b*t @@ -2435,25 +2440,30 @@ function _sqrt_quasitriu_offdiag_block!(R, A) return R end +# real square root of 2x2 diagonal block of quasi-triangular matrix from real Schur +# decomposition. Eqs 6.8-6.9 and Algorithm 6.5 of +# Higham, 2008, "Functions of Matrices: Theory and Computation", SIAM. Base.@propagate_inbounds function _sqrt_real_2x2!(R, A) - a11, a21, a12, a22 = A[1, 1], A[2, 1], A[1, 2], A[2, 2] - θ = (a11 + a22) / 2 - μ² = -(a11 - a22)^2 / 4 - a21 * a12 - μ = sqrt(μ²) - if θ > 0 - α = sqrt((sqrt(θ^2 + μ²) + θ) / 2) - else - α = μ / sqrt(2 * (sqrt(θ^2 + μ²) - θ)) - end + # in the real Schur form, A[1, 1] == A[2, 2], and A[2, 1] * A[1, 2] < 0 + θ, a21, a12 = A[1, 1], A[2, 1], A[1, 2] + # avoid overflow/underflow of μ + # for real sqrt, |d| ≤ 2 max(|a12|,|a21|) + μ = sqrt(abs(a12)) * sqrt(abs(a21)) + α = _real_sqrt(θ, μ) c = 2α - d = α - θ / c - R[1, 1] = a11 / c + d + R[1, 1] = α R[2, 1] = a21 / c R[1, 2] = a12 / c - R[2, 2] = a22 / c + d + R[2, 2] = α return R end +# real part of square root of θ+im*μ +@inline function _real_sqrt(θ, μ) + t = sqrt((abs(θ) + hypot(θ, μ)) / 2) + return θ ≥ 0 ? t : μ / 2t +end + Base.@propagate_inbounds function _sqrt_quasitriu_offdiag_block_1x1!(R, A, i, j) Rii = R[i, i] Rjj = R[j, j] @@ -2522,7 +2532,9 @@ Base.@propagate_inbounds function _sqrt_quasitriu_offdiag_block_2x2!(R, A, i, j) Rii = @view R[irange, irange] Rjj = @view R[jrange, jrange] Rij = @view R[irange, jrange] - _sylvester_2x2!(Rii, Rjj, Rij) + if !iszero(Rij) && !all(isnan, Rij) + _sylvester_2x2!(Rii, Rjj, Rij) + end return R end diff --git a/stdlib/LinearAlgebra/test/dense.jl b/stdlib/LinearAlgebra/test/dense.jl index 0c86fec9242b5..1cd9c8e6898b2 100644 --- a/stdlib/LinearAlgebra/test/dense.jl +++ b/stdlib/LinearAlgebra/test/dense.jl @@ -889,6 +889,38 @@ end end end +@testset "issue #40141" begin + x = [-1 -eps() 0 0; eps() -1 0 0; 0 0 -1 -eps(); 0 0 eps() -1] + @test sqrt(x)^2 ≈ x + + x2 = [-1 -eps() 0 0; 3eps() -1 0 0; 0 0 -1 -3eps(); 0 0 eps() -1] + @test sqrt(x2)^2 ≈ x2 + + x3 = [-1 -eps() 0 0; eps() -1 0 0; 0 0 -1 -eps(); 0 0 eps() Inf] + @test all(isnan, sqrt(x3)) + + # test overflow/underflow handled + x4 = [0 -1e200; 1e200 0] + @test sqrt(x4)^2 ≈ x4 + + x5 = [0 -1e-200; 1e-200 0] + @test sqrt(x5)^2 ≈ x5 + + x6 = [1.0 1e200; -1e-200 1.0] + @test sqrt(x6)^2 ≈ x6 +end + +@testset "matrix logarithm block diagonal underflow/overflow" begin + x1 = [0 -1e200; 1e200 0] + @test exp(log(x1)) ≈ x1 + + x2 = [0 -1e-200; 1e-200 0] + @test exp(log(x2)) ≈ x2 + + x3 = [1.0 1e200; -1e-200 1.0] + @test exp(log(x3)) ≈ x3 +end + @testset "issue #7181" begin A = [ 1 5 9 2 6 10