├── .gitignore ├── LICENSE ├── Makefile ├── README.rst ├── examples ├── plot_barycenter.py ├── plot_chainer_MLP.py └── plot_interpolation.py ├── sdtw ├── __init__.py ├── barycenter.py ├── chainer_func.py ├── dataset.py ├── distance.py ├── path.py ├── setup.py ├── soft_dtw.py ├── soft_dtw_fast.pyx └── tests │ ├── test_chainer_func.py │ ├── test_path.py │ └── test_soft_dtw.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | build 3 | *.c 4 | *.so 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, Mathieu Blondel 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation and/or 12 | other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 15 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 16 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 17 | IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 18 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT 19 | NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, 20 | OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 21 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 22 | OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF 23 | THE POSSIBILITY OF SUCH DAMAGE. 24 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PYTHON ?= python 2 | CYTHON ?= cython 3 | NOSETESTS ?= nosetests 4 | 5 | # Compilation... 6 | 7 | CYTHONSRC= $(wildcard sdtw/*.pyx) 8 | CSRC= $(CYTHONSRC:.pyx=.c) 9 | 10 | inplace: 11 | $(PYTHON) setup.py build_ext -i 12 | 13 | all: cython inplace 14 | 15 | cython: $(CSRC) 16 | 17 | clean: 18 | rm -f sdtw/*.c sdtw/*.html 19 | rm -f `find sdtw -name "*.pyc"` 20 | rm -f `find sdtw -name "*.so"` 21 | 22 | %.c: %.pyx 23 | $(CYTHON) $< 24 | 25 | # Tests... 26 | # 27 | test-code: inplace 28 | $(NOSETESTS) -s sdtw 29 | 30 | test-coverage: 31 | $(NOSETESTS) -s --with-coverage --cover-html --cover-html-dir=coverage \ 32 | --cover-package=sdtw sdtw 33 | 34 | 35 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | .. -*- mode: rst -*- 2 | 3 | soft-DTW 4 | ========= 5 | 6 | Python implementation of soft-DTW. 7 | 8 | What is it? 9 | ----------- 10 | 11 | The celebrated dynamic time warping (DTW) [1] defines the discrepancy between 12 | two time series, of possibly variable length, as their minimal alignment cost. 13 | Although the number of possible alignments is exponential in the length of the 14 | two time series, [1] showed that DTW can be computed in only quadractic time 15 | using dynamic programming. 16 | 17 | Soft-DTW [2] proposes to replace this minimum by a soft minimum. Like the 18 | original DTW, soft-DTW can be computed in quadratic time using dynamic 19 | programming. However, the main advantage of soft-DTW stems from the fact that 20 | it is differentiable everywhere and that its gradient can also be computed in 21 | quadratic time. This enables to use soft-DTW for time series averaging or as a 22 | loss function, between a ground-truth time series and a time series predicted 23 | by a neural network, trained end-to-end using backpropagation. 24 | 25 | Supported features 26 | ------------------ 27 | 28 | * soft-DTW (forward pass) and gradient (backward pass) computations, 29 | implemented in Cython for speed 30 | * barycenters (time series averaging) 31 | * dataset loader for the `UCR archive `_ 32 | * `Chainer `_ function 33 | 34 | Example 35 | -------- 36 | 37 | .. code-block:: python 38 | 39 | from sdtw import SoftDTW 40 | from sdtw.distance import SquaredEuclidean 41 | 42 | # Time series 1: numpy array, shape = [m, d] where m = length and d = dim 43 | X = ... 44 | # Time series 2: numpy array, shape = [n, d] where n = length and d = dim 45 | Y = ... 46 | 47 | # D can also be an arbitrary distance matrix: numpy array, shape [m, n] 48 | D = SquaredEuclidean(X, Y) 49 | sdtw = SoftDTW(D, gamma=1.0) 50 | # soft-DTW discrepancy, approaches DTW as gamma -> 0 51 | value = sdtw.compute() 52 | # gradient w.r.t. D, shape = [m, n], which is also the expected alignment matrix 53 | E = sdtw.grad() 54 | # gradient w.r.t. X, shape = [m, d] 55 | G = D.jacobian_product(E) 56 | 57 | Installation 58 | ------------ 59 | 60 | Binary packages are not available. 61 | 62 | This project can be installed from its git repository. It is assumed that you 63 | have a working C compiler. 64 | 65 | 1. Obtain the sources by:: 66 | 67 | git clone https://github.com/mblondel/soft-dtw.git 68 | 69 | or, if `git` is unavailable, `download as a ZIP from GitHub `_. 70 | 71 | 72 | 2. Install the dependencies:: 73 | 74 | # via pip 75 | 76 | pip install numpy scipy scikit-learn cython nose 77 | 78 | 79 | # via conda 80 | 81 | conda install numpy scipy scikit-learn cython nose 82 | 83 | 84 | 3. Build and install soft-dtw:: 85 | 86 | cd soft-dtw 87 | make cython 88 | python setup.py build 89 | sudo python setup.py install 90 | 91 | 92 | References 93 | ---------- 94 | 95 | .. [1] Hiroaki Sakoe, Seibi Chiba. 96 | *Dynamic programming algorithm optimization for spoken word recognition.* 97 | In: IEEE Trans. on Acoustics, Speech, and Sig. Proc, 1978. 98 | 99 | .. [2] Marco Cuturi, Mathieu Blondel. 100 | *Soft-DTW: a Differentiable Loss Function for Time-Series.* 101 | In: Proc. of ICML 2017. 102 | [`PDF `_] 103 | 104 | Author 105 | ------ 106 | 107 | - Mathieu Blondel, 2017 108 | -------------------------------------------------------------------------------- /examples/plot_barycenter.py: -------------------------------------------------------------------------------- 1 | # Author: Mathieu Blondel 2 | # License: Simplified BSD 3 | 4 | import numpy as np 5 | 6 | import matplotlib.pylab as plt 7 | plt.style.use('ggplot') 8 | plt.rcParams["xtick.labelsize"] = 15 9 | plt.rcParams["ytick.labelsize"] = 15 10 | 11 | from sdtw.dataset import load_ucr 12 | from sdtw.barycenter import sdtw_barycenter 13 | 14 | 15 | X_tr, y_tr, X_te, y_te = load_ucr("ECG200") 16 | 17 | n = 10 18 | 19 | # Pick n time series at random from the same class. 20 | rng = np.random.RandomState(0) 21 | classes = np.unique(y_tr) 22 | k = rng.randint(len(classes)) 23 | X = X_tr[y_tr == classes[k]] 24 | X = X[rng.permutation(len(X))[:n]] 25 | 26 | fig = plt.figure(figsize=(15,4)) 27 | 28 | barycenter_init = sum(X) / len(X) 29 | 30 | fig_pos = 131 31 | 32 | for gamma in (0.1, 1, 10): 33 | ax = fig.add_subplot(fig_pos) 34 | 35 | for x in X: 36 | ax.plot(x.ravel(), c="k", linewidth=3, alpha=0.15) 37 | 38 | Z = sdtw_barycenter(X, barycenter_init, gamma=gamma) 39 | ax.plot(Z.ravel(), c="r", linewidth=3, alpha=0.7) 40 | ax.set_title(r"Soft-DTW ($\gamma$=%0.1f)" % gamma) 41 | 42 | fig_pos += 1 43 | 44 | plt.show() 45 | -------------------------------------------------------------------------------- /examples/plot_chainer_MLP.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | 5 | from chainer import training 6 | from chainer import iterators, optimizers, serializers 7 | from chainer import Chain 8 | import chainer.functions as F 9 | import chainer.links as L 10 | from chainer.datasets import tuple_dataset 11 | 12 | from sdtw.dataset import load_ucr 13 | from sdtw.chainer_func import SoftDTWLoss 14 | 15 | import matplotlib.pylab as plt 16 | import matplotlib.font_manager as fm 17 | plt.style.use('ggplot') 18 | plt.rcParams["xtick.labelsize"] = 15 19 | plt.rcParams["ytick.labelsize"] = 15 20 | 21 | 22 | def split_time_series(X_tr, X_te, proportion=0.6): 23 | len_ts = X_tr.shape[1] 24 | len_input=int(round(len_ts * proportion)) 25 | len_output=len_ts - len_input 26 | 27 | return np.float32(X_tr[:, :len_input, 0]), \ 28 | np.float32(X_tr[:, len_input:, 0]), \ 29 | np.float32(X_te[:, :len_input, 0]), \ 30 | np.float32(X_te[:, len_input:, 0]) 31 | 32 | 33 | class MLP(Chain): 34 | 35 | def __init__(self, len_input, len_output, activation="tanh", n_units=50): 36 | self.activation = activation 37 | 38 | super(MLP, self).__init__( 39 | mid = L.Linear(len_input, n_units), 40 | out=L.Linear(n_units, len_output), 41 | ) 42 | 43 | def __call__(self, x): 44 | # Given the current observation, predict the rest. 45 | xx = self.mid(x) 46 | func = getattr(F, self.activation) 47 | h = func(xx) 48 | y = self.out(h) 49 | return y 50 | 51 | 52 | class Objective(Chain): 53 | 54 | def __init__(self, predictor, loss="euclidean", gamma=1.0): 55 | self.loss = loss 56 | self.gamma = gamma 57 | super(Objective, self).__init__(predictor=predictor) 58 | 59 | def __call__(self, x, t): 60 | y = self.predictor(x) 61 | 62 | if self.loss == "euclidean": 63 | return F.mean_squared_error(y, t) 64 | 65 | elif self.loss == "sdtw": 66 | loss = 0 67 | for i in range(y.shape[0]): 68 | y_i = F.reshape(y[i], (-1,1)) 69 | t_i = F.reshape(t[i], (-1,1)) 70 | loss += SoftDTWLoss(self.gamma)(y_i, t_i) 71 | return loss 72 | 73 | else: 74 | raise ValueError("Unknown loss") 75 | 76 | 77 | def train(network, loss, X_tr, Y_tr, X_te, Y_te, n_epochs=30, gamma=1): 78 | model= Objective(network, loss=loss, gamma=gamma) 79 | 80 | #optimizer = optimizers.SGD() 81 | optimizer = optimizers.Adam() 82 | optimizer.setup(model) 83 | 84 | train = tuple_dataset.TupleDataset(X_tr, Y_tr) 85 | test = tuple_dataset.TupleDataset(X_te, Y_te) 86 | 87 | train_iter = iterators.SerialIterator(train, batch_size=1, shuffle=True) 88 | test_iter = iterators.SerialIterator(test, batch_size=1, repeat=False, 89 | shuffle=False) 90 | updater = training.StandardUpdater(train_iter, optimizer) 91 | trainer = training.Trainer(updater, (n_epochs, 'epoch')) 92 | 93 | trainer.run() 94 | 95 | 96 | if __name__ == '__main__': 97 | import os 98 | import sys 99 | 100 | try: 101 | dbname = sys.argv[1] 102 | except IndexError: 103 | dbname = "ECG200" 104 | 105 | X_tr, _, X_te, _ = load_ucr(dbname) 106 | 107 | proportion = 0.6 108 | n_units = 10 109 | n_epochs = 30 110 | gamma = 1 111 | warm_start = True 112 | 113 | X_te_ = X_te 114 | X_tr, Y_tr, X_te, Y_te = split_time_series(X_tr, X_te, proportion) 115 | 116 | len_input = X_tr.shape[1] 117 | len_output = Y_tr.shape[1] 118 | 119 | networks = [MLP(len_input, len_output, n_units=n_units),] 120 | losses = ["sdtw",] 121 | labels = ["Soft-DTW loss",] 122 | 123 | for i in range(len(networks)): 124 | if warm_start and i >= 1: 125 | # Warm-start with Euclidean-case solution 126 | networks[i].mid = copy.deepcopy(networks[0].mid) 127 | networks[i].out = copy.deepcopy(networks[0].out) 128 | 129 | train(networks[i], losses[i], X_tr, Y_tr, X_te, Y_te, 130 | n_epochs=n_epochs, gamma=gamma) 131 | 132 | max_vals = [] 133 | min_vals = [] 134 | 135 | fig = plt.figure(figsize=(10,6)) 136 | 137 | pos = 220 138 | 139 | for i in range(min(X_te.shape[0], 4)): 140 | pos += 1 141 | ax = fig.add_subplot(pos) 142 | 143 | inputseq = np.array([X_te[i]]) # Need to wrap as minibatch... 144 | 145 | # Plot predictions. 146 | for idx, label in enumerate(labels): 147 | output = networks[idx](inputseq) 148 | output = np.squeeze(np.array(output.data)) 149 | 150 | ax.plot(range(len_input + 1,len_input + len(output) + 1), 151 | output, 152 | alpha=0.75, 153 | lw=3, 154 | label=label, 155 | zorder=10) 156 | 157 | max_vals.append(output.max()) 158 | min_vals.append(output.min()) 159 | 160 | # Plot ground-truth time series. 161 | ground_truth = X_te_[i] 162 | max_vals.append(ground_truth.max()) 163 | min_vals.append(ground_truth.min()) 164 | ax.plot(ground_truth, 165 | c="k", 166 | alpha=0.3, 167 | lw=3, 168 | label='Ground truth', 169 | zorder=5) 170 | 171 | # Plot vertical line. 172 | ax.plot([len_input, len_input], 173 | [np.min(min_vals), np.max(max_vals)], lw=3, ls="--", c="k") 174 | 175 | # Legend. 176 | prop = fm.FontProperties(size=18) 177 | ax.legend(loc="best", prop=prop) 178 | 179 | fig.set_tight_layout(True) 180 | plt.show() 181 | -------------------------------------------------------------------------------- /examples/plot_interpolation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import matplotlib.pylab as plt 4 | plt.style.use('ggplot') 5 | plt.rcParams["xtick.labelsize"] = 15 6 | plt.rcParams["ytick.labelsize"] = 15 7 | 8 | from sdtw.barycenter import sdtw_barycenter 9 | from sdtw.dataset import load_ucr 10 | 11 | X_tr = load_ucr("Gun_Point")[0] 12 | X1, X2 = X_tr[7], X_tr[37] 13 | 14 | init_25 = 0.25 * X1 + 0.75 * X2 15 | init_50 = 0.50 * X1 + 0.50 * X2 16 | init_75 = 0.75 * X1 + 0.25 * X2 17 | 18 | bary_25 = sdtw_barycenter([X1, X2], init_25, gamma=1, max_iter=100, 19 | weights=[0.25, 0.75]) 20 | bary_50 = sdtw_barycenter([X1, X2], init_50, gamma=1, max_iter=100, 21 | weights=[0.50, 0.50]) 22 | bary_75 = sdtw_barycenter([X1, X2], init_75, gamma=1, max_iter=100, 23 | weights=[0.75, 0.25]) 24 | 25 | colors = [ 26 | (0, 51./255, 204./255), 27 | (102./255, 153./255, 255./255), 28 | (255./255, 102./255, 255./255), 29 | (255./255, 0, 102./255), 30 | (1.0, 51./255, 0), 31 | ] 32 | 33 | 34 | fig = plt.figure(figsize=(10,4)) 35 | 36 | ax = fig.add_subplot(121) 37 | 38 | ax.plot(X1.ravel(), c=colors[0], lw=3) 39 | ax.plot(bary_75, c=colors[1], lw=3, alpha=0.75) 40 | ax.plot(bary_50, c=colors[2], lw=3, alpha=0.75) 41 | ax.plot(bary_25, c=colors[3], lw=3, alpha=0.75) 42 | ax.plot(X2.ravel(), c=colors[4], lw=3) 43 | ax.set_title("Soft-DTW geometry") 44 | 45 | ax = fig.add_subplot(122) 46 | 47 | ax.plot(X1.ravel(), c=colors[0], lw=3) 48 | ax.plot(init_75, c=colors[1], lw=3, alpha=0.75) 49 | ax.plot(init_50, c=colors[2], lw=3, alpha=0.75) 50 | ax.plot(init_25, c=colors[3], lw=3, alpha=0.75) 51 | ax.plot(X2.ravel(), c=colors[4], lw=3) 52 | ax.set_title("Euclidean geometry") 53 | 54 | plt.show() 55 | -------------------------------------------------------------------------------- /sdtw/__init__.py: -------------------------------------------------------------------------------- 1 | from .soft_dtw import SoftDTW 2 | -------------------------------------------------------------------------------- /sdtw/barycenter.py: -------------------------------------------------------------------------------- 1 | # Author: Mathieu Blondel 2 | # License: Simplified BSD 3 | 4 | import numpy as np 5 | 6 | from scipy.optimize import minimize 7 | 8 | from sdtw import SoftDTW 9 | from sdtw.distance import SquaredEuclidean 10 | 11 | 12 | def sdtw_barycenter(X, barycenter_init, gamma=1.0, weights=None, 13 | method="L-BFGS-B", tol=1e-3, max_iter=50): 14 | """ 15 | Compute barycenter (time series averaging) under the soft-DTW geometry. 16 | 17 | Parameters 18 | ---------- 19 | X: list 20 | List of time series, numpy arrays of shape [len(X[i]), d]. 21 | 22 | barycenter_init: array, shape = [length, d] 23 | Initialization. 24 | 25 | gamma: float 26 | Regularization parameter. 27 | Lower is less smoothed (closer to true DTW). 28 | 29 | weights: None or array 30 | Weights of each X[i]. Must be the same size as len(X). 31 | 32 | method: string 33 | Optimization method, passed to `scipy.optimize.minimize`. 34 | Default: L-BFGS. 35 | 36 | tol: float 37 | Tolerance of the method used. 38 | 39 | max_iter: int 40 | Maximum number of iterations. 41 | """ 42 | if weights is None: 43 | weights = np.ones(len(X)) 44 | 45 | weights = np.array(weights) 46 | 47 | def _func(Z): 48 | # Compute objective value and grad at Z. 49 | 50 | Z = Z.reshape(*barycenter_init.shape) 51 | 52 | m = Z.shape[0] 53 | G = np.zeros_like(Z) 54 | 55 | obj = 0 56 | 57 | for i in range(len(X)): 58 | D = SquaredEuclidean(Z, X[i]) 59 | sdtw = SoftDTW(D, gamma=gamma) 60 | value = sdtw.compute() 61 | E = sdtw.grad() 62 | G_tmp = D.jacobian_product(E) 63 | G += weights[i] * G_tmp 64 | obj += weights[i] * value 65 | 66 | return obj, G.ravel() 67 | 68 | # The function works with vectors so we need to vectorize barycenter_init. 69 | res = minimize(_func, barycenter_init.ravel(), method=method, jac=True, 70 | tol=tol, options=dict(maxiter=max_iter, disp=False)) 71 | 72 | return res.x.reshape(*barycenter_init.shape) 73 | -------------------------------------------------------------------------------- /sdtw/chainer_func.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import chainer 4 | 5 | from chainer import Function 6 | 7 | from .soft_dtw import SoftDTW 8 | from .distance import SquaredEuclidean 9 | 10 | 11 | class SoftDTWLoss(Function): 12 | 13 | def __init__(self, gamma): 14 | self.gamma = gamma 15 | 16 | def forward_cpu(self, inputs): 17 | # Z, X: both are arrays of shape length x n_dim 18 | Z, X = inputs 19 | 20 | assert Z.shape[1] == X.shape[1] 21 | 22 | D = SquaredEuclidean(Z, X) 23 | self.sdtw_ = SoftDTW(D, gamma=self.gamma) 24 | loss = self.sdtw_.compute() 25 | 26 | return np.array(loss), 27 | 28 | def backward_cpu(self, inputs, grad_outputs): 29 | Z, X = inputs 30 | # g has the same shape as the output of forward_cpu(). 31 | # g should always be 1 since it's the last function (loss function) 32 | g, = grad_outputs 33 | 34 | D = SquaredEuclidean(Z, X) 35 | E = self.sdtw_.grad() 36 | gZ = D.jacobian_product(E).astype(Z.dtype) 37 | 38 | # We don't need the gradient w.r.t. the 2nd argument. 39 | return gZ, np.zeros_like(X) 40 | -------------------------------------------------------------------------------- /sdtw/dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import numpy as np 3 | 4 | 5 | home = os.path.expanduser("~") 6 | data_dir = os.path.join(home, "sdtw_data") 7 | ucr_dir = os.path.join(data_dir, "UCR_TS_Archive_2015") 8 | 9 | 10 | def _parse_ucr(filename): 11 | y = [] 12 | X = [] 13 | for line in open(filename): 14 | line = line.strip() 15 | arr = line.split(",") 16 | label = int(arr[0]) 17 | feat = list(map(float, arr[1:])) 18 | feat = np.array(feat).reshape(-1, 1) 19 | y.append(label) 20 | X.append(feat) 21 | return X, np.array(y) 22 | 23 | 24 | def list_ucr(): 25 | return sorted(os.listdir(ucr_dir)) 26 | 27 | 28 | def load_ucr(name): 29 | folder = os.path.join(ucr_dir, name) 30 | tr = os.path.join(folder, "%s_TRAIN" % name) 31 | te = os.path.join(folder, "%s_TEST" % name) 32 | 33 | try: 34 | X_tr, y_tr = _parse_ucr(tr) 35 | X_te, y_te = _parse_ucr(te) 36 | except IOError: 37 | raise IOError("Please copy UCR_TS_Archive_2015/ to $HOME/sdtw_data/. " 38 | "Download from www.cs.ucr.edu/~eamonn/time_series_data.") 39 | 40 | y_tr = np.array(y_tr) 41 | y_te = np.array(y_te) 42 | X_tr = np.array(X_tr) 43 | X_te = np.array(X_te) 44 | 45 | return X_tr, y_tr, X_te, y_te 46 | -------------------------------------------------------------------------------- /sdtw/distance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from sklearn.metrics.pairwise import euclidean_distances 4 | 5 | from .soft_dtw_fast import _jacobian_product_sq_euc 6 | 7 | class SquaredEuclidean(object): 8 | 9 | def __init__(self, X, Y): 10 | """ 11 | Parameters 12 | ---------- 13 | X: array, shape = [m, d] 14 | First time series. 15 | 16 | Y: array, shape = [n, d] 17 | Second time series. 18 | """ 19 | self.X = X.astype(np.float64) 20 | self.Y = Y.astype(np.float64) 21 | 22 | def compute(self): 23 | """ 24 | Compute distance matrix. 25 | 26 | Returns 27 | ------- 28 | D: array, shape = [m, n] 29 | Distance matrix. 30 | """ 31 | return euclidean_distances(self.X, self.Y, squared=True) 32 | 33 | def jacobian_product(self, E): 34 | """ 35 | Compute the product between the Jacobian 36 | (a linear map from m x d to m x n) and a matrix E. 37 | 38 | Parameters 39 | ---------- 40 | E: array, shape = [m, n] 41 | Second time series. 42 | 43 | Returns 44 | ------- 45 | G: array, shape = [m, d] 46 | Product with Jacobian 47 | ([m x d, m x n] * [m x n] = [m x d]). 48 | """ 49 | G = np.zeros_like(self.X) 50 | 51 | _jacobian_product_sq_euc(self.X, self.Y, E, G) 52 | 53 | return G 54 | -------------------------------------------------------------------------------- /sdtw/path.py: -------------------------------------------------------------------------------- 1 | # Author: Mathieu Blondel 2 | # License: Simplified BSD 3 | 4 | import numpy as np 5 | 6 | 7 | def delannoy_num(m, n): 8 | """ 9 | Number of paths from the southwest corner (0, 0) of a rectangular grid to 10 | the northeast corner (m, n), using only single steps north, northeast, or 11 | east. 12 | 13 | Named after French army officer and amateur mathematician Henri Delannoy. 14 | 15 | Parameters 16 | ---------- 17 | m, n : int, int 18 | Northeast corner coordinates. 19 | 20 | Returns 21 | ------- 22 | delannoy_num: int 23 | Delannoy number. 24 | 25 | Reference 26 | --------- 27 | https://en.wikipedia.org/wiki/Delannoy_number 28 | """ 29 | a = np.zeros([m+1, n+1]) 30 | a[0,0] = 1 31 | 32 | for i in range(1, m+1): 33 | a[i,0] = 1 34 | 35 | for j in range(1, n+1): 36 | a[0,j] = 1 37 | 38 | for i in range(1, m+1): 39 | for j in range(1, n+1): 40 | a[i,j] = a[i-1, j] + a[i, j-1] + a[i-1, j-1] 41 | 42 | return a[m, n] 43 | 44 | 45 | def gen_all_paths(m, n, start=None, M=None): 46 | """ 47 | Generator that produces all possible paths between (1, 1) and (m, n), using 48 | only north, northeast, or east steps. Each path is represented as a (m, n) 49 | numpy array with ones indicating the path. 50 | 51 | Parameters 52 | ---------- 53 | m, n : int, int 54 | Northeast corner coordinates. 55 | """ 56 | if start is None: 57 | start = [0, 0] 58 | M = np.zeros((m, n)) 59 | 60 | i, j = start 61 | M[i, j] = 1 62 | ret = [] 63 | 64 | if i == m-1 and j == n-1: 65 | yield M 66 | else: 67 | if i < m - 1: 68 | # Can use yield_from starting from Python 3.3. 69 | for mat in gen_all_paths(m, n, (i+1, j), M.copy()): 70 | yield mat 71 | if i < m-1 and j < n-1: 72 | for mat in gen_all_paths(m, n, (i+1, j+1), M.copy()): 73 | yield mat 74 | if j < n-1: 75 | for mat in gen_all_paths(m, n, (i, j+1), M.copy()): 76 | yield mat 77 | -------------------------------------------------------------------------------- /sdtw/setup.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import numpy 4 | 5 | 6 | def configuration(parent_package='', top_path=None): 7 | from numpy.distutils.misc_util import Configuration 8 | 9 | config = Configuration('sdtw', parent_package, top_path) 10 | 11 | config.add_extension('soft_dtw_fast', sources=['soft_dtw_fast.c'], 12 | include_dirs=[numpy.get_include()]) 13 | 14 | config.add_subpackage('tests') 15 | 16 | return config 17 | 18 | 19 | if __name__ == '__main__': 20 | from numpy.distutils.core import setup 21 | setup(**configuration(top_path='').todict()) 22 | -------------------------------------------------------------------------------- /sdtw/soft_dtw.py: -------------------------------------------------------------------------------- 1 | # Author: Mathieu Blondel 2 | # License: Simplified BSD 3 | 4 | import numpy as np 5 | 6 | from .soft_dtw_fast import _soft_dtw 7 | from .soft_dtw_fast import _soft_dtw_grad 8 | 9 | 10 | class SoftDTW(object): 11 | 12 | def __init__(self, D, gamma=1.0): 13 | """ 14 | Parameters 15 | ---------- 16 | D: array, shape = [m, n] or distance object 17 | Distance matrix between elements of two time series. 18 | 19 | gamma: float 20 | Regularization parameter. 21 | Lower is less smoothed (closer to true DTW). 22 | 23 | Attributes 24 | ---------- 25 | self.R_: array, shape = [m + 2, n + 2] 26 | Accumulated cost matrix (stored after calling `compute`). 27 | """ 28 | if hasattr(D, "compute"): 29 | self.D = D.compute() 30 | else: 31 | self.D = D 32 | 33 | self.D = self.D.astype(np.float64) 34 | 35 | self.gamma = gamma 36 | 37 | def compute(self): 38 | """ 39 | Compute soft-DTW by dynamic programming. 40 | 41 | Returns 42 | ------- 43 | sdtw: float 44 | soft-DTW discrepancy. 45 | """ 46 | m, n = self.D.shape 47 | 48 | # Allocate memory. 49 | # We need +2 because we use indices starting from 1 50 | # and to deal with edge cases in the backward recursion. 51 | self.R_ = np.zeros((m+2, n+2), dtype=np.float64) 52 | 53 | _soft_dtw(self.D, self.R_, gamma=self.gamma) 54 | 55 | return self.R_[m, n] 56 | 57 | def grad(self): 58 | """ 59 | Compute gradient of soft-DTW w.r.t. D by dynamic programming. 60 | 61 | Returns 62 | ------- 63 | grad: array, shape = [m, n] 64 | Gradient w.r.t. D. 65 | """ 66 | if not hasattr(self, "R_"): 67 | raise ValueError("Needs to call compute() first.") 68 | 69 | m, n = self.D.shape 70 | 71 | # Add an extra row and an extra column to D. 72 | # Needed to deal with edge cases in the recursion. 73 | D = np.vstack((self.D, np.zeros(n))) 74 | D = np.hstack((D, np.zeros((m+1, 1)))) 75 | 76 | # Allocate memory. 77 | # We need +2 because we use indices starting from 1 78 | # and to deal with edge cases in the recursion. 79 | E = np.zeros((m+2, n+2)) 80 | 81 | _soft_dtw_grad(D, self.R_, E, gamma=self.gamma) 82 | 83 | return E[1:-1, 1:-1] 84 | -------------------------------------------------------------------------------- /sdtw/soft_dtw_fast.pyx: -------------------------------------------------------------------------------- 1 | # Author: Mathieu Blondel 2 | # License: Simplified BSD 3 | 4 | # encoding: utf-8 5 | # cython: cdivision=True 6 | # cython: boundscheck=False 7 | # cython: wraparound=False 8 | 9 | import numpy as np 10 | cimport numpy as np 11 | np.import_array() 12 | 13 | 14 | from libc.float cimport DBL_MAX 15 | from libc.math cimport exp, log 16 | from libc.string cimport memset 17 | 18 | 19 | cdef inline double _softmin3(double a, 20 | double b, 21 | double c, 22 | double gamma): 23 | a /= -gamma 24 | b /= -gamma 25 | c /= -gamma 26 | 27 | cdef double max_val = max(max(a, b), c) 28 | 29 | cdef double tmp = 0 30 | tmp += exp(a - max_val) 31 | tmp += exp(b - max_val) 32 | tmp += exp(c - max_val) 33 | 34 | return -gamma * (log(tmp) + max_val) 35 | 36 | 37 | def _soft_dtw(np.ndarray[double, ndim=2] D, 38 | np.ndarray[double, ndim=2] R, 39 | double gamma): 40 | 41 | cdef int m = D.shape[0] 42 | cdef int n = D.shape[1] 43 | 44 | cdef int i, j 45 | 46 | # Initialization. 47 | memset(R.data, 0, (m+1) * (n+1) * sizeof(double)) 48 | 49 | for i in range(m + 1): 50 | R[i, 0] = DBL_MAX 51 | 52 | for j in range(n + 1): 53 | R[0, j] = DBL_MAX 54 | 55 | R[0, 0] = 0 56 | 57 | # DP recursion. 58 | for i in range(1, m + 1): 59 | for j in range(1, n + 1): 60 | # D is indexed starting from 0. 61 | R[i, j] = D[i-1, j-1] + _softmin3(R[i-1, j], 62 | R[i-1, j-1], 63 | R[i, j-1], 64 | gamma) 65 | 66 | 67 | def _soft_dtw_grad(np.ndarray[double, ndim=2] D, 68 | np.ndarray[double, ndim=2] R, 69 | np.ndarray[double, ndim=2] E, 70 | double gamma): 71 | 72 | # We added an extra row and an extra column on the Python side. 73 | cdef int m = D.shape[0] - 1 74 | cdef int n = D.shape[1] - 1 75 | 76 | cdef int i, j 77 | cdef double a, b, c 78 | 79 | # Initialization. 80 | memset(E.data, 0, (m+2) * (n+2) * sizeof(double)) 81 | 82 | for i in range(1, m+1): 83 | # For D, indices start from 0 throughout. 84 | D[i-1, n] = 0 85 | R[i, n+1] = -DBL_MAX 86 | 87 | for j in range(1, n+1): 88 | D[m, j-1] = 0 89 | R[m+1, j] = -DBL_MAX 90 | 91 | E[m+1, n+1] = 1 92 | R[m+1, n+1] = R[m, n] 93 | D[m, n] = 0 94 | 95 | # DP recursion. 96 | for j in reversed(range(1, n+1)): # ranges from n to 1 97 | for i in reversed(range(1, m+1)): # ranges from m to 1 98 | a = exp((R[i+1, j] - R[i, j] - D[i, j-1]) / gamma) 99 | b = exp((R[i, j+1] - R[i, j] - D[i-1, j]) / gamma) 100 | c = exp((R[i+1, j+1] - R[i, j] - D[i, j]) / gamma) 101 | E[i, j] = E[i+1, j] * a + E[i, j+1] * b + E[i+1,j+1] * c 102 | 103 | 104 | def _jacobian_product_sq_euc(np.ndarray[double, ndim=2] X, 105 | np.ndarray[double, ndim=2] Y, 106 | np.ndarray[double, ndim=2] E, 107 | np.ndarray[double, ndim=2] G): 108 | cdef int m = X.shape[0] 109 | cdef int n = Y.shape[0] 110 | cdef int d = X.shape[1] 111 | 112 | for i in range(m): 113 | for j in range(n): 114 | for k in range(d): 115 | G[i, k] += E[i,j] * 2 * (X[i, k] - Y[j, k]) 116 | -------------------------------------------------------------------------------- /sdtw/tests/test_chainer_func.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from chainer import Variable 3 | 4 | from sdtw.dataset import load_ucr 5 | from sdtw.chainer_func import SoftDTWLoss 6 | from scipy.optimize import check_grad 7 | 8 | 9 | def _func(z, X): 10 | shape = (-1, X.shape[1]) 11 | Z = z.reshape(*shape) 12 | return SoftDTWLoss(gamma=0.1)(Z, X).data 13 | 14 | 15 | def _grad(z, X): 16 | shape = (-1, X.shape[1]) 17 | Z = z.reshape(*shape) 18 | Z = Variable(Z) 19 | loss = SoftDTWLoss(gamma=0.1)(Z, X) 20 | loss.backward(retain_grad=True) 21 | return Z.grad.ravel() 22 | 23 | 24 | def test_grad(): 25 | rng = np.random.RandomState(0) 26 | X = rng.randn(10, 2) 27 | Z = rng.randn(8, 2) 28 | print(check_grad(_func, _grad, Z.ravel(), X)) 29 | -------------------------------------------------------------------------------- /sdtw/tests/test_path.py: -------------------------------------------------------------------------------- 1 | from sklearn.utils.testing import assert_equal 2 | 3 | 4 | from sdtw.path import gen_all_paths 5 | from sdtw.path import delannoy_num 6 | 7 | 8 | def test_gen_all_paths(): 9 | assert_equal(len(list(gen_all_paths(2, 2))), 3) 10 | assert_equal(len(list(gen_all_paths(3, 2))), 5) 11 | assert_equal(len(list(gen_all_paths(4, 2))), 7) 12 | # delannoy_num counts paths from (0,0), 13 | # while gen_all_paths starts from (1,1). 14 | assert_equal(len(list(gen_all_paths(5, 7))), delannoy_num(5-1, 7-1)) 15 | assert_equal(len(list(gen_all_paths(8, 6))), delannoy_num(8-1, 6-1)) 16 | -------------------------------------------------------------------------------- /sdtw/tests/test_soft_dtw.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from scipy.optimize import approx_fprime 4 | 5 | from sklearn.metrics.pairwise import euclidean_distances 6 | from sklearn.utils.testing import assert_almost_equal 7 | from sklearn.utils.testing import assert_array_almost_equal 8 | 9 | from sdtw.path import gen_all_paths 10 | from sdtw.distance import SquaredEuclidean 11 | from sdtw import SoftDTW 12 | 13 | # Generate two inputs randomly. 14 | rng = np.random.RandomState(0) 15 | X = rng.randn(5, 4) 16 | Y = rng.randn(6, 4) 17 | D = euclidean_distances(X, Y, squared=True) 18 | 19 | 20 | def _softmax(z): 21 | max_val = np.max(z) 22 | return max_val + np.log(np.exp(z - max_val).sum()) 23 | 24 | 25 | def _softmin(z, gamma): 26 | z = np.array(z) 27 | return -gamma * _softmax(-z / gamma) 28 | 29 | 30 | def _soft_dtw_bf(D, gamma): 31 | costs = [np.sum(A * D) for A in gen_all_paths(D.shape[0], D.shape[1])] 32 | return _softmin(costs, gamma) 33 | 34 | 35 | def test_soft_dtw(): 36 | for gamma in (0.001, 0.01, 0.1, 1, 10, 100, 1000): 37 | assert_almost_equal(SoftDTW(D, gamma).compute(), 38 | _soft_dtw_bf(D, gamma=gamma)) 39 | 40 | def test_soft_dtw_grad(): 41 | def make_func(gamma): 42 | def func(d): 43 | D_ = d.reshape(*D.shape) 44 | return SoftDTW(D_, gamma).compute() 45 | return func 46 | 47 | for gamma in (0.001, 0.01, 0.1, 1, 10, 100, 1000): 48 | sdtw = SoftDTW(D, gamma) 49 | sdtw.compute() 50 | E = sdtw.grad() 51 | func = make_func(gamma) 52 | E_num = approx_fprime(D.ravel(), func, 1e-6).reshape(*E.shape) 53 | assert_array_almost_equal(E, E_num, 5) 54 | 55 | 56 | def test_soft_dtw_grad_X(): 57 | def make_func(gamma): 58 | def func(x): 59 | X_ = x.reshape(*X.shape) 60 | D_ = SquaredEuclidean(X_, Y) 61 | return SoftDTW(D_, gamma).compute() 62 | return func 63 | 64 | for gamma in (0.001, 0.01, 0.1, 1, 10, 100, 1000): 65 | dist = SquaredEuclidean(X, Y) 66 | sdtw = SoftDTW(dist, gamma) 67 | sdtw.compute() 68 | E = sdtw.grad() 69 | G = dist.jacobian_product(E) 70 | 71 | func = make_func(gamma) 72 | G_num = approx_fprime(X.ravel(), func, 1e-6).reshape(*G.shape) 73 | assert_array_almost_equal(G, G_num, 5) 74 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os.path 3 | import sys 4 | import setuptools 5 | from numpy.distutils.core import setup 6 | 7 | 8 | try: 9 | import numpy 10 | except ImportError: 11 | print('numpy is required during installation') 12 | sys.exit(1) 13 | 14 | 15 | DISTNAME = 'soft-dtw' 16 | DESCRIPTION = "Python implementation of soft-DTW" 17 | LONG_DESCRIPTION = open('README.rst').read() 18 | MAINTAINER = 'Mathieu Blondel' 19 | MAINTAINER_EMAIL = '' 20 | URL = 'https://github.com/mblondel/soft-dtw/' 21 | LICENSE = 'Simplified BSD' 22 | DOWNLOAD_URL = 'https://github.com/mblondel/soft-dtw/' 23 | VERSION = '0.1.dev0' 24 | 25 | 26 | def configuration(parent_package='', top_path=None): 27 | from numpy.distutils.misc_util import Configuration 28 | 29 | config = Configuration(None, parent_package, top_path) 30 | 31 | config.add_subpackage('sdtw') 32 | 33 | return config 34 | 35 | 36 | if __name__ == '__main__': 37 | old_path = os.getcwd() 38 | local_path = os.path.dirname(os.path.abspath(sys.argv[0])) 39 | 40 | os.chdir(local_path) 41 | sys.path.insert(0, local_path) 42 | 43 | setup(configuration=configuration, 44 | name=DISTNAME, 45 | maintainer=MAINTAINER, 46 | include_package_data=True, 47 | maintainer_email=MAINTAINER_EMAIL, 48 | description=DESCRIPTION, 49 | license=LICENSE, 50 | url=URL, 51 | version=VERSION, 52 | download_url=DOWNLOAD_URL, 53 | long_description=LONG_DESCRIPTION, 54 | zip_safe=False, # the package can run out of an .egg file 55 | classifiers=[ 56 | 'Intended Audience :: Science/Research', 57 | 'Intended Audience :: Developers', 'License :: OSI Approved', 58 | 'Programming Language :: C', 'Programming Language :: Python', 59 | 'Topic :: Software Development', 60 | 'Topic :: Scientific/Engineering', 61 | 'Operating System :: Microsoft :: Windows', 62 | 'Operating System :: POSIX', 'Operating System :: Unix', 63 | 'Operating System :: MacOS' 64 | ] 65 | ) 66 | --------------------------------------------------------------------------------