├── README.md └── sdtw.py /README.md: -------------------------------------------------------------------------------- 1 | # stdw_pytorch 2 | Implementation of soft dynamic time warping in pytorch 3 | 4 | Here is my implementation of the Soft Dynamic Time Warping loss function described in https://arxiv.org/abs/1703.01541. 5 | 6 | Currently I have only a 'naive' implementation without extending the fast cython implementation in 7 | https://github.com/mblondel/soft-dtw to incorporate a batch dimension. If I continue to use this in my line of 8 | research I may implement a cython / CUDA version to increase speed. 9 | 10 | ####### 11 | See: https://github.com/Maghoumi/pytorch-softdtw-cuda for a ***much*** better implementation 12 | ####### 13 | -------------------------------------------------------------------------------- /sdtw.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author : Luke Y. Prince 3 | email : luke.prince@utoronto.ca 4 | github : lyprince 5 | date : 19 Feb 2019 6 | ''' 7 | 8 | class SoftDTWLoss(torch.nn.Module): 9 | ''' 10 | Soft-DTW (Dynamic Time Warping) Loss function as defined in Cuturi and Blondel (2017) Soft-DTW: 11 | a Differentiable Loss Function for Time-Series. In: Proc. of ICML 2017. 12 | https://arxiv.org/abs/1703.01541. 13 | ''' 14 | 15 | def __init__(self, gamma=1.0, spatial_independent=False, bandwidth=None): 16 | ''' 17 | __init__(self, gamma=1.0, spatial_independent=False): 18 | 19 | Arguments: 20 | gamma (float) : smoothing parameter (default=1.0) 21 | spatial_independent (bool) : argument to treat spatial dimensions as independent (default=False) 22 | When false, each time point x_t is treated as a vector in multi-dimensional 23 | space. When true, each time point x_t is treated as a set of independent scalars 24 | x_i,t. This is a short-cut for creating a 'false' singular spatial dimension such 25 | that data can continue to be treated as a 3-tensor of size (batch x space x time). 26 | TODO: implement for arbitrary spatial dimensions. 27 | bandwidth (int) : apply Sakoe-Chiba constraint 28 | ''' 29 | 30 | super(SoftDTWLoss, self).__init__() 31 | 32 | self.gamma = gamma 33 | self.spatial_independent = spatial_independent 34 | self.bandwidth = bandwidth 35 | 36 | def forward(self, x, y): 37 | ''' 38 | forward(self, x, y): 39 | 40 | Arguments: 41 | x (torch.Tensor): Time series data of size (batch_dim x space_dim x x_time_dim) 42 | y (torch.Tensor): Time series data of size (batch_dim x space_dim x y_time_dim) 43 | 44 | Returns: 45 | loss (torch.Tensor): Loss for each data point in batch. Size = batch_dim 46 | ''' 47 | 48 | return SoftDTWLossFunction.apply(x, y, (self.gamma, self.spatial_independent, self.bandwidth)) 49 | 50 | class SoftDTWLossFunction(torch.autograd.Function): 51 | ''' 52 | Custom autograd function for Soft DTW. 53 | 54 | See https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#pytorch-defining-new-autograd-functions 55 | for details on defining new autograd functions 56 | ''' 57 | 58 | @staticmethod 59 | def forward(ctx, x, y, params): 60 | 61 | ''' 62 | @staticmethod 63 | forward(ctx, x, y, params): 64 | 65 | Compute the forward pass for Soft-DTW, storing intermediate alignment costs R, and Squared Euclidean 66 | Distance costs D for use in the backward pass. See algorithm 1 in https://arxiv.org/abs/1703.01541 67 | 68 | Arguments: 69 | ctx : context 70 | x (torch.Tensor) : tensor of dimensions (batch_dim x space_dim x x_time_dim) 71 | y (torch.Tensor) : tensor of dimensions (batch_dim x space_dim x y_time_dim) 72 | 73 | Returns: 74 | SoftDTW_Loss (torch.Tensor) 75 | ''' 76 | 77 | # Store parameters in context variable 78 | 79 | gamma, spatial_independent, bandwidth = params 80 | ctx.gamma = gamma 81 | ctx.spatial_independent = spatial_independent 82 | ctx.bandwidth = bandwidth 83 | 84 | # Determine device and store in context variable 85 | ctx.device = 'cuda' if x.is_cuda else 'cpu' 86 | 87 | # Determine dimensions 88 | x_batch_dim, x_space_dim, x_time_dim = x.shape 89 | y_batch_dim, y_space_dim, y_time_dim = y.shape 90 | 91 | # Store dimensions in context variable 92 | ctx.x_time_dim = x_time_dim 93 | ctx.y_time_dim = y_time_dim 94 | 95 | # Check batch dimensions are equal 96 | if x_batch_dim == y_batch_dim: 97 | batch_dim = x_batch_dim 98 | ctx.batch_dim = batch_dim 99 | del x_batch_dim, y_batch_dim 100 | else: 101 | raise RuntimeError('Unequal batch dimensions') 102 | 103 | # Check space dimensions are equal 104 | if x_space_dim == y_space_dim: 105 | space_dim = x_space_dim 106 | ctx.space_dim = space_dim 107 | del x_space_dim, y_space_dim 108 | else: 109 | raise RuntimeError('Unequal space dimensions') 110 | 111 | # Determine dimensions for Squared Euclidean Distance Gram Matrix 112 | # +1 because padding needed at the end for backward function 113 | D_dims = (batch_dim, x_time_dim + 1, y_time_dim + 1, space_dim) \ 114 | if spatial_independent else (batch_dim, x_time_dim + 1, y_time_dim + 1) 115 | 116 | # Determine dimensions for Soft-DTW Distance Gram Matrix. 117 | # +2 because padding needed either side for forward and backward function 118 | R_dims = (batch_dim, x_time_dim + 2, y_time_dim + 2, space_dim) \ 119 | if spatial_independent else (batch_dim, x_time_dim + 2, y_time_dim + 2) 120 | 121 | # Create Gram Matrices 122 | D = torch.zeros(D_dims).to(ctx.device) 123 | R = torch.ones(R_dims).to(ctx.device)*inf 124 | 125 | from math import inf 126 | 127 | # Initialize edges of Soft-DTW Gram Matrix 128 | R[:, 0, 0] = 0 129 | 130 | niters = x_time_dim + y_time_dim + 2 131 | 132 | # Sweep diagonally through Gram Matrices to compute alignment costs. 133 | # See https://towardsdatascience.com/gpu-optimized-dynamic-programming-8d5ba3d7064f for inspiration 134 | for (i,j),(ip1,jp1) in zip(MatrixDiagonalIndexIterator(m = x_time_dim, n = y_time_dim, bandwidth=bandwidth), 135 | MatrixDiagonalIndexIterator(m = x_time_dim + 1, n= y_time_dim + 1, k_start=1, 136 | bandwidth=bandwidth)): 137 | 138 | # Compute Squared Euclidean Distance 139 | if spatial_independent: 140 | D[:, i, j] = (x[:, :, i] - y[:, :, j]).permute(0, 2, 1).pow(2) 141 | else: 142 | D[:, i, j] = (x[:, :, i] - y[:, :, j]).permute(0, 2, 1).pow(2).sum(dim=-1) 143 | 144 | # Add soft minimum alignment costs 145 | R[:, ip1, jp1] = D[:, i, j] + softmin([R[:, i, j], 146 | R[:, ip1, j], 147 | R[:, i, jp1]], 148 | gamma=1.0) 149 | ctx.save_for_backward(x, y) 150 | ctx.R = R 151 | ctx.D = D 152 | return R[:, -2, -2].sum(dim=-1) if spatial_independent else R[:, -2, -2] 153 | 154 | @staticmethod 155 | def backward(ctx, grad_output): 156 | ''' 157 | @staticmethod 158 | backward(ctx, grad_output): 159 | 160 | Compute SoftDTW gradient wrt x. See algorithm 2 in https://arxiv.org/abs/1703.01541 161 | ''' 162 | # Get saved tensors 163 | x, y = ctx.saved_tensors 164 | 165 | # Determine size of alignment gradient matrix 166 | E_dims = (ctx.batch_dim, ctx.x_time_dim + 2, ctx.y_time_dim + 2, ctx.space_dim) \ 167 | if ctx.spatial_independent else (ctx.batch_dim, ctx.x_time_dim + 2, ctx.y_time_dim + 2) 168 | 169 | # Create alignment gradient matrix 170 | E = torch.zeros(E_dims).to(ctx.device) 171 | E[:, -1, -1] = 1 172 | 173 | from math import inf 174 | ctx.R[torch.isinf(ctx.R)] = -inf 175 | ctx.R[:, -1, -1] = ctx.R[:, -2, -2] 176 | 177 | rev_idxs = reversed(list(MatrixDiagonalIndexIterator(ctx.x_time_dim, ctx.y_time_dim, bandwidth=ctx.bandwidth))) 178 | rev_idxsp1 = reversed(list(MatrixDiagonalIndexIterator(ctx.x_time_dim + 1, ctx.y_time_dim + 1, 179 | k_start = 1, bandwidth=ctx.bandwidth))) 180 | rev_idxsp2 = reversed(list(MatrixDiagonalIndexIterator(ctx.x_time_dim + 2, ctx.y_time_dim + 2, 181 | k_start = 2, bandwidth=ctx.bandwidth))) 182 | 183 | # Sweep diagonally through alignment gradient matrix 184 | for (i,j),(ip1,jp1),(ip2,jp2) in zip(rev_idxs, rev_idxsp1, rev_idxsp2): 185 | a = torch.exp((ctx.R[:, ip2, jp1] - ctx.R[:, ip1, jp1] - ctx.D[:, ip1, j ])/ctx.gamma) 186 | b = torch.exp((ctx.R[:, ip1, jp2] - ctx.R[:, ip1, jp1] - ctx.D[:, i, jp1])/ctx.gamma) 187 | c = torch.exp((ctx.R[:, ip2, jp2] - ctx.R[:, ip1, jp1] - ctx.D[:, ip1, jp1])/ctx.gamma) 188 | 189 | E[:, ip1, jp1] = E[:, ip2, jp1]*a + E[:, ip1, jp2]*b + E[:, ip2, jp2]*c 190 | 191 | # Compute Jacobean product to compute gradient wrt x 192 | if ctx.spatial_independent: 193 | G = jacobean_product_squared_euclidean(x.unsqueeze(2), y.unsqueeze(2), E[:, 1:-1, 1:-1].permute(0, 3, 2, 1)).squeeze(2) 194 | else: 195 | G = jacobean_product_squared_euclidean(x, y, E[:, 1:-1, 1:-1].permute(0, 2, 1)) 196 | 197 | # Must return as many outputs as inputs to forward function 198 | return G, None, None, 199 | 200 | def softmin(x, gamma): 201 | ''' 202 | softmin(x, gamma): 203 | 204 | Soft minimum function used to smooth DTW and make it differentiable 205 | 206 | Arguments 207 | x : list of tensors [x1, ..., xN] to compute soft-minimum over 208 | gamma : smoothing parameter 209 | 210 | Return 211 | smin_x : softmin of x 212 | ''' 213 | # Obtain dimensions of inputs 214 | dims = tuple([len(x), *x[0].shape]) 215 | 216 | # Concatenate inputs 217 | x = -torch.cat(x).reshape(dims)/gamma 218 | 219 | # Compute and return soft minimum 220 | return -gamma * torch.logsumexp(x, dim=0) 221 | 222 | def jacobean_product_squared_euclidean(X, Y, Bt): 223 | ''' 224 | jacobean_product_squared_euclidean(X, Y, Bt): 225 | 226 | Jacobean product of squared Euclidean distance matrix and alignment matrix. 227 | See equations 2 and 2.5 of https://arxiv.org/abs/1703.01541 228 | ''' 229 | ones = torch.ones(Y.shape).to('cuda' if Bt.is_cuda else 'cpu') 230 | return 2 * (ones.matmul(Bt) * X - Y.matmul(Bt)) 231 | 232 | class MatrixDiagonalIndexIterator: 233 | ''' 234 | Custom iterator class to return successive diagonal indices of a matrix 235 | ''' 236 | 237 | def __init__(self, m, n, k_start=0, bandwidth=None): 238 | ''' 239 | __init__(self, m, n, k_start=0, bandwidth=None): 240 | 241 | Arguments: 242 | m (int) : number of rows in matrix 243 | n (int) : number of columns in matrix 244 | k_start (int) : (k_start, k_start) index to begin from 245 | bandwidth (int) : bandwidth to constrain indices within 246 | ''' 247 | self.m = m 248 | self.n = n 249 | self.k = k_start 250 | self.k_max = self.m + self.n - k_start - 1 251 | self.bandwidth = bandwidth 252 | 253 | def __iter__(self): 254 | return self 255 | 256 | def __next__(self): 257 | if hasattr(self, 'i') and hasattr(self, 'j'): 258 | 259 | if self.k == self.k_max: 260 | raise StopIteration 261 | 262 | elif self.k < self.m and self.k < self.n: 263 | self.i = self.i + [self.k] 264 | self.j = [self.k] + self.j 265 | self.k+=1 266 | 267 | elif self.k >= self.m and self.k < self.n: 268 | self.j.pop(-1) 269 | self.j = [self.k] + self.j 270 | self.k+=1 271 | 272 | elif self.k < self.m and self.k >= self.n: 273 | self.i.pop(0) 274 | self.i = self.i + [self.k] 275 | self.k+=1 276 | 277 | elif self.k >= self.m and self.k >= self.n: 278 | self.i.pop(0) 279 | self.j.pop(-1) 280 | self.k+=1 281 | 282 | else: 283 | self.i = [self.k] 284 | self.j = [self.k] 285 | self.k+=1 286 | 287 | if bandwidth: 288 | i_scb, j_scb = sakoe_chiba_band(self.i.copy(), self.j.copy(), self.m, self.n, bandwidth) 289 | return i_scb, j_scb 290 | else: 291 | return self.i.copy(), self.j.copy() 292 | 293 | def sakoe_chiba_band(i_list, j_list, m, n, bandwidth=1): 294 | i_scb, j_scb = zip(*[(i, j) for i,j in zip(i_list, j_list) 295 | if abs(2*(i*(n-1) - j*(m-1))) < max(m, n)*(bandwidth+1)]) 296 | return list(i_scb), list(j_scb) --------------------------------------------------------------------------------