Skip to content

Commit

Permalink
bit of cleanup of relative entropy cone (#838)
Browse files Browse the repository at this point in the history
Co-authored-by: Chris Coey <chriscoey@users.noreply.github.com>
  • Loading branch information
araujoms and chriscoey authored Apr 12, 2024
1 parent 4659a7e commit 96746e6
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 13 deletions.
17 changes: 7 additions & 10 deletions src/Cones/epitrrelentropytri.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ function update_feas(
VH = Hermitian(cone.V, :U)
WH = Hermitian(cone.W, :U)
if isposdef(VH) && isposdef(WH)
# TODO use LAPACK syev! instead of syevr! for efficiency
V_fact = cone.V_fact = eigen(VH)
W_fact = cone.W_fact = eigen(WH)
if isposdef(V_fact) && isposdef(W_fact)
Expand Down Expand Up @@ -345,7 +344,7 @@ function hess_prod!(prod::AbstractVecOrMat, arr::AbstractVecOrMat, cone::EpiTrRe
@views mul!(temp2[:, k], cone.Wsim_Δ3[:, :, k], Varr_simV[:, k])
end
@. temp = temp2 + temp2'
# destroys arr_W_mat
# overwrites arr_W_mat
Warr_simV = spectral_outer!(arr_W_mat, V_vecs', Hermitian(arr_W_mat, :U), temp2)
@. temp += Warr_simV * cone.Δ2_V
@. temp /= -z
Expand Down Expand Up @@ -465,8 +464,8 @@ function dder3(
V_part_2 = d3WlogVdV
@. V_part_2 += diff_dot_V_VW + diff_dot_V_VW'
mul!(V_part_2, V_dir_sim, V_dir_sim', true, zi)
mul!(mat, Hermitian(V_part_2, :U), V_vecs')
mul!(V_part_1, V_vecs, mat, true, zi)
mul!(mat, V_vecs, Hermitian(V_part_2, :U))
mul!(V_part_1, mat, V_vecs', true, zi)
@views dder3_V = dder3[V_idxs]
smat_to_svec!(dder3_V, V_part_1, rt2)
@. dder3_V += const1 * dzdV
Expand All @@ -479,8 +478,8 @@ function dder3(
ldiv!(Diagonal(W_λ), W_dir_sim)
W_part_2 = diff_dot_W_WW
mul!(W_part_2, W_dir_sim, W_dir_sim', true, -zi)
mul!(mat, Hermitian(W_part_2, :U), W_vecs')
mul!(W_part_1, W_vecs, mat, true, zi)
mul!(mat, W_vecs, Hermitian(W_part_2, :U))
mul!(W_part_1, mat, W_vecs', true, zi)
@views dder3_W = dder3[W_idxs]
smat_to_svec!(dder3_W, W_part_1, rt2)
@. dder3_W += const1 * dzdW
Expand Down Expand Up @@ -566,8 +565,7 @@ function d2zdV2!(
end
# mat2 = vecs * (mat3 + mat3) * vecs'
@. mat2 = mat3 + mat3'
mul!(mat3, Hermitian(mat2, :U), V)
mul!(mat2, V', mat3)
spectral_outer!(mat2, V', Hermitian(mat2, :U), mat3)
@views smat_to_svec!(d2zdV2[:, col_idx], mat2, rt2)
col_idx += 1
end
Expand All @@ -577,8 +575,7 @@ function d2zdV2!(
@views mul!(mat3[:, k], Wsim_Δ3[:, :, k], mat2[:, k])
end
@. mat2 = mat3 + mat3'
mul!(mat3, Hermitian(mat2, :U), V)
mul!(mat2, V', mat3)
spectral_outer!(mat2, V', Hermitian(mat2, :U), mat3)
@views smat_to_svec!(d2zdV2[:, col_idx], mat2, rt2)
col_idx += 1
end
Expand Down
28 changes: 25 additions & 3 deletions src/linearalgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ end

function spectral_outer!(
mat::AbstractMatrix{T},
vecs::Union{Matrix{T}, Adjoint{T, Matrix{T}}},
vecs::Matrix{T},
symm::Symmetric{T},
temp::Matrix{T},
) where {T <: Real}
Expand All @@ -141,17 +141,39 @@ function spectral_outer!(
return mat
end

function spectral_outer!(
mat::AbstractMatrix{T},
vecs::Adjoint{T, Matrix{T}},
symm::Symmetric{T},
temp::Matrix{T},
) where {T <: Real}
mul!(temp, symm, vecs')
mul!(mat, vecs, temp)
return mat
end

function spectral_outer!(
mat::AbstractMatrix{R},
vecs::Union{Matrix{R}, Adjoint{R, Matrix{R}}},
vecs::Matrix{R},
symm::Hermitian{R},
temp::Matrix{R},
) where {T <: Real, R <: RealOrComplex{T}}
) where {R <: RealOrComplex}
mul!(temp, vecs, symm)
mul!(mat, temp, vecs')
return mat
end

function spectral_outer!(
mat::AbstractMatrix{R},
vecs::Adjoint{R, Matrix{R}},
symm::Hermitian{R},
temp::Matrix{R},
) where {R <: RealOrComplex}
mul!(temp, symm, vecs')
mul!(mat, vecs, temp)
return mat
end

#=
nonsymmetric square: LU
=#
Expand Down

0 comments on commit 96746e6

Please sign in to comment.