Skip to content

Commit

Permalink
Change signature of pairwise! and colwise! (#239)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Jul 3, 2023
1 parent 3b90e79 commit d05bf6c
Show file tree
Hide file tree
Showing 7 changed files with 243 additions and 79 deletions.
19 changes: 17 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,29 @@ If the vector/matrix to store the results are pre-allocated, you may use the
storage (without creating a new array) using the following syntax
(`i` being either `1` or `2`):

```julia
colwise!(dist, r, X, Y)
pairwise!(dist, R, X, Y, dims=i)
pairwise!(dist, R, X, dims=i)
```

Please pay attention to the difference, the functions for inplace computation are
`colwise!` and `pairwise!` (instead of `colwise` and `pairwise`).

#### Deprecated alternative syntax

The syntax

```julia
colwise!(r, dist, X, Y)
pairwise!(R, dist, X, Y, dims=i)
pairwise!(R, dist, X, dims=i)
```

Please pay attention to the difference, the functions for inplace computation are
`colwise!` and `pairwise!` (instead of `colwise` and `pairwise`).
with the first two arguments (metric and results) interchanged is supported as well.
However, its use is discouraged since
[it is deprecated](https://github.com/JuliaStats/Distances.jl/pull/239) and will be
removed in a future release.

## Distance type hierarchy

Expand Down
2 changes: 2 additions & 0 deletions src/Distances.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,6 @@ include("mahalanobis.jl")
include("bhattacharyya.jl")
include("bregman.jl")

include("deprecated.jl")

end # module end
69 changes: 69 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
Base.@deprecate pairwise!(r::AbstractMatrix, dist::PreMetric, a) pairwise!(dist, r, a)
Base.@deprecate pairwise!(r::AbstractMatrix, dist::PreMetric, a, b) pairwise!(dist, r, a, b)

Base.@deprecate pairwise!(
r::AbstractMatrix, dist::PreMetric, a::AbstractMatrix;
dims::Union{Nothing,Integer}=nothing
) pairwise!(dist, r, a; dims=dims)
Base.@deprecate pairwise!(
r::AbstractMatrix, dist::PreMetric, a::AbstractMatrix, b::AbstractMatrix;
dims::Union{Nothing,Integer}=nothing
) pairwise!(dist, r, a, b; dims=dims)

Base.@deprecate colwise!(r::AbstractArray, dist::PreMetric, a, b) colwise!(dist, r, a, b)

# docstrings for deprecated methods
@doc """
pairwise!(r::AbstractMatrix, dist::PreMetric, a)
Same as `pairwise!(dist, r, a)`.
!!! warning
Since this alternative syntax is deprecated and will be removed in a future release of
Distances.jl, its use is discouraged. Please call `pairwise!(dist, r, a)` instead.
""" pairwise!(r::AbstractMatrix, dist::PreMetric, a)
@doc """
pairwise!(r::AbstractMatrix, dist::PreMetric, a, b)
Same as `pairwise!(dist, r, a, b)`.
!!! warning
Since this alternative syntax is deprecated and will be removed in a future release of
Distances.jl, its use is discouraged. Please call `pairwise!(dist, r, a, b)` instead.
""" pairwise!(r::AbstractMatrix, dist::PreMetric, a, b)

@doc """
pairwise!(r::AbstractMatrix, dist::PreMetric, a::AbstractMatrix; dims)
Same as `pairwise!(dist, r, a; dims)`.
!!! warning
Since this alternative syntax is deprecated and will be removed in a future release of
Distances.jl, its use is discouraged. Please call `pairwise!(dist, r, a; dims)` instead.
""" pairwise!(
r::AbstractMatrix, dist::PreMetric, a::AbstractMatrix;
dims::Union{Nothing,Integer}
)
@doc """
pairwise!(r::AbstractMatrix, dist::PreMetric, a::AbstractMatrix, b::AbstractMatrix; dims)
Same as `pairwise!(dist, r, a, b; dims)`.
!!! warning
Since this alternative syntax is deprecated and will be removed in a future release of
Distances.jl, its use is discouraged. Please call `pairwise!(dist, r, a, b; dims)`
instead.
""" pairwise!(
r::AbstractMatrix, dist::PreMetric, a::AbstractMatrix, b::AbstractMatrix;
dims::Union{Nothing,Integer}
)

@doc """
colwise!(r::AbstractArray, dist::PreMetric, a, b)
Same as `colwise!(dist, r, a, b)`.
!!! warning
Since this alternative syntax is deprecated and will be removed in a future release of
Distances.jl, its use is discouraged. Please call `colwise!(dist, r, a, b)` instead.
""" colwise!(r::AbstractArray, dist::PreMetric, a, b)
60 changes: 30 additions & 30 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ __eltype(::Base.EltypeUnknown, a) = _eltype(typeof(first(a)))
# Generic column-wise evaluation

"""
colwise!(r::AbstractArray, metric::PreMetric, a, b)
colwise!(metric::PreMetric, r::AbstractArray, a, b)
Compute distances between corresponding elements of the iterable collections
`a` and `b` according to distance `metric`, and store the result in `r`.
`a` and `b` must have the same number of elements, `r` must be an array of length
`length(a) == length(b)`.
"""
function colwise!(r::AbstractArray, metric::PreMetric, a, b)
function colwise!(metric::PreMetric, r::AbstractArray, a, b)
require_one_based_indexing(r)
n = length(a)
length(b) == n || throw(DimensionMismatch("iterators have different lengths"))
Expand All @@ -65,7 +65,7 @@ function colwise!(r::AbstractArray, metric::PreMetric, a, b)
r
end

function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractVector, b::AbstractMatrix)
function colwise!(metric::PreMetric, r::AbstractArray, a::AbstractVector, b::AbstractMatrix)
require_one_based_indexing(r)
n = size(b, 2)
length(r) == n || throw(DimensionMismatch("Incorrect size of r."))
Expand All @@ -75,7 +75,7 @@ function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractVector, b::Abs
r
end

function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractMatrix, b::AbstractVector)
function colwise!(metric::PreMetric, r::AbstractArray, a::AbstractMatrix, b::AbstractVector)
require_one_based_indexing(r)
n = size(a, 2)
length(r) == n || throw(DimensionMismatch("Incorrect size of r."))
Expand All @@ -86,11 +86,11 @@ function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractMatrix, b::Abs
end

