Skip to content

Commit

Permalink
Allow generators and iterators (#194)
Browse files Browse the repository at this point in the history
Co-authored-by: Milan Bouchet-Valat <nalimilan@club.fr>
  • Loading branch information
dkarrasch and nalimilan authored Dec 18, 2020
1 parent be3a901 commit f6ee353
Show file tree
Hide file tree
Showing 8 changed files with 490 additions and 222 deletions.
9 changes: 9 additions & 0 deletions src/Distances.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ export
rmsd,
nrmsd

if VERSION < v"1.2-"
import Base: has_offset_axes
require_one_based_indexing(A...) =
!has_offset_axes(A...) ||
throw(ArgumentError("offset arrays are not supported but got an array with index other than 1"))
else
import Base: require_one_based_indexing
end

include("common.jl")
include("generic.jl")
include("metrics.jl")
Expand Down
47 changes: 31 additions & 16 deletions src/bhattacharyya.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,54 @@ struct HellingerDist <: Metric end

# Bhattacharyya coefficient

function bhattacharyya_coeff(a::AbstractVector{T}, b::AbstractVector{T}) where {T <: Number}
if length(a) != length(b)
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
function bhattacharyya_coeff(a, b)
n = length(a)
if n != length(b)
throw(DimensionMismatch("first argument has length $n which does not match the length of the second, $(length(b))."))
end
sqab, asum, bsum = _bhattacharyya_coeff(a, b)
# We must normalize since we cannot assume that the vectors are normalized to probability vectors.
return sqab / sqrt(asum * bsum)
end

n = length(a)
@inline function _bhattacharyya_coeff(a, b)
Ta = _eltype(a)
Tb = _eltype(b)
T = typeof(sqrt(zero(promote_type(Ta, Tb))))
sqab = zero(T)
# We must normalize since we cannot assume that the vectors are normalized to probability vectors.
asum = zero(T)
bsum = zero(T)
asum = zero(Ta)
bsum = zero(Tb)

for (ai, bi) in zip(a, b)
sqab += sqrt(ai * bi)
asum += ai
bsum += bi
end
return sqab, asum, bsum
end
@inline function _bhattacharyya_coeff(a::AbstractVector{Ta}, b::AbstractVector{Tb}) where {Ta<:Number,Tb<:Number}
T = typeof(sqrt(oneunit(Ta)*oneunit(Tb)))
sqab = zero(T)
asum = zero(Ta)
bsum = zero(Tb)

@simd for i = 1:n
@simd for i in eachindex(a, b)
@inbounds ai = a[i]
@inbounds bi = b[i]
sqab += sqrt(ai * bi)
asum += ai
bsum += bi
end

sqab / sqrt(asum * bsum)
return sqab, asum, bsum
end

bhattacharyya_coeff(a::T, b::T) where {T <: Number} = throw("Bhattacharyya coefficient cannot be calculated for scalars")

# Faster pair- and column-wise versions TBD...


# Bhattacharyya distance
(::BhattacharyyaDist)(a::AbstractVector{T}, b::AbstractVector{T}) where {T <: Number} = -log(bhattacharyya_coeff(a, b))
(::BhattacharyyaDist)(a::T, b::T) where {T <: Number} = throw("Bhattacharyya distance cannot be calculated for scalars")
(::BhattacharyyaDist)(a, b) = -log(bhattacharyya_coeff(a, b))
bhattacharyya(a, b) = BhattacharyyaDist()(a, b)

# Hellinger distance
(::HellingerDist)(a::AbstractVector{T}, b::AbstractVector{T}) where {T <: Number} = sqrt(1 - bhattacharyya_coeff(a, b))
(::HellingerDist)(a::T, b::T) where {T <: Number} = throw("Hellinger distance cannot be calculated for scalars")
(::HellingerDist)(a, b) = sqrt(1 - bhattacharyya_coeff(a, b))
hellinger(a, b) = HellingerDist()(a, b)
16 changes: 8 additions & 8 deletions src/bregman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,26 @@ end
Bregman(F, ∇) = Bregman(F, ∇, LinearAlgebra.dot)

# Evaluation fuction
function (dist::Bregman)(p::AbstractVector, q::AbstractVector)
function (dist::Bregman)(p, q)
# Create cache vals.
FP_val = dist.F(p);
FQ_val = dist.F(q);
DQ_val = dist.(q);
p_size = size(p);
FP_val = dist.F(p)
FQ_val = dist.F(q)
DQ_val = dist.(q)
p_size = length(p)
# Check F codomain.
if !(isa(FP_val, Real) && isa(FQ_val, Real))
throw(ArgumentError("F Codomain Error: F doesn't map the vectors to real numbers"))
end
# Check vector size.
if !(p_size == size(q))
if p_size != length(q)
throw(DimensionMismatch("The vector p ($(size(p))) and q ($(size(q))) are different sizes."))
end
# Check gradient size.
if !(size(DQ_val) == p_size)
if length(DQ_val) != p_size
throw(DimensionMismatch("The gradient result is not the same size as p and q"))
end
# Return the Bregman divergence.
return FP_val - FQ_val - dist.inner(DQ_val, p-q);
return FP_val - FQ_val - dist.inner(DQ_val, p .- q)
end

# Convenience function.
Expand Down
6 changes: 6 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ function get_common_ncols(a::AbstractMatrix, b::AbstractMatrix)
return na
end

function get_common_length(a, b)
n = length(a)
length(b) == n || throw(DimensionMismatch("The lengths of a and b must match."))
return n
end

function get_pairwise_dims(r::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix)
ma, na = size(a)
mb, nb = size(b)
Expand Down
Loading

0 comments on commit f6ee353

Please sign in to comment.