├── readme.md └── soft_dtw.py /readme.md: -------------------------------------------------------------------------------- 1 | pytorch-softdtw 2 | === 3 | 4 | An implementation of SoftDTW [1] for PyTorch. Should run pretty fast. 5 | 6 | More goodies: check out [pytorch-softdtw-cuda](https://github.com/Maghoumi/pytorch-softdtw-cuda) by Maghoumi, a heavily upgraded version that runs parallel on CUDA. 7 | 8 | Install 9 | --- 10 | 11 | Just paste this file around. I don't believe in Python's package managers. Setting up the environment took me longer than actually writing this thing. 12 | 13 | Depends on PyTorch and Numba. 14 | 15 | How to use 16 | --- 17 | 18 | `SoftDTW` autograd function computes the smoothed DTW distance (scalar) for a given distance matrix and calling backward on the result gives you the derivative (matrix) of the DTW distance (scalar) with respect to the distance matrix. 19 | 20 | As the original authors pointed out [1], the derivative is the same as the expected DTW path. This is comparable to forward-backward algorithm for HMM. 21 | 22 | You may also specify the temperature gamma (positive number). As gamma goes to zero, the result converges to that of the original "hard" DTW. 23 | 24 | ```python 25 | from soft_dtw import SoftDTW 26 | ... 27 | criterion = SoftDTW(gamma=1.0, normalize=True) # just like nn.MSELoss() 28 | ... 29 | loss = criterion(out, target) 30 | ``` 31 | 32 | 33 | ### Does it support pruning? 34 | 35 | No. You can mess with the loop (line 13 and line 35) yourself. 36 | 37 | License 38 | --- 39 | 40 | Look, I just took their paper and wrote a single-file Python thingy. If you want to say thanks, which is of course welcome, then buy me a drink. Not a big fan of beer though. Umeshu in soda will be nice. Yeah sweety sweety. 41 | 42 | Reference 43 | --- 44 | 45 | [1] M. Cuturi and M. Blondel. "Soft-DTW: a Differentiable Loss Function for Time-Series". In Proceedings of ICML 2017. 46 | 47 | -------------------------------------------------------------------------------- /soft_dtw.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from numba import jit 4 | from torch.autograd import Function 5 | 6 | @jit(nopython = True) 7 | def compute_softdtw(D, gamma): 8 | B = D.shape[0] 9 | N = D.shape[1] 10 | M = D.shape[2] 11 | R = np.ones((B, N + 2, M + 2)) * np.inf 12 | R[:, 0, 0] = 0 13 | for k in range(B): 14 | for j in range(1, M + 1): 15 | for i in range(1, N + 1): 16 | r0 = -R[k, i - 1, j - 1] / gamma 17 | r1 = -R[k, i - 1, j] / gamma 18 | r2 = -R[k, i, j - 1] / gamma 19 | rmax = max(max(r0, r1), r2) 20 | rsum = np.exp(r0 - rmax) + np.exp(r1 - rmax) + np.exp(r2 - rmax) 21 | softmin = - gamma * (np.log(rsum) + rmax) 22 | R[k, i, j] = D[k, i - 1, j - 1] + softmin 23 | return R 24 | 25 | @jit(nopython = True) 26 | def compute_softdtw_backward(D_, R, gamma): 27 | B = D_.shape[0] 28 | N = D_.shape[1] 29 | M = D_.shape[2] 30 | D = np.zeros((B, N + 2, M + 2)) 31 | E = np.zeros((B, N + 2, M + 2)) 32 | D[:, 1:N + 1, 1:M + 1] = D_ 33 | E[:, -1, -1] = 1 34 | R[:, : , -1] = -np.inf 35 | R[:, -1, :] = -np.inf 36 | R[:, -1, -1] = R[:, -2, -2] 37 | for k in range(B): 38 | for j in range(M, 0, -1): 39 | for i in range(N, 0, -1): 40 | a0 = (R[k, i + 1, j] - R[k, i, j] - D[k, i + 1, j]) / gamma 41 | b0 = (R[k, i, j + 1] - R[k, i, j] - D[k, i, j + 1]) / gamma 42 | c0 = (R[k, i + 1, j + 1] - R[k, i, j] - D[k, i + 1, j + 1]) / gamma 43 | a = np.exp(a0) 44 | b = np.exp(b0) 45 | c = np.exp(c0) 46 | E[k, i, j] = E[k, i + 1, j] * a + E[k, i, j + 1] * b + E[k, i + 1, j + 1] * c 47 | return E[:, 1:N + 1, 1:M + 1] 48 | 49 | class _SoftDTW(Function): 50 | @staticmethod 51 | def forward(ctx, D, gamma): 52 | dev = D.device 53 | dtype = D.dtype 54 | gamma = torch.Tensor([gamma]).to(dev).type(dtype) # dtype fixed 55 | D_ = D.detach().cpu().numpy() 56 | g_ = gamma.item() 57 | R = torch.Tensor(compute_softdtw(D_, g_)).to(dev).type(dtype) 58 | ctx.save_for_backward(D, R, gamma) 59 | return R[:, -2, -2] 60 | 61 | @staticmethod 62 | def backward(ctx, grad_output): 63 | dev = grad_output.device 64 | dtype = grad_output.dtype 65 | D, R, gamma = ctx.saved_tensors 66 | D_ = D.detach().cpu().numpy() 67 | R_ = R.detach().cpu().numpy() 68 | g_ = gamma.item() 69 | E = torch.Tensor(compute_softdtw_backward(D_, R_, g_)).to(dev).type(dtype) 70 | return grad_output.view(-1, 1, 1).expand_as(E) * E, None 71 | 72 | class SoftDTW(torch.nn.Module): 73 | def __init__(self, gamma=1.0, normalize=False): 74 | super(SoftDTW, self).__init__() 75 | self.normalize = normalize 76 | self.gamma=gamma 77 | self.func_dtw = _SoftDTW.apply 78 | 79 | def calc_distance_matrix(self, x, y): 80 | n = x.size(1) 81 | m = y.size(1) 82 | d = x.size(2) 83 | x = x.unsqueeze(2).expand(-1, n, m, d) 84 | y = y.unsqueeze(1).expand(-1, n, m, d) 85 | dist = torch.pow(x - y, 2).sum(3) 86 | return dist 87 | 88 | def forward(self, x, y): 89 | assert len(x.shape) == len(y.shape) 90 | squeeze = False 91 | if len(x.shape) < 3: 92 | x = x.unsqueeze(0) 93 | y = y.unsqueeze(0) 94 | squeeze = True 95 | if self.normalize: 96 | D_xy = self.calc_distance_matrix(x, y) 97 | out_xy = self.func_dtw(D_xy, self.gamma) 98 | D_xx = self.calc_distance_matrix(x, x) 99 | out_xx = self.func_dtw(D_xx, self.gamma) 100 | D_yy = self.calc_distance_matrix(y, y) 101 | out_yy = self.func_dtw(D_yy, self.gamma) 102 | result = out_xy - 1/2 * (out_xx + out_yy) # distance 103 | else: 104 | D_xy = self.calc_distance_matrix(x, y) 105 | out_xy = self.func_dtw(D_xy, self.gamma) 106 | result = out_xy # discrepancy 107 | return result.squeeze(0) if squeeze else result 108 | --------------------------------------------------------------------------------