"""
colwise!(r::AbstractArray, metric::PreMetric,
colwise!(metric::PreMetric, r::AbstractArray,
a::AbstractMatrix, b::AbstractMatrix)
colwise!(r::AbstractArray, metric::PreMetric,
colwise!(metric::PreMetric, r::AbstractArray,
a::AbstractVector, b::AbstractMatrix)
colwise!(r::AbstractArray, metric::PreMetric,
colwise!(metric::PreMetric, r::AbstractArray,
a::AbstractMatrix, b::AbstractVector)
Compute distances between each corresponding columns of `a` and `b` according
Expand All @@ -105,7 +105,7 @@ vector. `r` must be an array of length `maximum(size(a, 2), size(b, 2))`.
If both `a` and `b` are vectors, the generic, iterator-based method of
`colwise` applies.
"""
function colwise!(r::AbstractArray, metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix)
function colwise!(metric::PreMetric, r::AbstractArray, a::AbstractMatrix, b::AbstractMatrix)
require_one_based_indexing(r, a, b)
n = get_common_ncols(a, b)
length(r) == n || throw(DimensionMismatch("Incorrect size of r."))
Expand All @@ -126,7 +126,7 @@ Compute distances between corresponding elements of the iterable collections
function colwise(metric::PreMetric, a, b)
n = get_common_length(a, b)
r = Vector{result_type(metric, a, b)}(undef, n)
colwise!(r, metric, a, b)
colwise!(metric, r, a, b)
end

