├── .gitignore
├── LICENSE
├── Makefile
├── README.rst
├── examples
├── plot_barycenter.py
├── plot_chainer_MLP.py
└── plot_interpolation.py
├── sdtw
├── __init__.py
├── barycenter.py
├── chainer_func.py
├── dataset.py
├── distance.py
├── path.py
├── setup.py
├── soft_dtw.py
├── soft_dtw_fast.pyx
└── tests
│ ├── test_chainer_func.py
│ ├── test_path.py
│ └── test_soft_dtw.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | build
3 | *.c
4 | *.so
5 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017, Mathieu Blondel
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | 1. Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | 2. Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation and/or
12 | other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
15 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
16 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17 | IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
18 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19 | NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
20 | OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
21 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
22 | OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23 | THE POSSIBILITY OF SUCH DAMAGE.
24 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | PYTHON ?= python
2 | CYTHON ?= cython
3 | NOSETESTS ?= nosetests
4 |
5 | # Compilation...
6 |
7 | CYTHONSRC= $(wildcard sdtw/*.pyx)
8 | CSRC= $(CYTHONSRC:.pyx=.c)
9 |
10 | inplace:
11 | $(PYTHON) setup.py build_ext -i
12 |
13 | all: cython inplace
14 |
15 | cython: $(CSRC)
16 |
17 | clean:
18 | rm -f sdtw/*.c sdtw/*.html
19 | rm -f `find sdtw -name "*.pyc"`
20 | rm -f `find sdtw -name "*.so"`
21 |
22 | %.c: %.pyx
23 | $(CYTHON) $<
24 |
25 | # Tests...
26 | #
27 | test-code: inplace
28 | $(NOSETESTS) -s sdtw
29 |
30 | test-coverage:
31 | $(NOSETESTS) -s --with-coverage --cover-html --cover-html-dir=coverage \
32 | --cover-package=sdtw sdtw
33 |
34 |
35 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | .. -*- mode: rst -*-
2 |
3 | soft-DTW
4 | =========
5 |
6 | Python implementation of soft-DTW.
7 |
8 | What is it?
9 | -----------
10 |
11 | The celebrated dynamic time warping (DTW) [1] defines the discrepancy between
12 | two time series, of possibly variable length, as their minimal alignment cost.
13 | Although the number of possible alignments is exponential in the length of the
14 | two time series, [1] showed that DTW can be computed in only quadractic time
15 | using dynamic programming.
16 |
17 | Soft-DTW [2] proposes to replace this minimum by a soft minimum. Like the
18 | original DTW, soft-DTW can be computed in quadratic time using dynamic
19 | programming. However, the main advantage of soft-DTW stems from the fact that
20 | it is differentiable everywhere and that its gradient can also be computed in
21 | quadratic time. This enables to use soft-DTW for time series averaging or as a
22 | loss function, between a ground-truth time series and a time series predicted
23 | by a neural network, trained end-to-end using backpropagation.
24 |
25 | Supported features
26 | ------------------
27 |
28 | * soft-DTW (forward pass) and gradient (backward pass) computations,
29 | implemented in Cython for speed
30 | * barycenters (time series averaging)
31 | * dataset loader for the `UCR archive `_
32 | * `Chainer `_ function
33 |
34 | Example
35 | --------
36 |
37 | .. code-block:: python
38 |
39 | from sdtw import SoftDTW
40 | from sdtw.distance import SquaredEuclidean
41 |
42 | # Time series 1: numpy array, shape = [m, d] where m = length and d = dim
43 | X = ...
44 | # Time series 2: numpy array, shape = [n, d] where n = length and d = dim
45 | Y = ...
46 |
47 | # D can also be an arbitrary distance matrix: numpy array, shape [m, n]
48 | D = SquaredEuclidean(X, Y)
49 | sdtw = SoftDTW(D, gamma=1.0)
50 | # soft-DTW discrepancy, approaches DTW as gamma -> 0
51 | value = sdtw.compute()
52 | # gradient w.r.t. D, shape = [m, n], which is also the expected alignment matrix
53 | E = sdtw.grad()
54 | # gradient w.r.t. X, shape = [m, d]
55 | G = D.jacobian_product(E)
56 |
57 | Installation
58 | ------------
59 |
60 | Binary packages are not available.
61 |
62 | This project can be installed from its git repository. It is assumed that you
63 | have a working C compiler.
64 |
65 | 1. Obtain the sources by::
66 |
67 | git clone https://github.com/mblondel/soft-dtw.git
68 |
69 | or, if `git` is unavailable, `download as a ZIP from GitHub `_.
70 |
71 |
72 | 2. Install the dependencies::
73 |
74 | # via pip
75 |
76 | pip install numpy scipy scikit-learn cython nose
77 |
78 |
79 | # via conda
80 |
81 | conda install numpy scipy scikit-learn cython nose
82 |
83 |
84 | 3. Build and install soft-dtw::
85 |
86 | cd soft-dtw
87 | make cython
88 | python setup.py build
89 | sudo python setup.py install
90 |
91 |
92 | References
93 | ----------
94 |
95 | .. [1] Hiroaki Sakoe, Seibi Chiba.
96 | *Dynamic programming algorithm optimization for spoken word recognition.*
97 | In: IEEE Trans. on Acoustics, Speech, and Sig. Proc, 1978.
98 |
99 | .. [2] Marco Cuturi, Mathieu Blondel.
100 | *Soft-DTW: a Differentiable Loss Function for Time-Series.*
101 | In: Proc. of ICML 2017.
102 | [`PDF `_]
103 |
104 | Author
105 | ------
106 |
107 | - Mathieu Blondel, 2017
108 |
--------------------------------------------------------------------------------
/examples/plot_barycenter.py:
--------------------------------------------------------------------------------
1 | # Author: Mathieu Blondel
2 | # License: Simplified BSD
3 |
4 | import numpy as np
5 |
6 | import matplotlib.pylab as plt
7 | plt.style.use('ggplot')
8 | plt.rcParams["xtick.labelsize"] = 15
9 | plt.rcParams["ytick.labelsize"] = 15
10 |
11 | from sdtw.dataset import load_ucr
12 | from sdtw.barycenter import sdtw_barycenter
13 |
14 |
15 | X_tr, y_tr, X_te, y_te = load_ucr("ECG200")
16 |
17 | n = 10
18 |
19 | # Pick n time series at random from the same class.
20 | rng = np.random.RandomState(0)
21 | classes = np.unique(y_tr)
22 | k = rng.randint(len(classes))
23 | X = X_tr[y_tr == classes[k]]
24 | X = X[rng.permutation(len(X))[:n]]
25 |
26 | fig = plt.figure(figsize=(15,4))
27 |
28 | barycenter_init = sum(X) / len(X)
29 |
30 | fig_pos = 131
31 |
32 | for gamma in (0.1, 1, 10):
33 | ax = fig.add_subplot(fig_pos)
34 |
35 | for x in X:
36 | ax.plot(x.ravel(), c="k", linewidth=3, alpha=0.15)
37 |
38 | Z = sdtw_barycenter(X, barycenter_init, gamma=gamma)
39 | ax.plot(Z.ravel(), c="r", linewidth=3, alpha=0.7)
40 | ax.set_title(r"Soft-DTW ($\gamma$=%0.1f)" % gamma)
41 |
42 | fig_pos += 1
43 |
44 | plt.show()
45 |
--------------------------------------------------------------------------------
/examples/plot_chainer_MLP.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import numpy as np
4 |
5 | from chainer import training
6 | from chainer import iterators, optimizers, serializers
7 | from chainer import Chain
8 | import chainer.functions as F
9 | import chainer.links as L
10 | from chainer.datasets import tuple_dataset
11 |
12 | from sdtw.dataset import load_ucr
13 | from sdtw.chainer_func import SoftDTWLoss
14 |
15 | import matplotlib.pylab as plt
16 | import matplotlib.font_manager as fm
17 | plt.style.use('ggplot')
18 | plt.rcParams["xtick.labelsize"] = 15
19 | plt.rcParams["ytick.labelsize"] = 15
20 |
21 |
22 | def split_time_series(X_tr, X_te, proportion=0.6):
23 | len_ts = X_tr.shape[1]
24 | len_input=int(round(len_ts * proportion))
25 | len_output=len_ts - len_input
26 |
27 | return np.float32(X_tr[:, :len_input, 0]), \
28 | np.float32(X_tr[:, len_input:, 0]), \
29 | np.float32(X_te[:, :len_input, 0]), \
30 | np.float32(X_te[:, len_input:, 0])
31 |
32 |
33 | class MLP(Chain):
34 |
35 | def __init__(self, len_input, len_output, activation="tanh", n_units=50):
36 | self.activation = activation
37 |
38 | super(MLP, self).__init__(
39 | mid = L.Linear(len_input, n_units),
40 | out=L.Linear(n_units, len_output),
41 | )
42 |
43 | def __call__(self, x):
44 | # Given the current observation, predict the rest.
45 | xx = self.mid(x)
46 | func = getattr(F, self.activation)
47 | h = func(xx)
48 | y = self.out(h)
49 | return y
50 |
51 |
52 | class Objective(Chain):
53 |
54 | def __init__(self, predictor, loss="euclidean", gamma=1.0):
55 | self.loss = loss
56 | self.gamma = gamma
57 | super(Objective, self).__init__(predictor=predictor)
58 |
59 | def __call__(self, x, t):
60 | y = self.predictor(x)
61 |
62 | if self.loss == "euclidean":
63 | return F.mean_squared_error(y, t)
64 |
65 | elif self.loss == "sdtw":
66 | loss = 0
67 | for i in range(y.shape[0]):
68 | y_i = F.reshape(y[i], (-1,1))
69 | t_i = F.reshape(t[i], (-1,1))
70 | loss += SoftDTWLoss(self.gamma)(y_i, t_i)
71 | return loss
72 |
73 | else:
74 | raise ValueError("Unknown loss")
75 |
76 |
77 | def train(network, loss, X_tr, Y_tr, X_te, Y_te, n_epochs=30, gamma=1):
78 | model= Objective(network, loss=loss, gamma=gamma)
79 |
80 | #optimizer = optimizers.SGD()
81 | optimizer = optimizers.Adam()
82 | optimizer.setup(model)
83 |
84 | train = tuple_dataset.TupleDataset(X_tr, Y_tr)
85 | test = tuple_dataset.TupleDataset(X_te, Y_te)
86 |
87 | train_iter = iterators.SerialIterator(train, batch_size=1, shuffle=True)
88 | test_iter = iterators.SerialIterator(test, batch_size=1, repeat=False,
89 | shuffle=False)
90 | updater = training.StandardUpdater(train_iter, optimizer)
91 | trainer = training.Trainer(updater, (n_epochs, 'epoch'))
92 |
93 | trainer.run()
94 |
95 |
96 | if __name__ == '__main__':
97 | import os
98 | import sys
99 |
100 | try:
101 | dbname = sys.argv[1]
102 | except IndexError:
103 | dbname = "ECG200"
104 |
105 | X_tr, _, X_te, _ = load_ucr(dbname)
106 |
107 | proportion = 0.6
108 | n_units = 10
109 | n_epochs = 30
110 | gamma = 1
111 | warm_start = True
112 |
113 | X_te_ = X_te
114 | X_tr, Y_tr, X_te, Y_te = split_time_series(X_tr, X_te, proportion)
115 |
116 | len_input = X_tr.shape[1]
117 | len_output = Y_tr.shape[1]
118 |
119 | networks = [MLP(len_input, len_output, n_units=n_units),]
120 | losses = ["sdtw",]
121 | labels = ["Soft-DTW loss",]
122 |
123 | for i in range(len(networks)):
124 | if warm_start and i >= 1:
125 | # Warm-start with Euclidean-case solution
126 | networks[i].mid = copy.deepcopy(networks[0].mid)
127 | networks[i].out = copy.deepcopy(networks[0].out)
128 |
129 | train(networks[i], losses[i], X_tr, Y_tr, X_te, Y_te,
130 | n_epochs=n_epochs, gamma=gamma)
131 |
132 | max_vals = []
133 | min_vals = []
134 |
135 | fig = plt.figure(figsize=(10,6))
136 |
137 | pos = 220
138 |
139 | for i in range(min(X_te.shape[0], 4)):
140 | pos += 1
141 | ax = fig.add_subplot(pos)
142 |
143 | inputseq = np.array([X_te[i]]) # Need to wrap as minibatch...
144 |
145 | # Plot predictions.
146 | for idx, label in enumerate(labels):
147 | output = networks[idx](inputseq)
148 | output = np.squeeze(np.array(output.data))
149 |
150 | ax.plot(range(len_input + 1,len_input + len(output) + 1),
151 | output,
152 | alpha=0.75,
153 | lw=3,
154 | label=label,
155 | zorder=10)
156 |
157 | max_vals.append(output.max())
158 | min_vals.append(output.min())
159 |
160 | # Plot ground-truth time series.
161 | ground_truth = X_te_[i]
162 | max_vals.append(ground_truth.max())
163 | min_vals.append(ground_truth.min())
164 | ax.plot(ground_truth,
165 | c="k",
166 | alpha=0.3,
167 | lw=3,
168 | label='Ground truth',
169 | zorder=5)
170 |
171 | # Plot vertical line.
172 | ax.plot([len_input, len_input],
173 | [np.min(min_vals), np.max(max_vals)], lw=3, ls="--", c="k")
174 |
175 | # Legend.
176 | prop = fm.FontProperties(size=18)
177 | ax.legend(loc="best", prop=prop)
178 |
179 | fig.set_tight_layout(True)
180 | plt.show()
181 |
--------------------------------------------------------------------------------
/examples/plot_interpolation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import matplotlib.pylab as plt
4 | plt.style.use('ggplot')
5 | plt.rcParams["xtick.labelsize"] = 15
6 | plt.rcParams["ytick.labelsize"] = 15
7 |
8 | from sdtw.barycenter import sdtw_barycenter
9 | from sdtw.dataset import load_ucr
10 |
11 | X_tr = load_ucr("Gun_Point")[0]
12 | X1, X2 = X_tr[7], X_tr[37]
13 |
14 | init_25 = 0.25 * X1 + 0.75 * X2
15 | init_50 = 0.50 * X1 + 0.50 * X2
16 | init_75 = 0.75 * X1 + 0.25 * X2
17 |
18 | bary_25 = sdtw_barycenter([X1, X2], init_25, gamma=1, max_iter=100,
19 | weights=[0.25, 0.75])
20 | bary_50 = sdtw_barycenter([X1, X2], init_50, gamma=1, max_iter=100,
21 | weights=[0.50, 0.50])
22 | bary_75 = sdtw_barycenter([X1, X2], init_75, gamma=1, max_iter=100,
23 | weights=[0.75, 0.25])
24 |
25 | colors = [
26 | (0, 51./255, 204./255),
27 | (102./255, 153./255, 255./255),
28 | (255./255, 102./255, 255./255),
29 | (255./255, 0, 102./255),
30 | (1.0, 51./255, 0),
31 | ]
32 |
33 |
34 | fig = plt.figure(figsize=(10,4))
35 |
36 | ax = fig.add_subplot(121)
37 |
38 | ax.plot(X1.ravel(), c=colors[0], lw=3)
39 | ax.plot(bary_75, c=colors[1], lw=3, alpha=0.75)
40 | ax.plot(bary_50, c=colors[2], lw=3, alpha=0.75)
41 | ax.plot(bary_25, c=colors[3], lw=3, alpha=0.75)
42 | ax.plot(X2.ravel(), c=colors[4], lw=3)
43 | ax.set_title("Soft-DTW geometry")
44 |
45 | ax = fig.add_subplot(122)
46 |
47 | ax.plot(X1.ravel(), c=colors[0], lw=3)
48 | ax.plot(init_75, c=colors[1], lw=3, alpha=0.75)
49 | ax.plot(init_50, c=colors[2], lw=3, alpha=0.75)
50 | ax.plot(init_25, c=colors[3], lw=3, alpha=0.75)
51 | ax.plot(X2.ravel(), c=colors[4], lw=3)
52 | ax.set_title("Euclidean geometry")
53 |
54 | plt.show()
55 |
--------------------------------------------------------------------------------
/sdtw/__init__.py:
--------------------------------------------------------------------------------
1 | from .soft_dtw import SoftDTW
2 |
--------------------------------------------------------------------------------
/sdtw/barycenter.py:
--------------------------------------------------------------------------------
1 | # Author: Mathieu Blondel
2 | # License: Simplified BSD
3 |
4 | import numpy as np
5 |
6 | from scipy.optimize import minimize
7 |
8 | from sdtw import SoftDTW
9 | from sdtw.distance import SquaredEuclidean
10 |
11 |
12 | def sdtw_barycenter(X, barycenter_init, gamma=1.0, weights=None,
13 | method="L-BFGS-B", tol=1e-3, max_iter=50):
14 | """
15 | Compute barycenter (time series averaging) under the soft-DTW geometry.
16 |
17 | Parameters
18 | ----------
19 | X: list
20 | List of time series, numpy arrays of shape [len(X[i]), d].
21 |
22 | barycenter_init: array, shape = [length, d]
23 | Initialization.
24 |
25 | gamma: float
26 | Regularization parameter.
27 | Lower is less smoothed (closer to true DTW).
28 |
29 | weights: None or array
30 | Weights of each X[i]. Must be the same size as len(X).
31 |
32 | method: string
33 | Optimization method, passed to `scipy.optimize.minimize`.
34 | Default: L-BFGS.
35 |
36 | tol: float
37 | Tolerance of the method used.
38 |
39 | max_iter: int
40 | Maximum number of iterations.
41 | """
42 | if weights is None:
43 | weights = np.ones(len(X))
44 |
45 | weights = np.array(weights)
46 |
47 | def _func(Z):
48 | # Compute objective value and grad at Z.
49 |
50 | Z = Z.reshape(*barycenter_init.shape)
51 |
52 | m = Z.shape[0]
53 | G = np.zeros_like(Z)
54 |
55 | obj = 0
56 |
57 | for i in range(len(X)):
58 | D = SquaredEuclidean(Z, X[i])
59 | sdtw = SoftDTW(D, gamma=gamma)
60 | value = sdtw.compute()
61 | E = sdtw.grad()
62 | G_tmp = D.jacobian_product(E)
63 | G += weights[i] * G_tmp
64 | obj += weights[i] * value
65 |
66 | return obj, G.ravel()
67 |
68 | # The function works with vectors so we need to vectorize barycenter_init.
69 | res = minimize(_func, barycenter_init.ravel(), method=method, jac=True,
70 | tol=tol, options=dict(maxiter=max_iter, disp=False))
71 |
72 | return res.x.reshape(*barycenter_init.shape)
73 |
--------------------------------------------------------------------------------
/sdtw/chainer_func.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import chainer
4 |
5 | from chainer import Function
6 |
7 | from .soft_dtw import SoftDTW
8 | from .distance import SquaredEuclidean
9 |
10 |
11 | class SoftDTWLoss(Function):
12 |
13 | def __init__(self, gamma):
14 | self.gamma = gamma
15 |
16 | def forward_cpu(self, inputs):
17 | # Z, X: both are arrays of shape length x n_dim
18 | Z, X = inputs
19 |
20 | assert Z.shape[1] == X.shape[1]
21 |
22 | D = SquaredEuclidean(Z, X)
23 | self.sdtw_ = SoftDTW(D, gamma=self.gamma)
24 | loss = self.sdtw_.compute()
25 |
26 | return np.array(loss),
27 |
28 | def backward_cpu(self, inputs, grad_outputs):
29 | Z, X = inputs
30 | # g has the same shape as the output of forward_cpu().
31 | # g should always be 1 since it's the last function (loss function)
32 | g, = grad_outputs
33 |
34 | D = SquaredEuclidean(Z, X)
35 | E = self.sdtw_.grad()
36 | gZ = D.jacobian_product(E).astype(Z.dtype)
37 |
38 | # We don't need the gradient w.r.t. the 2nd argument.
39 | return gZ, np.zeros_like(X)
40 |
--------------------------------------------------------------------------------
/sdtw/dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import numpy as np
3 |
4 |
5 | home = os.path.expanduser("~")
6 | data_dir = os.path.join(home, "sdtw_data")
7 | ucr_dir = os.path.join(data_dir, "UCR_TS_Archive_2015")
8 |
9 |
10 | def _parse_ucr(filename):
11 | y = []
12 | X = []
13 | for line in open(filename):
14 | line = line.strip()
15 | arr = line.split(",")
16 | label = int(arr[0])
17 | feat = list(map(float, arr[1:]))
18 | feat = np.array(feat).reshape(-1, 1)
19 | y.append(label)
20 | X.append(feat)
21 | return X, np.array(y)
22 |
23 |
24 | def list_ucr():
25 | return sorted(os.listdir(ucr_dir))
26 |
27 |
28 | def load_ucr(name):
29 | folder = os.path.join(ucr_dir, name)
30 | tr = os.path.join(folder, "%s_TRAIN" % name)
31 | te = os.path.join(folder, "%s_TEST" % name)
32 |
33 | try:
34 | X_tr, y_tr = _parse_ucr(tr)
35 | X_te, y_te = _parse_ucr(te)
36 | except IOError:
37 | raise IOError("Please copy UCR_TS_Archive_2015/ to $HOME/sdtw_data/. "
38 | "Download from www.cs.ucr.edu/~eamonn/time_series_data.")
39 |
40 | y_tr = np.array(y_tr)
41 | y_te = np.array(y_te)
42 | X_tr = np.array(X_tr)
43 | X_te = np.array(X_te)
44 |
45 | return X_tr, y_tr, X_te, y_te
46 |
--------------------------------------------------------------------------------
/sdtw/distance.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from sklearn.metrics.pairwise import euclidean_distances
4 |
5 | from .soft_dtw_fast import _jacobian_product_sq_euc
6 |
7 | class SquaredEuclidean(object):
8 |
9 | def __init__(self, X, Y):
10 | """
11 | Parameters
12 | ----------
13 | X: array, shape = [m, d]
14 | First time series.
15 |
16 | Y: array, shape = [n, d]
17 | Second time series.
18 | """
19 | self.X = X.astype(np.float64)
20 | self.Y = Y.astype(np.float64)
21 |
22 | def compute(self):
23 | """
24 | Compute distance matrix.
25 |
26 | Returns
27 | -------
28 | D: array, shape = [m, n]
29 | Distance matrix.
30 | """
31 | return euclidean_distances(self.X, self.Y, squared=True)
32 |
33 | def jacobian_product(self, E):
34 | """
35 | Compute the product between the Jacobian
36 | (a linear map from m x d to m x n) and a matrix E.
37 |
38 | Parameters
39 | ----------
40 | E: array, shape = [m, n]
41 | Second time series.
42 |
43 | Returns
44 | -------
45 | G: array, shape = [m, d]
46 | Product with Jacobian
47 | ([m x d, m x n] * [m x n] = [m x d]).
48 | """
49 | G = np.zeros_like(self.X)
50 |
51 | _jacobian_product_sq_euc(self.X, self.Y, E, G)
52 |
53 | return G
54 |
--------------------------------------------------------------------------------
/sdtw/path.py:
--------------------------------------------------------------------------------
1 | # Author: Mathieu Blondel
2 | # License: Simplified BSD
3 |
4 | import numpy as np
5 |
6 |
7 | def delannoy_num(m, n):
8 | """
9 | Number of paths from the southwest corner (0, 0) of a rectangular grid to
10 | the northeast corner (m, n), using only single steps north, northeast, or
11 | east.
12 |
13 | Named after French army officer and amateur mathematician Henri Delannoy.
14 |
15 | Parameters
16 | ----------
17 | m, n : int, int
18 | Northeast corner coordinates.
19 |
20 | Returns
21 | -------
22 | delannoy_num: int
23 | Delannoy number.
24 |
25 | Reference
26 | ---------
27 | https://en.wikipedia.org/wiki/Delannoy_number
28 | """
29 | a = np.zeros([m+1, n+1])
30 | a[0,0] = 1
31 |
32 | for i in range(1, m+1):
33 | a[i,0] = 1
34 |
35 | for j in range(1, n+1):
36 | a[0,j] = 1
37 |
38 | for i in range(1, m+1):
39 | for j in range(1, n+1):
40 | a[i,j] = a[i-1, j] + a[i, j-1] + a[i-1, j-1]
41 |
42 | return a[m, n]
43 |
44 |
45 | def gen_all_paths(m, n, start=None, M=None):
46 | """
47 | Generator that produces all possible paths between (1, 1) and (m, n), using
48 | only north, northeast, or east steps. Each path is represented as a (m, n)
49 | numpy array with ones indicating the path.
50 |
51 | Parameters
52 | ----------
53 | m, n : int, int
54 | Northeast corner coordinates.
55 | """
56 | if start is None:
57 | start = [0, 0]
58 | M = np.zeros((m, n))
59 |
60 | i, j = start
61 | M[i, j] = 1
62 | ret = []
63 |
64 | if i == m-1 and j == n-1:
65 | yield M
66 | else:
67 | if i < m - 1:
68 | # Can use yield_from starting from Python 3.3.
69 | for mat in gen_all_paths(m, n, (i+1, j), M.copy()):
70 | yield mat
71 | if i < m-1 and j < n-1:
72 | for mat in gen_all_paths(m, n, (i+1, j+1), M.copy()):
73 | yield mat
74 | if j < n-1:
75 | for mat in gen_all_paths(m, n, (i, j+1), M.copy()):
76 | yield mat
77 |
--------------------------------------------------------------------------------
/sdtw/setup.py:
--------------------------------------------------------------------------------
1 | import os.path
2 |
3 | import numpy
4 |
5 |
6 | def configuration(parent_package='', top_path=None):
7 | from numpy.distutils.misc_util import Configuration
8 |
9 | config = Configuration('sdtw', parent_package, top_path)
10 |
11 | config.add_extension('soft_dtw_fast', sources=['soft_dtw_fast.c'],
12 | include_dirs=[numpy.get_include()])
13 |
14 | config.add_subpackage('tests')
15 |
16 | return config
17 |
18 |
19 | if __name__ == '__main__':
20 | from numpy.distutils.core import setup
21 | setup(**configuration(top_path='').todict())
22 |
--------------------------------------------------------------------------------
/sdtw/soft_dtw.py:
--------------------------------------------------------------------------------
1 | # Author: Mathieu Blondel
2 | # License: Simplified BSD
3 |
4 | import numpy as np
5 |
6 | from .soft_dtw_fast import _soft_dtw
7 | from .soft_dtw_fast import _soft_dtw_grad
8 |
9 |
10 | class SoftDTW(object):
11 |
12 | def __init__(self, D, gamma=1.0):
13 | """
14 | Parameters
15 | ----------
16 | D: array, shape = [m, n] or distance object
17 | Distance matrix between elements of two time series.
18 |
19 | gamma: float
20 | Regularization parameter.
21 | Lower is less smoothed (closer to true DTW).
22 |
23 | Attributes
24 | ----------
25 | self.R_: array, shape = [m + 2, n + 2]
26 | Accumulated cost matrix (stored after calling `compute`).
27 | """
28 | if hasattr(D, "compute"):
29 | self.D = D.compute()
30 | else:
31 | self.D = D
32 |
33 | self.D = self.D.astype(np.float64)
34 |
35 | self.gamma = gamma
36 |
37 | def compute(self):
38 | """
39 | Compute soft-DTW by dynamic programming.
40 |
41 | Returns
42 | -------
43 | sdtw: float
44 | soft-DTW discrepancy.
45 | """
46 | m, n = self.D.shape
47 |
48 | # Allocate memory.
49 | # We need +2 because we use indices starting from 1
50 | # and to deal with edge cases in the backward recursion.
51 | self.R_ = np.zeros((m+2, n+2), dtype=np.float64)
52 |
53 | _soft_dtw(self.D, self.R_, gamma=self.gamma)
54 |
55 | return self.R_[m, n]
56 |
57 | def grad(self):
58 | """
59 | Compute gradient of soft-DTW w.r.t. D by dynamic programming.
60 |
61 | Returns
62 | -------
63 | grad: array, shape = [m, n]
64 | Gradient w.r.t. D.
65 | """
66 | if not hasattr(self, "R_"):
67 | raise ValueError("Needs to call compute() first.")
68 |
69 | m, n = self.D.shape
70 |
71 | # Add an extra row and an extra column to D.
72 | # Needed to deal with edge cases in the recursion.
73 | D = np.vstack((self.D, np.zeros(n)))
74 | D = np.hstack((D, np.zeros((m+1, 1))))
75 |
76 | # Allocate memory.
77 | # We need +2 because we use indices starting from 1
78 | # and to deal with edge cases in the recursion.
79 | E = np.zeros((m+2, n+2))
80 |
81 | _soft_dtw_grad(D, self.R_, E, gamma=self.gamma)
82 |
83 | return E[1:-1, 1:-1]
84 |
--------------------------------------------------------------------------------
/sdtw/soft_dtw_fast.pyx:
--------------------------------------------------------------------------------
1 | # Author: Mathieu Blondel
2 | # License: Simplified BSD
3 |
4 | # encoding: utf-8
5 | # cython: cdivision=True
6 | # cython: boundscheck=False
7 | # cython: wraparound=False
8 |
9 | import numpy as np
10 | cimport numpy as np
11 | np.import_array()
12 |
13 |
14 | from libc.float cimport DBL_MAX
15 | from libc.math cimport exp, log
16 | from libc.string cimport memset
17 |
18 |
19 | cdef inline double _softmin3(double a,
20 | double b,
21 | double c,
22 | double gamma):
23 | a /= -gamma
24 | b /= -gamma
25 | c /= -gamma
26 |
27 | cdef double max_val = max(max(a, b), c)
28 |
29 | cdef double tmp = 0
30 | tmp += exp(a - max_val)
31 | tmp += exp(b - max_val)
32 | tmp += exp(c - max_val)
33 |
34 | return -gamma * (log(tmp) + max_val)
35 |
36 |
37 | def _soft_dtw(np.ndarray[double, ndim=2] D,
38 | np.ndarray[double, ndim=2] R,
39 | double gamma):
40 |
41 | cdef int m = D.shape[0]
42 | cdef int n = D.shape[1]
43 |
44 | cdef int i, j
45 |
46 | # Initialization.
47 | memset(R.data, 0, (m+1) * (n+1) * sizeof(double))
48 |
49 | for i in range(m + 1):
50 | R[i, 0] = DBL_MAX
51 |
52 | for j in range(n + 1):
53 | R[0, j] = DBL_MAX
54 |
55 | R[0, 0] = 0
56 |
57 | # DP recursion.
58 | for i in range(1, m + 1):
59 | for j in range(1, n + 1):
60 | # D is indexed starting from 0.
61 | R[i, j] = D[i-1, j-1] + _softmin3(R[i-1, j],
62 | R[i-1, j-1],
63 | R[i, j-1],
64 | gamma)
65 |
66 |
67 | def _soft_dtw_grad(np.ndarray[double, ndim=2] D,
68 | np.ndarray[double, ndim=2] R,
69 | np.ndarray[double, ndim=2] E,
70 | double gamma):
71 |
72 | # We added an extra row and an extra column on the Python side.
73 | cdef int m = D.shape[0] - 1
74 | cdef int n = D.shape[1] - 1
75 |
76 | cdef int i, j
77 | cdef double a, b, c
78 |
79 | # Initialization.
80 | memset(E.data, 0, (m+2) * (n+2) * sizeof(double))
81 |
82 | for i in range(1, m+1):
83 | # For D, indices start from 0 throughout.
84 | D[i-1, n] = 0
85 | R[i, n+1] = -DBL_MAX
86 |
87 | for j in range(1, n+1):
88 | D[m, j-1] = 0
89 | R[m+1, j] = -DBL_MAX
90 |
91 | E[m+1, n+1] = 1
92 | R[m+1, n+1] = R[m, n]
93 | D[m, n] = 0
94 |
95 | # DP recursion.
96 | for j in reversed(range(1, n+1)): # ranges from n to 1
97 | for i in reversed(range(1, m+1)): # ranges from m to 1
98 | a = exp((R[i+1, j] - R[i, j] - D[i, j-1]) / gamma)
99 | b = exp((R[i, j+1] - R[i, j] - D[i-1, j]) / gamma)
100 | c = exp((R[i+1, j+1] - R[i, j] - D[i, j]) / gamma)
101 | E[i, j] = E[i+1, j] * a + E[i, j+1] * b + E[i+1,j+1] * c
102 |
103 |
104 | def _jacobian_product_sq_euc(np.ndarray[double, ndim=2] X,
105 | np.ndarray[double, ndim=2] Y,
106 | np.ndarray[double, ndim=2] E,
107 | np.ndarray[double, ndim=2] G):
108 | cdef int m = X.shape[0]
109 | cdef int n = Y.shape[0]
110 | cdef int d = X.shape[1]
111 |
112 | for i in range(m):
113 | for j in range(n):
114 | for k in range(d):
115 | G[i, k] += E[i,j] * 2 * (X[i, k] - Y[j, k])
116 |
--------------------------------------------------------------------------------
/sdtw/tests/test_chainer_func.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from chainer import Variable
3 |
4 | from sdtw.dataset import load_ucr
5 | from sdtw.chainer_func import SoftDTWLoss
6 | from scipy.optimize import check_grad
7 |
8 |
9 | def _func(z, X):
10 | shape = (-1, X.shape[1])
11 | Z = z.reshape(*shape)
12 | return SoftDTWLoss(gamma=0.1)(Z, X).data
13 |
14 |
15 | def _grad(z, X):
16 | shape = (-1, X.shape[1])
17 | Z = z.reshape(*shape)
18 | Z = Variable(Z)
19 | loss = SoftDTWLoss(gamma=0.1)(Z, X)
20 | loss.backward(retain_grad=True)
21 | return Z.grad.ravel()
22 |
23 |
24 | def test_grad():
25 | rng = np.random.RandomState(0)
26 | X = rng.randn(10, 2)
27 | Z = rng.randn(8, 2)
28 | print(check_grad(_func, _grad, Z.ravel(), X))
29 |
--------------------------------------------------------------------------------
/sdtw/tests/test_path.py:
--------------------------------------------------------------------------------
1 | from sklearn.utils.testing import assert_equal
2 |
3 |
4 | from sdtw.path import gen_all_paths
5 | from sdtw.path import delannoy_num
6 |
7 |
8 | def test_gen_all_paths():
9 | assert_equal(len(list(gen_all_paths(2, 2))), 3)
10 | assert_equal(len(list(gen_all_paths(3, 2))), 5)
11 | assert_equal(len(list(gen_all_paths(4, 2))), 7)
12 | # delannoy_num counts paths from (0,0),
13 | # while gen_all_paths starts from (1,1).
14 | assert_equal(len(list(gen_all_paths(5, 7))), delannoy_num(5-1, 7-1))
15 | assert_equal(len(list(gen_all_paths(8, 6))), delannoy_num(8-1, 6-1))
16 |
--------------------------------------------------------------------------------
/sdtw/tests/test_soft_dtw.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from scipy.optimize import approx_fprime
4 |
5 | from sklearn.metrics.pairwise import euclidean_distances
6 | from sklearn.utils.testing import assert_almost_equal
7 | from sklearn.utils.testing import assert_array_almost_equal
8 |
9 | from sdtw.path import gen_all_paths
10 | from sdtw.distance import SquaredEuclidean
11 | from sdtw import SoftDTW
12 |
13 | # Generate two inputs randomly.
14 | rng = np.random.RandomState(0)
15 | X = rng.randn(5, 4)
16 | Y = rng.randn(6, 4)
17 | D = euclidean_distances(X, Y, squared=True)
18 |
19 |
20 | def _softmax(z):
21 | max_val = np.max(z)
22 | return max_val + np.log(np.exp(z - max_val).sum())
23 |
24 |
25 | def _softmin(z, gamma):
26 | z = np.array(z)
27 | return -gamma * _softmax(-z / gamma)
28 |
29 |
30 | def _soft_dtw_bf(D, gamma):
31 | costs = [np.sum(A * D) for A in gen_all_paths(D.shape[0], D.shape[1])]
32 | return _softmin(costs, gamma)
33 |
34 |
35 | def test_soft_dtw():
36 | for gamma in (0.001, 0.01, 0.1, 1, 10, 100, 1000):
37 | assert_almost_equal(SoftDTW(D, gamma).compute(),
38 | _soft_dtw_bf(D, gamma=gamma))
39 |
40 | def test_soft_dtw_grad():
41 | def make_func(gamma):
42 | def func(d):
43 | D_ = d.reshape(*D.shape)
44 | return SoftDTW(D_, gamma).compute()
45 | return func
46 |
47 | for gamma in (0.001, 0.01, 0.1, 1, 10, 100, 1000):
48 | sdtw = SoftDTW(D, gamma)
49 | sdtw.compute()
50 | E = sdtw.grad()
51 | func = make_func(gamma)
52 | E_num = approx_fprime(D.ravel(), func, 1e-6).reshape(*E.shape)
53 | assert_array_almost_equal(E, E_num, 5)
54 |
55 |
56 | def test_soft_dtw_grad_X():
57 | def make_func(gamma):
58 | def func(x):
59 | X_ = x.reshape(*X.shape)
60 | D_ = SquaredEuclidean(X_, Y)
61 | return SoftDTW(D_, gamma).compute()
62 | return func
63 |
64 | for gamma in (0.001, 0.01, 0.1, 1, 10, 100, 1000):
65 | dist = SquaredEuclidean(X, Y)
66 | sdtw = SoftDTW(dist, gamma)
67 | sdtw.compute()
68 | E = sdtw.grad()
69 | G = dist.jacobian_product(E)
70 |
71 | func = make_func(gamma)
72 | G_num = approx_fprime(X.ravel(), func, 1e-6).reshape(*G.shape)
73 | assert_array_almost_equal(G, G_num, 5)
74 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os.path
3 | import sys
4 | import setuptools
5 | from numpy.distutils.core import setup
6 |
7 |
8 | try:
9 | import numpy
10 | except ImportError:
11 | print('numpy is required during installation')
12 | sys.exit(1)
13 |
14 |
15 | DISTNAME = 'soft-dtw'
16 | DESCRIPTION = "Python implementation of soft-DTW"
17 | LONG_DESCRIPTION = open('README.rst').read()
18 | MAINTAINER = 'Mathieu Blondel'
19 | MAINTAINER_EMAIL = ''
20 | URL = 'https://github.com/mblondel/soft-dtw/'
21 | LICENSE = 'Simplified BSD'
22 | DOWNLOAD_URL = 'https://github.com/mblondel/soft-dtw/'
23 | VERSION = '0.1.dev0'
24 |
25 |
26 | def configuration(parent_package='', top_path=None):
27 | from numpy.distutils.misc_util import Configuration
28 |
29 | config = Configuration(None, parent_package, top_path)
30 |
31 | config.add_subpackage('sdtw')
32 |
33 | return config
34 |
35 |
36 | if __name__ == '__main__':
37 | old_path = os.getcwd()
38 | local_path = os.path.dirname(os.path.abspath(sys.argv[0]))
39 |
40 | os.chdir(local_path)
41 | sys.path.insert(0, local_path)
42 |
43 | setup(configuration=configuration,
44 | name=DISTNAME,
45 | maintainer=MAINTAINER,
46 | include_package_data=True,
47 | maintainer_email=MAINTAINER_EMAIL,
48 | description=DESCRIPTION,
49 | license=LICENSE,
50 | url=URL,
51 | version=VERSION,
52 | download_url=DOWNLOAD_URL,
53 | long_description=LONG_DESCRIPTION,
54 | zip_safe=False, # the package can run out of an .egg file
55 | classifiers=[
56 | 'Intended Audience :: Science/Research',
57 | 'Intended Audience :: Developers', 'License :: OSI Approved',
58 | 'Programming Language :: C', 'Programming Language :: Python',
59 | 'Topic :: Software Development',
60 | 'Topic :: Scientific/Engineering',
61 | 'Operating System :: Microsoft :: Windows',
62 | 'Operating System :: POSIX', 'Operating System :: Unix',
63 | 'Operating System :: MacOS'
64 | ]
65 | )
66 |
--------------------------------------------------------------------------------