Skip to content

Commit

Permalink
Merge pull request #8 from capkuro/develop
Browse files Browse the repository at this point in the history
Added fastdtw
  • Loading branch information
pierre-rouanet authored Feb 16, 2017
2 parents fd625b2 + 54e60ac commit 2207f88
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
35 changes: 34 additions & 1 deletion dtw.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from numpy import array, zeros, argmin, inf
from numpy import array, zeros, argmin, inf, equal
from scipy.spatial.distance import cdist

def dtw(x, y, dist):
"""
Expand Down Expand Up @@ -32,6 +33,38 @@ def dtw(x, y, dist):
path = _traceback(D0)
return D1[-1, -1] / sum(D1.shape), C, D1, path

def fastdtw(x, y, dist):
"""
Computes Dynamic Time Warping (DTW) of two sequences in a faster way.
Instead of iterating through each element and calculating each distance,
this uses the cdist function from scipy (https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cdist.html)
:param array x: N1*M array
:param array y: N2*M array
:param string or func dist: distance parameter for cdist. When string is given, cdist uses optimized functions for the distance metrics.
If a string is passed, the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'.
Returns the minimum distance, the cost matrix, the accumulated cost matrix, and the wrap path.
"""
assert len(x)
assert len(y)
r, c = len(x), len(y)
D0 = zeros((r + 1, c + 1))
D0[0, 1:] = inf
D0[1:, 0] = inf
D1 = D0[1:, 1:]
D0[1:,1:] = cdist(x,y,dist)
C = D1.copy()
for i in range(r):
for j in range(c):
D1[i, j] += min(D0[i, j], D0[i, j+1], D0[i+1, j])
if len(x)==1:
path = zeros(len(y)), range(len(y))
elif len(y) == 1:
path = range(len(x)), zeros(len(x))
else:
path = _traceback(D0)
return D1[-1, -1] / sum(D1.shape), C, D1, path

def _traceback(D):
i, j = array(D.shape) - 2
p, q = [i], [j]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
url='https://github.com/pierre-rouanet/dtw',
license='GNU GENERAL PUBLIC LICENSE Version 3',

install_requires=['numpy'],
install_requires=['numpy', 'scipy'],
setup_requires=['setuptools_git >= 0.3', ],

py_modules=['dtw'],
Expand Down

0 comments on commit 2207f88

Please sign in to comment.