"""
Expand All @@ -148,25 +148,25 @@ vector.
function colwise(metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix)
n = get_common_ncols(a, b)
r = Vector{result_type(metric, a, b)}(undef, n)
colwise!(r, metric, a, b)
colwise!(metric, r, a, b)
end

function colwise(metric::PreMetric, a::AbstractVector, b::AbstractMatrix)
n = size(b, 2)
r = Vector{result_type(metric, a, b)}(undef, n)
colwise!(r, metric, a, b)
colwise!(metric, r, a, b)
end

function colwise(metric::PreMetric, a::AbstractMatrix, b::AbstractVector)
n = size(a, 2)
r = Vector{result_type(metric, a, b)}(undef, n)
colwise!(r, metric, a, b)
colwise!(metric, r, a, b)
end


# Generic pairwise evaluation

function _pairwise!(r::AbstractMatrix, metric::PreMetric, a, b=a)
function _pairwise!(metric::PreMetric, r::AbstractMatrix, a, b=a)
require_one_based_indexing(r)
na = length(a)
nb = length(b)
Expand All @@ -177,7 +177,7 @@ function _pairwise!(r::AbstractMatrix, metric::PreMetric, a, b=a)
r
end

function _pairwise!(r::AbstractMatrix, metric::PreMetric,
function _pairwise!(metric::PreMetric, r::AbstractMatrix,
a::AbstractMatrix, b::AbstractMatrix=a)
require_one_based_indexing(r, a, b)
na = size(a, 2)
Expand All @@ -192,7 +192,7 @@ function _pairwise!(r::AbstractMatrix, metric::PreMetric,
r
end

function _pairwise!(r::AbstractMatrix, metric::SemiMetric, a)
function _pairwise!(metric::SemiMetric, r::AbstractMatrix, a)
require_one_based_indexing(r)
n = length(a)
size(r) == (n, n) || throw(DimensionMismatch("Incorrect size of r."))
Expand All @@ -208,7 +208,7 @@ function _pairwise!(r::AbstractMatrix, metric::SemiMetric, a)
r
end

function _pairwise!(r::AbstractMatrix, metric::SemiMetric, a::AbstractMatrix)
function _pairwise!(metric::SemiMetric, r::AbstractMatrix, a::AbstractMatrix)
require_one_based_indexing(r)
n = size(a, 2)
size(r) == (n, n) || throw(DimensionMismatch("Incorrect size of r."))
Expand Down Expand Up @@ -237,7 +237,7 @@ function deprecated_dims(dims::Union{Nothing,Integer})
end

"""
pairwise!(r::AbstractMatrix, metric::PreMetric,
pairwise!(metric::PreMetric, r::AbstractMatrix,
a::AbstractMatrix, b::AbstractMatrix=a; dims)
Compute distances between each pair of rows (if `dims=1`) or columns (if `dims=2`)
Expand All @@ -247,7 +247,7 @@ If a single matrix `a` is provided, compute distances between its rows or column
`a` and `b` must have the same numbers of columns if `dims=1`, or of rows if `dims=2`.
`r` must be a matrix with size `size(a, dims) × size(b, dims)`.
"""
function pairwise!(r::AbstractMatrix, metric::PreMetric,
function pairwise!(metric::PreMetric, r::AbstractMatrix,
a::AbstractMatrix, b::AbstractMatrix;
dims::Union{Nothing,Integer}=nothing)
dims = deprecated_dims(dims)
Expand All @@ -266,13 +266,13 @@ function pairwise!(r::AbstractMatrix, metric::PreMetric,
size(r) == (na, nb) ||
throw(DimensionMismatch("Incorrect size of r (got $(size(r)), expected $((na, nb)))."))
if dims == 1
_pairwise!(r, metric, permutedims(a), permutedims(b))
_pairwise!(metric, r, permutedims(a), permutedims(b))
else
_pairwise!(r, metric, a, b)
_pairwise!(metric, r, a, b)
end
end

function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix;
function pairwise!(metric::PreMetric, r::AbstractMatrix, a::AbstractMatrix;
dims::Union{Nothing,Integer}=nothing)
dims = deprecated_dims(dims)
dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)"))
Expand All @@ -284,23 +284,23 @@ function pairwise!(r::AbstractMatrix, metric::PreMetric, a::AbstractMatrix;
size(r) == (n, n) ||
throw(DimensionMismatch("Incorrect size of r (got $(size(r)), expected $((n, n)))."))
if dims == 1
_pairwise!(r, metric, permutedims(a))
_pairwise!(metric, r, permutedims(a))
else
_pairwise!(r, metric, a)
_pairwise!(metric, r, a)
end
end

"""
pairwise!(r::AbstractMatrix, metric::PreMetric, a, b=a)
pairwise!(metric::PreMetric, r::AbstractMatrix, a, b=a)
Compute distances between each element of collection `a` and each element of
collection `b` according to distance `metric`, and store the result in `r`.
If a single iterable `a` is provided, compute distances between its elements.
`r` must be a matrix with size `length(a) × length(b)`.
"""
pairwise!(r::AbstractMatrix, metric::PreMetric, a, b) = _pairwise!(r, metric, a, b)
pairwise!(r::AbstractMatrix, metric::PreMetric, a) = _pairwise!(r, metric, a)
pairwise!(metric::PreMetric, r::AbstractMatrix, a, b) = _pairwise!(metric, r, a, b)
pairwise!(metric::PreMetric, r::AbstractMatrix, a) = _pairwise!(metric, r, a)

"""
pairwise(metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix=a; dims)
Expand All @@ -318,7 +318,7 @@ function pairwise(metric::PreMetric, a::AbstractMatrix, b::AbstractMatrix;
m = size(a, dims)
n = size(b, dims)
r = Matrix{result_type(metric, a, b)}(undef, m, n)
pairwise!(r, metric, a, b, dims=dims)
pairwise!(metric, r, a, b, dims=dims)
end

function pairwise(metric::PreMetric, a::AbstractMatrix;
Expand All @@ -327,7 +327,7 @@ function pairwise(metric::PreMetric, a::AbstractMatrix;
dims in (1, 2) || throw(ArgumentError("dims should be 1 or 2 (got $dims)"))
n = size(a, dims)
r = Matrix{result_type(metric, a, a)}(undef, n, n)
pairwise!(r, metric, a, dims=dims)
pairwise!(metric, r, a, dims=dims)
end

"""
Expand All @@ -341,11 +341,11 @@ function pairwise(metric::PreMetric, a, b)
m = length(a)
n = length(b)
r = Matrix{result_type(metric, a, b)}(undef, m, n)
_pairwise!(r, metric, a, b)
_pairwise!(metric, r, a, b)
end

function pairwise(metric::PreMetric, a)
n = length(a)
r = Matrix{result_type(metric, a, a)}(undef, n, n)
_pairwise!(r, metric, a)
_pairwise!(metric, r, a)
end
30 changes: 15 additions & 15 deletions src/mahalanobis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,34 +95,34 @@ end
sqmahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = SqMahalanobis(Q)(a, b)
mahalanobis(a::AbstractVector, b::AbstractVector, Q::AbstractMatrix) = Mahalanobis(Q)(a, b)

function _colwise!(r, dist, a, b)
function _colwise!(dist, r, a, b)
Q = dist.qmat
get_colwise_dims(size(Q, 1), r, a, b)
z = a .- b
dot_percol!(r, Q * z, z)
end

function colwise!(r::AbstractArray, dist::SqMahalanobis, a::AbstractMatrix, b::AbstractMatrix)
_colwise!(r, dist, a, b)
function colwise!(dist::SqMahalanobis, r::AbstractArray, a::AbstractMatrix, b::AbstractMatrix)
_colwise!(dist, r, a, b)
end
function colwise!(r::AbstractArray, dist::SqMahalanobis, a::AbstractVector, b::AbstractMatrix)
_colwise!(r, dist, a, b)
function colwise!(dist::SqMahalanobis, r::AbstractArray, a::AbstractVector, b::AbstractMatrix)
_colwise!(dist, r, a, b)
end
function colwise!(r::AbstractArray, dist::SqMahalanobis, a::AbstractMatrix, b::AbstractVector)
_colwise!(r, dist, a, b)
function colwise!(dist::SqMahalanobis, r::AbstractArray, a::AbstractMatrix, b::AbstractVector)
_colwise!(dist, r, a, b)
end

function colwise!(r::AbstractArray, dist::Mahalanobis, a::AbstractMatrix, b::AbstractMatrix)
sqrt!(_colwise!(r, dist, a, b))
function colwise!(dist::Mahalanobis, r::AbstractArray, a::AbstractMatrix, b::AbstractMatrix)
sqrt!(_colwise!(dist, r, a, b))
end
function colwise!(r::AbstractArray, dist::Mahalanobis, a::AbstractVector, b::AbstractMatrix)
sqrt!(_colwise!(r, dist, a, b))
function colwise!(dist::Mahalanobis, r::AbstractArray, a::AbstractVector, b::AbstractMatrix)
sqrt!(_colwise!(dist, r, a, b))
end
function colwise!(r::AbstractArray, dist::Mahalanobis, a::AbstractMatrix, b::AbstractVector)
sqrt!(_colwise!(r, dist, a, b))
function colwise!(dist::Mahalanobis, r::AbstractArray, a::AbstractMatrix, b::AbstractVector)
sqrt!(_colwise!(dist, r, a, b))
end

function _pairwise!(r::AbstractMatrix, dist::Union{SqMahalanobis,Mahalanobis}, a::AbstractMatrix, b::AbstractMatrix)
function _pairwise!(dist::Union{SqMahalanobis,Mahalanobis}, r::AbstractMatrix, a::AbstractMatrix, b::AbstractMatrix)
Q = dist.qmat
m, na, nb = get_pairwise_dims(size(Q, 1), r, a, b)

Expand All @@ -140,7 +140,7 @@ function _pairwise!(r::AbstractMatrix, dist::Union{SqMahalanobis,Mahalanobis}, a
r
end

function _pairwise!(r::AbstractMatrix, dist::Union{SqMahalanobis,Mahalanobis}, a::AbstractMatrix)
function _pairwise!(dist::Union{SqMahalanobis,Mahalanobis}, r::AbstractMatrix, a::AbstractMatrix)
Q = dist.qmat
m, n = get_pairwise_dims(size(Q, 1), r, a)

Expand Down
Loading

0 comments on commit d05bf6c

Please sign in to comment.