├── .gitignore ├── README.md └── tf2_soft_dtw.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Batch Soft-DTW(Dynamic Time Warping) in TensorFlow2 including forward and backward computation 2 | === 3 | 4 | Custom TensorFlow2 implementations of forward and backward computation of soft-DTW(Dynamic Time Warping) algorithm in batch mode, which is proposed in paper 《Soft-DTW: a Differentiable Loss Function for Time-Series》. 5 | 6 | I have implemented two versions of soft-DTW, one is the original paper, the other is Parallel Tacotron2's paper(with warp penalty). For latter version, I solved the equations of backward computation myself. 7 | 8 | 9 | If you have questions or improvements about the code, welcome to submit issues ASAP! 10 | -------------------------------------------------------------------------------- /tf2_soft_dtw.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : tf2_soft_dtw.py 4 | # Author : zewangzhang 5 | # Date : 07.06.2021 6 | # Last Modified Date: 07.06.2021 7 | # Last Modified By : zewangzhang 8 | # -*- coding: utf-8 -*- 9 | 10 | """ soft-DTW tensorflow 版本 """ 11 | 12 | import tensorflow as tf 13 | import numpy as np 14 | 15 | def batch_distance(X, Y, metric="L1"): 16 | """ batch模式的距离计算 17 | X: batch_size*seq_len1*feat_dim 18 | Y: batch_size*seq_len2*feat_dim 19 | 这里只负责计算距离,不进行任何维度填充 20 | 21 | """ 22 | 23 | assert metric == "L1", "wrong metric value !" 24 | 25 | N, T1, d = tf.shape(X)[0], tf.shape(X)[1], tf.shape(X)[2] 26 | T2 = tf.shape(Y)[1] 27 | 28 | X = tf.reshape(tf.tile(X, [1, 1, T2]), (N*T1*T2, d)) 29 | Y = tf.reshape(tf.tile(Y, [1, T1, 1]), (N*T1*T2, d)) 30 | 31 | res = tf.math.abs(X-Y) 32 | res = tf.reduce_sum(res, axis=-1) 33 | res = tf.cast(tf.reshape(res, [N, T1, T2]), tf.float32) 34 | 35 | raw_res = X-Y 36 | raw_res = tf.reduce_sum(raw_res, axis=-1) 37 | raw_res = tf.cast(tf.reshape(raw_res, [N, T1, T2]), tf.float32) 38 | 39 | return res, raw_res 40 | 41 | 42 | def batch_soft_dtw(X, Y, gamma, warp, metric="L2"): 43 | """ batch模式的soft-DTW距离计算,并带有自定义的梯度(custom gradient) """ 44 | 45 | N, T1 = tf.shape(X)[0], tf.shape(X)[1] 46 | T2 = tf.shape(Y)[1] 47 | 48 | # 获取欧式距离矩阵 49 | delta_matrix, raw_delta_matrix = batch_distance(X, Y, metric=metric) 50 | 51 | @tf.custom_gradient 52 | def _batch_soft_dtw_kernel(delta_matrix, raw_delta_matrix): 53 | 54 | delta_matrix_v1 = tf.identity(tf.cast(delta_matrix, tf.float32), "delta_matrix_v1") 55 | raw_delta_matrix_v1 = tf.identity(tf.cast(raw_delta_matrix, tf.float32), "raw_delta_matrix_v1") 56 | 57 | delta_array = tf.TensorArray(tf.float32, size=T1*T2, clear_after_read=False) 58 | delta_array = delta_array.unstack(tf.reshape(delta_matrix, [T1*T2, N])) 59 | 60 | r_array = tf.TensorArray(tf.float32, size=(T1+1)*(T2+1), clear_after_read=False) 61 | r_array = r_array.write(0, tf.zeros(shape=(N, ))) 62 | 63 | def cond_boder_x(idx, array): 64 | return idx < T1+1 65 | 66 | def cond_boder_y(idx, seq_len): 67 | return idx < T2+1 68 | 69 | def body_border_x(idx, array): 70 | array = array.write(tf.cast(idx*(T2+1), tf.int32), 1000000*tf.ones(shape=(N, ))) 71 | return idx+1, array 72 | 73 | def body_border_y(idx, array): 74 | array = array.write(tf.cast(idx, tf.int32), 1000000*tf.ones(shape=(N, ))) 75 | return idx+1, array 76 | 77 | _, r_array = tf.while_loop(cond_boder_x, body_border_x, (1, r_array)) 78 | _, r_array = tf.while_loop(cond_boder_y, body_border_y, (1, r_array)) 79 | 80 | 81 | def cond(idx, array): 82 | return idx < (T1+1) * (T2+1) 83 | 84 | def body(idx, array): 85 | i = tf.cast(tf.divide(idx, T2+1), tf.int32) # 行号 86 | j = tf.math.floormod(tf.cast(idx, tf.int32), T2+1) # 列号 87 | 88 | def inner_func_v1(): 89 | """ Parallel Tacotron2's version """ 90 | z1 = -1./gamma * (array.read((i-1)*(T2+1)+(j-1)) +r_array.read((i-1)*(T2+1)+(j-1))) 91 | z2 = -1./gamma * (warp+array.read((i-1)*(T2+1)+(j))+r_array.read((i-1)*(T2+1)+(j))) 92 | z3 = -1./gamma * (warp+array.read((i)*(T2+1)+(j-1))+r_array.read((i)*(T2+1)+(j-1))) 93 | soft_min_value = -gamma * tf.math.reduce_logsumexp([z1, z2, z3], axis=0) 94 | r_value = tf.cast(soft_min_value, tf.float32) 95 | 96 | return array.write(idx, tf.cast(r_value, tf.float32)) 97 | 98 | def outer_func(): 99 | 100 | return array 101 | 102 | array = tf.cond(tf.less(i, 1) | tf.less(j, 1), 103 | true_fn=outer_func, 104 | false_fn=inner_func_v1) 105 | 106 | return idx+1, array 107 | 108 | _, r_array = tf.while_loop(cond, body, (0, r_array)) 109 | r_matrix = r_array.stack() 110 | 111 | # 最终的soft-DTW距离 112 | r_matrix = tf.reshape(r_matrix, (N, T1+1, T2+1)) 113 | r_matrix_v1 = tf.identity(tf.cast(r_matrix, tf.float32), "r_matrix_v1") 114 | 115 | def grad_v1(dy): 116 | """ Parallel Tacotron2's version, I solve it analytically """ 117 | 118 | # [N, T1+1, T2+1] 119 | delta_matrix = tf.concat([delta_matrix_v1, tf.zeros([N, T1, 1], tf.float32)], axis=2) 120 | delta_matrix = tf.concat([delta_matrix, tf.zeros([N, 1, T2+1], tf.float32)], axis=1) 121 | delta_array = tf.TensorArray(tf.float32, size=(T1+1)*(T2+1), clear_after_read=False) 122 | delta_array = delta_array.unstack(tf.reshape(delta_matrix, [(T1+1)*(T2+1), N])) 123 | delta_array = delta_array.write((T1+1)*(T2+1)-1, tf.zeros((N, ))) 124 | 125 | # [N, T1+2, T2+2] 126 | r_matrix = tf.concat([r_matrix_v1, -1000000*tf.ones([N, T1+1, 1], tf.float32)], axis=2) 127 | r_matrix = tf.concat([r_matrix, -1000000*tf.ones([N, 1, T2+2], tf.float32)], axis=1) 128 | r_array = tf.TensorArray(tf.float32, size=(T1+2)*(T2+2), clear_after_read=False) 129 | r_array = r_array.unstack(tf.reshape(r_matrix, [(T1+2)*(T2+2), N])) 130 | r_array = r_array.write((T1+2)*(T2+2)-1, r_array.read((T1+1)*(T2+2)-2)) 131 | 132 | # [N, T1+1, T2+1] 133 | e_matrix = tf.zeros([N, T1+1, T2+1], tf.float32) 134 | 135 | e_array = tf.TensorArray(tf.float32, size=(T1+1)*(T2+1), clear_after_read=False) 136 | e_array = e_array.unstack(tf.reshape(e_matrix, [(T1+1)*(T2+1), N])) 137 | e_array = e_array.write((T1+1)*(T2+1)-1, tf.ones((N, ))) 138 | 139 | grad_array = tf.TensorArray(tf.float32, size=(T1+1)*(T2+1), clear_after_read=False) 140 | grad_array = grad_array.unstack(tf.reshape(e_matrix, [(T1+1)*(T2+1), N])) 141 | grad_array = grad_array.write((T1+1)*(T2+1)-1, tf.ones((N, ))) 142 | 143 | def cond(idx, array, grad_array): 144 | return idx > 0 145 | 146 | def body(idx, array, grad_array): 147 | # delta_array [N, T1+1, T2+1] 148 | # r_array [N, T1+2, T2+2] 149 | # e_array [N, T1+1, T2+1] 150 | 151 | j = tf.cast(tf.divide(idx, T1+1), tf.int32) # 行号 152 | i = tf.math.floormod(tf.cast(idx, tf.int32), T1+1) # 列号 153 | 154 | def inner_func(): 155 | 156 | a = tf.math.exp(1./gamma * (r_array.read((i+1)*(T2+2)+j)-r_array.read(i*(T2+2)+j)-delta_array.read(i*(T2+1)+(j-1))-warp)) 157 | b = tf.math.exp(1./gamma * (r_array.read((i)*(T2+2)+(j+1))-r_array.read(i*(T2+2)+j)-delta_array.read((i-1)*(T2+1)+j)-warp)) 158 | c = tf.math.exp(1./gamma * (r_array.read((i+1)*(T2+2)+(j+1))-r_array.read(i*(T2+2)+j)-delta_array.read((i)*(T2+1)+j))) 159 | e_value = array.read(i*(T2+1)+(j-1))*a + array.read((i-1)*(T2+1)+j)*b + array.read(i*(T2+1)+j)*c 160 | 161 | return array.write((i-1)*(T2+1)+(j-1), e_value), grad_array.write((i-1)*(T2+1)+j-1, array.read((i-1)*(T2+1)+j)+array.read(i*(T2+1)+(j-1))+array.read(i*(T2+1)+j)) 162 | 163 | def outer_func(): 164 | 165 | return array, grad_array 166 | 167 | array, grad_array = tf.cond((i>0) & (j>0), 168 | true_fn=inner_func, 169 | false_fn=outer_func) 170 | 171 | return idx-1, array, grad_array 172 | 173 | _, e_array, grad_array = tf.while_loop(cond, body, ((T1+1)*(T2+1), e_array, grad_array)) 174 | grad_matrix = grad_array.stack() 175 | grad_matrix = tf.cast(tf.reshape(grad_matrix, [N, T1+1, T2+1]), tf.float32) 176 | 177 | # raw_delta_matrix_v1 178 | tmp_grad = grad_matrix[:, 1:, 1:]*tf.math.sign(raw_delta_matrix_v1) 179 | tmp_grad = tf.linalg.matmul(tmp_grad, tf.ones(shape=[tf.shape(raw_delta_matrix_v1)[0], tf.shape(raw_delta_matrix_v1).shape[2], 80], dtype=tf.float32)) 180 | 181 | return tmp_grad 182 | 183 | # I use Parallel Tacotron2's version for TTS training 184 | return r_matrix[:, -1, -1], grad_v1 185 | 186 | return _batch_soft_dtw_kernel(delta_matrix, raw_delta_matrix) 187 | 188 | if __name__ == '__main__': 189 | 190 | n = 4 191 | m = 3 192 | 193 | # sequence1 194 | a = tf.Variable(np.random.rand(1, n, 2)) 195 | 196 | # sequence2(or target sequence) 197 | b = np.random.rand(1, m, 2) 198 | 199 | eu_distance = batch_distance(a, b, metric="L1") 200 | 201 | with tf.GradientTape() as tape: 202 | soft_dtw_distance = batch_soft_dtw(a, b, gamma=0.01, metric="L1") 203 | grad = tape.gradient(soft_dtw_distance, a) 204 | 205 | print(eu_distance) 206 | print(soft_dtw_distance) 207 | print(grad) 208 | 209 | --------------------------------------------------------------------------------