├── DILATE_neurips19.pdf ├── LICENCE.md ├── README.md ├── data └── synthetic_dataset.py ├── fig2.png ├── loss ├── dilate_loss.py ├── path_soft_dtw.py └── soft_dtw.py ├── main.py ├── models └── seq2seq.py └── poster_DILATE.pdf /DILATE_neurips19.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vincent-leguen/DILATE/5f114b587fd7abb7b79726ed68d8e4a91049cc0e/DILATE_neurips19.pdf -------------------------------------------------------------------------------- /LICENCE.md: -------------------------------------------------------------------------------- 1 | 2 | This program is free software: you can redistribute it and/or modify 3 | it under the terms of the GNU General Public License as published by 4 | the Free Software Foundation, either version 3 of the License, or 5 | (at your option) any later version. 6 | 7 | This program is distributed in the hope that it will be useful, 8 | but WITHOUT ANY WARRANTY; without even the implied warranty of 9 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 10 | GNU General Public License for more details. 11 | 12 | You should have received a copy of the GNU General Public License 13 | along with this program. If not, see . 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DILATE: DIstortion Loss with shApe and tImE 2 | [Vincent Le Guen](https://www.linkedin.com/in/vincentleguen/), [Nicolas Thome](http://cedric.cnam.fr/~thomen/) 3 | 4 | Code for our NeurIPS 2019 paper "Shape and Time Distortion Loss for Training Deep Time Series Forecasting Models" 5 | 6 | ![](https://github.com/vincent-leguen/DILATE/blob/master/fig2.png) 7 | 8 | If you find this code useful for your research, please cite our [paper](https://papers.nips.cc/paper/8672-shape-and-time-distortion-loss-for-training-deep-time-series-forecasting-models): 9 | 10 | ``` 11 | @incollection{leguen19dilate, 12 | title = {Shape and Time Distortion Loss for Training Deep Time Series Forecasting Models}, 13 | author = {Le Guen, Vincent and Thome, Nicolas}, 14 | booktitle = {Advances in Neural Information Processing Systems}, 15 | pages = {4191--4203}, 16 | year = {2019} 17 | } 18 | ``` 19 | 20 | ## Abstract 21 | This paper addresses the problem of time series forecasting for non-stationary signals and multiple future steps prediction. To handle this challenging task, we introduce DILATE (DIstortion Loss including shApe and TimE), a new objective function for training deep neural networks. DILATE aims at accurately predicting sudden changes, and explicitly incorporates two terms supporting precise shape and temporal change detection. We introduce a differentiable loss function suitable for training deep neural nets, and provide a custom back-prop implementation for speeding up optimization. We also introduce a variant of DILATE, which provides a smooth generalization of temporally-constrained Dynamic Time Warping (DTW). Experiments carried out on various non-stationary datasets reveal the very good behaviour of DILATE compared to models trained with the standard Mean Squared Error (MSE) loss function, and also to DTW and variants. DILATE is also agnostic to the choice of the model, and we highlight its benefit for training fully connected networks as well as specialized recurrent architectures, showing its capacity to improve over state-of-the-art trajectory forecasting approaches. 22 | -------------------------------------------------------------------------------- /data/synthetic_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | from torch.utils.data import Dataset, DataLoader 5 | 6 | def create_synthetic_dataset(N, N_input,N_output,sigma): 7 | # N: number of samples in each split (train, test) 8 | # N_input: import of time steps in input series 9 | # N_output: import of time steps in output series 10 | # sigma: standard deviation of additional noise 11 | X = [] 12 | breakpoints = [] 13 | for k in range(2*N): 14 | serie = np.array([ sigma*random.random() for i in range(N_input+N_output)]) 15 | i1 = random.randint(1,10) 16 | i2 = random.randint(10,18) 17 | j1 = random.random() 18 | j2 = random.random() 19 | interval = abs(i2-i1) + random.randint(-3,3) 20 | serie[i1:i1+1] += j1 21 | serie[i2:i2+1] += j2 22 | serie[i2+interval:] += (j2-j1) 23 | X.append(serie) 24 | breakpoints.append(i2+interval) 25 | X = np.stack(X) 26 | breakpoints = np.array(breakpoints) 27 | return X[0:N,0:N_input], X[0:N, N_input:N_input+N_output], X[N:2*N,0:N_input], X[N:2*N, N_input:N_input+N_output],breakpoints[0:N], breakpoints[N:2*N] 28 | 29 | 30 | class SyntheticDataset(torch.utils.data.Dataset): 31 | def __init__(self, X_input, X_target, breakpoints): 32 | super(SyntheticDataset, self).__init__() 33 | self.X_input = X_input 34 | self.X_target = X_target 35 | self.breakpoints = breakpoints 36 | 37 | def __len__(self): 38 | return (self.X_input).shape[0] 39 | 40 | def __getitem__(self, idx): 41 | return (self.X_input[idx,:,np.newaxis], self.X_target[idx,:,np.newaxis] , self.breakpoints[idx]) 42 | 43 | 44 | -------------------------------------------------------------------------------- /fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vincent-leguen/DILATE/5f114b587fd7abb7b79726ed68d8e4a91049cc0e/fig2.png -------------------------------------------------------------------------------- /loss/dilate_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import soft_dtw 3 | from . import path_soft_dtw 4 | 5 | def dilate_loss(outputs, targets, alpha, gamma, device): 6 | # outputs, targets: shape (batch_size, N_output, 1) 7 | batch_size, N_output = outputs.shape[0:2] 8 | loss_shape = 0 9 | softdtw_batch = soft_dtw.SoftDTWBatch.apply 10 | D = torch.zeros((batch_size, N_output,N_output )).to(device) 11 | for k in range(batch_size): 12 | Dk = soft_dtw.pairwise_distances(targets[k,:,:].view(-1,1),outputs[k,:,:].view(-1,1)) 13 | D[k:k+1,:,:] = Dk 14 | loss_shape = softdtw_batch(D,gamma) 15 | 16 | path_dtw = path_soft_dtw.PathDTWBatch.apply 17 | path = path_dtw(D,gamma) 18 | Omega = soft_dtw.pairwise_distances(torch.range(1,N_output).view(N_output,1)).to(device) 19 | loss_temporal = torch.sum( path*Omega ) / (N_output*N_output) 20 | loss = alpha*loss_shape+ (1-alpha)*loss_temporal 21 | return loss, loss_shape, loss_temporal -------------------------------------------------------------------------------- /loss/path_soft_dtw.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Function 4 | from numba import jit 5 | 6 | 7 | @jit(nopython = True) 8 | def my_max(x, gamma): 9 | # use the log-sum-exp trick 10 | max_x = np.max(x) 11 | exp_x = np.exp((x - max_x) / gamma) 12 | Z = np.sum(exp_x) 13 | return gamma * np.log(Z) + max_x, exp_x / Z 14 | 15 | @jit(nopython = True) 16 | def my_min(x,gamma) : 17 | min_x, argmax_x = my_max(-x, gamma) 18 | return - min_x, argmax_x 19 | 20 | @jit(nopython = True) 21 | def my_max_hessian_product(p, z, gamma): 22 | return ( p * z - p * np.sum(p * z) ) /gamma 23 | 24 | @jit(nopython = True) 25 | def my_min_hessian_product(p, z, gamma): 26 | return - my_max_hessian_product(p, z, gamma) 27 | 28 | 29 | @jit(nopython = True) 30 | def dtw_grad(theta, gamma): 31 | m = theta.shape[0] 32 | n = theta.shape[1] 33 | V = np.zeros((m + 1, n + 1)) 34 | V[:, 0] = 1e10 35 | V[0, :] = 1e10 36 | V[0, 0] = 0 37 | 38 | Q = np.zeros((m + 2, n + 2, 3)) 39 | 40 | for i in range(1, m + 1): 41 | for j in range(1, n + 1): 42 | # theta is indexed starting from 0. 43 | v, Q[i, j] = my_min(np.array([V[i, j - 1], 44 | V[i - 1, j - 1], 45 | V[i - 1, j]]) , gamma) 46 | V[i, j] = theta[i - 1, j - 1] + v 47 | 48 | E = np.zeros((m + 2, n + 2)) 49 | E[m + 1, :] = 0 50 | E[:, n + 1] = 0 51 | E[m + 1, n + 1] = 1 52 | Q[m + 1, n + 1] = 1 53 | 54 | for i in range(m,0,-1): 55 | for j in range(n,0,-1): 56 | E[i, j] = Q[i, j + 1, 0] * E[i, j + 1] + \ 57 | Q[i + 1, j + 1, 1] * E[i + 1, j + 1] + \ 58 | Q[i + 1, j, 2] * E[i + 1, j] 59 | 60 | return V[m, n], E[1:m + 1, 1:n + 1], Q, E 61 | 62 | 63 | @jit(nopython = True) 64 | def dtw_hessian_prod(theta, Z, Q, E, gamma): 65 | m = Z.shape[0] 66 | n = Z.shape[1] 67 | 68 | V_dot = np.zeros((m + 1, n + 1)) 69 | V_dot[0, 0] = 0 70 | 71 | Q_dot = np.zeros((m + 2, n + 2, 3)) 72 | for i in range(1, m + 1): 73 | for j in range(1, n + 1): 74 | # theta is indexed starting from 0. 75 | V_dot[i, j] = Z[i - 1, j - 1] + \ 76 | Q[i, j, 0] * V_dot[i, j - 1] + \ 77 | Q[i, j, 1] * V_dot[i - 1, j - 1] + \ 78 | Q[i, j, 2] * V_dot[i - 1, j] 79 | 80 | v = np.array([V_dot[i, j - 1], V_dot[i - 1, j - 1], V_dot[i - 1, j]]) 81 | Q_dot[i, j] = my_min_hessian_product(Q[i, j], v, gamma) 82 | E_dot = np.zeros((m + 2, n + 2)) 83 | 84 | for j in range(n,0,-1): 85 | for i in range(m,0,-1): 86 | E_dot[i, j] = Q_dot[i, j + 1, 0] * E[i, j + 1] + \ 87 | Q[i, j + 1, 0] * E_dot[i, j + 1] + \ 88 | Q_dot[i + 1, j + 1, 1] * E[i + 1, j + 1] + \ 89 | Q[i + 1, j + 1, 1] * E_dot[i + 1, j + 1] + \ 90 | Q_dot[i + 1, j, 2] * E[i + 1, j] + \ 91 | Q[i + 1, j, 2] * E_dot[i + 1, j] 92 | 93 | return V_dot[m, n], E_dot[1:m + 1, 1:n + 1] 94 | 95 | 96 | class PathDTWBatch(Function): 97 | @staticmethod 98 | def forward(ctx, D, gamma): # D.shape: [batch_size, N , N] 99 | batch_size,N,N = D.shape 100 | device = D.device 101 | D_cpu = D.detach().cpu().numpy() 102 | gamma_gpu = torch.FloatTensor([gamma]).to(device) 103 | 104 | grad_gpu = torch.zeros((batch_size, N ,N)).to(device) 105 | Q_gpu = torch.zeros((batch_size, N+2 ,N+2,3)).to(device) 106 | E_gpu = torch.zeros((batch_size, N+2 ,N+2)).to(device) 107 | 108 | for k in range(0,batch_size): # loop over all D in the batch 109 | _, grad_cpu_k, Q_cpu_k, E_cpu_k = dtw_grad(D_cpu[k,:,:], gamma) 110 | grad_gpu[k,:,:] = torch.FloatTensor(grad_cpu_k).to(device) 111 | Q_gpu[k,:,:,:] = torch.FloatTensor(Q_cpu_k).to(device) 112 | E_gpu[k,:,:] = torch.FloatTensor(E_cpu_k).to(device) 113 | ctx.save_for_backward(grad_gpu,D, Q_gpu ,E_gpu, gamma_gpu) 114 | return torch.mean(grad_gpu, dim=0) 115 | 116 | @staticmethod 117 | def backward(ctx, grad_output): 118 | device = grad_output.device 119 | grad_gpu, D_gpu, Q_gpu, E_gpu, gamma = ctx.saved_tensors 120 | D_cpu = D_gpu.detach().cpu().numpy() 121 | Q_cpu = Q_gpu.detach().cpu().numpy() 122 | E_cpu = E_gpu.detach().cpu().numpy() 123 | gamma = gamma.detach().cpu().numpy()[0] 124 | Z = grad_output.detach().cpu().numpy() 125 | 126 | batch_size,N,N = D_cpu.shape 127 | Hessian = torch.zeros((batch_size, N ,N)).to(device) 128 | for k in range(0,batch_size): 129 | _, hess_k = dtw_hessian_prod(D_cpu[k,:,:], Z, Q_cpu[k,:,:,:], E_cpu[k,:,:], gamma) 130 | Hessian[k:k+1,:,:] = torch.FloatTensor(hess_k).to(device) 131 | 132 | return Hessian, None -------------------------------------------------------------------------------- /loss/soft_dtw.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from numba import jit 4 | from torch.autograd import Function 5 | 6 | def pairwise_distances(x, y=None): 7 | ''' 8 | Input: x is a Nxd matrix 9 | y is an optional Mxd matirx 10 | Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:] 11 | if y is not given then use 'y=x'. 12 | i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2 13 | ''' 14 | x_norm = (x**2).sum(1).view(-1, 1) 15 | if y is not None: 16 | y_t = torch.transpose(y, 0, 1) 17 | y_norm = (y**2).sum(1).view(1, -1) 18 | else: 19 | y_t = torch.transpose(x, 0, 1) 20 | y_norm = x_norm.view(1, -1) 21 | 22 | dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t) 23 | return torch.clamp(dist, 0.0, float('inf')) 24 | 25 | @jit(nopython = True) 26 | def compute_softdtw(D, gamma): 27 | N = D.shape[0] 28 | M = D.shape[1] 29 | R = np.zeros((N + 2, M + 2)) + 1e8 30 | R[0, 0] = 0 31 | for j in range(1, M + 1): 32 | for i in range(1, N + 1): 33 | r0 = -R[i - 1, j - 1] / gamma 34 | r1 = -R[i - 1, j] / gamma 35 | r2 = -R[i, j - 1] / gamma 36 | rmax = max(max(r0, r1), r2) 37 | rsum = np.exp(r0 - rmax) + np.exp(r1 - rmax) + np.exp(r2 - rmax) 38 | softmin = - gamma * (np.log(rsum) + rmax) 39 | R[i, j] = D[i - 1, j - 1] + softmin 40 | return R 41 | 42 | @jit(nopython = True) 43 | def compute_softdtw_backward(D_, R, gamma): 44 | N = D_.shape[0] 45 | M = D_.shape[1] 46 | D = np.zeros((N + 2, M + 2)) 47 | E = np.zeros((N + 2, M + 2)) 48 | D[1:N + 1, 1:M + 1] = D_ 49 | E[-1, -1] = 1 50 | R[:, -1] = -1e8 51 | R[-1, :] = -1e8 52 | R[-1, -1] = R[-2, -2] 53 | for j in range(M, 0, -1): 54 | for i in range(N, 0, -1): 55 | a0 = (R[i + 1, j] - R[i, j] - D[i + 1, j]) / gamma 56 | b0 = (R[i, j + 1] - R[i, j] - D[i, j + 1]) / gamma 57 | c0 = (R[i + 1, j + 1] - R[i, j] - D[i + 1, j + 1]) / gamma 58 | a = np.exp(a0) 59 | b = np.exp(b0) 60 | c = np.exp(c0) 61 | E[i, j] = E[i + 1, j] * a + E[i, j + 1] * b + E[i + 1, j + 1] * c 62 | return E[1:N + 1, 1:M + 1] 63 | 64 | 65 | class SoftDTWBatch(Function): 66 | @staticmethod 67 | def forward(ctx, D, gamma = 1.0): # D.shape: [batch_size, N , N] 68 | dev = D.device 69 | batch_size,N,N = D.shape 70 | gamma = torch.FloatTensor([gamma]).to(dev) 71 | D_ = D.detach().cpu().numpy() 72 | g_ = gamma.item() 73 | 74 | total_loss = 0 75 | R = torch.zeros((batch_size, N+2 ,N+2)).to(dev) 76 | for k in range(0, batch_size): # loop over all D in the batch 77 | Rk = torch.FloatTensor(compute_softdtw(D_[k,:,:], g_)).to(dev) 78 | R[k:k+1,:,:] = Rk 79 | total_loss = total_loss + Rk[-2,-2] 80 | ctx.save_for_backward(D, R, gamma) 81 | return total_loss / batch_size 82 | 83 | @staticmethod 84 | def backward(ctx, grad_output): 85 | dev = grad_output.device 86 | D, R, gamma = ctx.saved_tensors 87 | batch_size,N,N = D.shape 88 | D_ = D.detach().cpu().numpy() 89 | R_ = R.detach().cpu().numpy() 90 | g_ = gamma.item() 91 | 92 | E = torch.zeros((batch_size, N ,N)).to(dev) 93 | for k in range(batch_size): 94 | Ek = torch.FloatTensor(compute_softdtw_backward(D_[k,:,:], R_[k,:,:], g_)).to(dev) 95 | E[k:k+1,:,:] = Ek 96 | 97 | return grad_output * E, None 98 | 99 | 100 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from data.synthetic_dataset import create_synthetic_dataset, SyntheticDataset 4 | from models.seq2seq import EncoderRNN, DecoderRNN, Net_GRU 5 | from loss.dilate_loss import dilate_loss 6 | from torch.utils.data import DataLoader 7 | import random 8 | from tslearn.metrics import dtw, dtw_path 9 | import matplotlib.pyplot as plt 10 | import warnings 11 | import warnings; warnings.simplefilter('ignore') 12 | 13 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 14 | random.seed(0) 15 | 16 | # parameters 17 | batch_size = 100 18 | N = 500 19 | N_input = 20 20 | N_output = 20 21 | sigma = 0.01 22 | gamma = 0.01 23 | 24 | # Load synthetic dataset 25 | X_train_input,X_train_target,X_test_input,X_test_target,train_bkp,test_bkp = create_synthetic_dataset(N,N_input,N_output,sigma) 26 | dataset_train = SyntheticDataset(X_train_input,X_train_target, train_bkp) 27 | dataset_test = SyntheticDataset(X_test_input,X_test_target, test_bkp) 28 | trainloader = DataLoader(dataset_train, batch_size=batch_size,shuffle=True, num_workers=1) 29 | testloader = DataLoader(dataset_test, batch_size=batch_size,shuffle=False, num_workers=1) 30 | 31 | 32 | def train_model(net,loss_type, learning_rate, epochs=1000, gamma = 0.001, 33 | print_every=50,eval_every=50, verbose=1, Lambda=1, alpha=0.5): 34 | 35 | optimizer = torch.optim.Adam(net.parameters(),lr=learning_rate) 36 | criterion = torch.nn.MSELoss() 37 | 38 | for epoch in range(epochs): 39 | for i, data in enumerate(trainloader, 0): 40 | inputs, target, _ = data 41 | inputs = torch.tensor(inputs, dtype=torch.float32).to(device) 42 | target = torch.tensor(target, dtype=torch.float32).to(device) 43 | batch_size, N_output = target.shape[0:2] 44 | 45 | # forward + backward + optimize 46 | outputs = net(inputs) 47 | loss_mse,loss_shape,loss_temporal = torch.tensor(0),torch.tensor(0),torch.tensor(0) 48 | 49 | if (loss_type=='mse'): 50 | loss_mse = criterion(target,outputs) 51 | loss = loss_mse 52 | 53 | if (loss_type=='dilate'): 54 | loss, loss_shape, loss_temporal = dilate_loss(target,outputs,alpha, gamma, device) 55 | 56 | optimizer.zero_grad() 57 | loss.backward() 58 | optimizer.step() 59 | 60 | if(verbose): 61 | if (epoch % print_every == 0): 62 | print('epoch ', epoch, ' loss ',loss.item(),' loss shape ',loss_shape.item(),' loss temporal ',loss_temporal.item()) 63 | eval_model(net,testloader, gamma,verbose=1) 64 | 65 | 66 | def eval_model(net,loader, gamma,verbose=1): 67 | criterion = torch.nn.MSELoss() 68 | losses_mse = [] 69 | losses_dtw = [] 70 | losses_tdi = [] 71 | 72 | for i, data in enumerate(loader, 0): 73 | loss_mse, loss_dtw, loss_tdi = torch.tensor(0),torch.tensor(0),torch.tensor(0) 74 | # get the inputs 75 | inputs, target, breakpoints = data 76 | inputs = torch.tensor(inputs, dtype=torch.float32).to(device) 77 | target = torch.tensor(target, dtype=torch.float32).to(device) 78 | batch_size, N_output = target.shape[0:2] 79 | outputs = net(inputs) 80 | 81 | # MSE 82 | loss_mse = criterion(target,outputs) 83 | loss_dtw, loss_tdi = 0,0 84 | # DTW and TDI 85 | for k in range(batch_size): 86 | target_k_cpu = target[k,:,0:1].view(-1).detach().cpu().numpy() 87 | output_k_cpu = outputs[k,:,0:1].view(-1).detach().cpu().numpy() 88 | 89 | path, sim = dtw_path(target_k_cpu, output_k_cpu) 90 | loss_dtw += sim 91 | 92 | Dist = 0 93 | for i,j in path: 94 | Dist += (i-j)*(i-j) 95 | loss_tdi += Dist / (N_output*N_output) 96 | 97 | loss_dtw = loss_dtw /batch_size 98 | loss_tdi = loss_tdi / batch_size 99 | 100 | # print statistics 101 | losses_mse.append( loss_mse.item() ) 102 | losses_dtw.append( loss_dtw ) 103 | losses_tdi.append( loss_tdi ) 104 | 105 | print( ' Eval mse= ', np.array(losses_mse).mean() ,' dtw= ',np.array(losses_dtw).mean() ,' tdi= ', np.array(losses_tdi).mean()) 106 | 107 | 108 | encoder = EncoderRNN(input_size=1, hidden_size=128, num_grulstm_layers=1, batch_size=batch_size).to(device) 109 | decoder = DecoderRNN(input_size=1, hidden_size=128, num_grulstm_layers=1,fc_units=16, output_size=1).to(device) 110 | net_gru_dilate = Net_GRU(encoder,decoder, N_output, device).to(device) 111 | train_model(net_gru_dilate,loss_type='dilate',learning_rate=0.001, epochs=500, gamma=gamma, print_every=50, eval_every=50,verbose=1) 112 | 113 | encoder = EncoderRNN(input_size=1, hidden_size=128, num_grulstm_layers=1, batch_size=batch_size).to(device) 114 | decoder = DecoderRNN(input_size=1, hidden_size=128, num_grulstm_layers=1,fc_units=16, output_size=1).to(device) 115 | net_gru_mse = Net_GRU(encoder,decoder, N_output, device).to(device) 116 | train_model(net_gru_mse,loss_type='mse',learning_rate=0.001, epochs=500, gamma=gamma, print_every=50, eval_every=50,verbose=1) 117 | 118 | # Visualize results 119 | gen_test = iter(testloader) 120 | test_inputs, test_targets, breaks = next(gen_test) 121 | 122 | test_inputs = torch.tensor(test_inputs, dtype=torch.float32).to(device) 123 | test_targets = torch.tensor(test_targets, dtype=torch.float32).to(device) 124 | criterion = torch.nn.MSELoss() 125 | 126 | nets = [net_gru_mse,net_gru_dilate] 127 | 128 | for ind in range(1,51): 129 | plt.figure() 130 | plt.rcParams['figure.figsize'] = (17.0,5.0) 131 | k = 1 132 | for net in nets: 133 | pred = net(test_inputs).to(device) 134 | 135 | input = test_inputs.detach().cpu().numpy()[ind,:,:] 136 | target = test_targets.detach().cpu().numpy()[ind,:,:] 137 | preds = pred.detach().cpu().numpy()[ind,:,:] 138 | 139 | plt.subplot(1,3,k) 140 | plt.plot(range(0,N_input) ,input,label='input',linewidth=3) 141 | plt.plot(range(N_input-1,N_input+N_output), np.concatenate([ input[N_input-1:N_input], target ]) ,label='target',linewidth=3) 142 | plt.plot(range(N_input-1,N_input+N_output), np.concatenate([ input[N_input-1:N_input], preds ]) ,label='prediction',linewidth=3) 143 | plt.xticks(range(0,40,2)) 144 | plt.legend() 145 | k = k+1 146 | 147 | plt.show() 148 | -------------------------------------------------------------------------------- /models/seq2seq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class EncoderRNN(torch.nn.Module): 6 | def __init__(self,input_size, hidden_size, num_grulstm_layers, batch_size): 7 | super(EncoderRNN, self).__init__() 8 | self.hidden_size = hidden_size 9 | self.batch_size = batch_size 10 | self.num_grulstm_layers = num_grulstm_layers 11 | self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_grulstm_layers,batch_first=True) 12 | 13 | def forward(self, input, hidden): # input [batch_size, length T, dimensionality d] 14 | output, hidden = self.gru(input, hidden) 15 | return output, hidden 16 | 17 | def init_hidden(self,device): 18 | #[num_layers*num_directions,batch,hidden_size] 19 | return torch.zeros(self.num_grulstm_layers, self.batch_size, self.hidden_size, device=device) 20 | 21 | class DecoderRNN(nn.Module): 22 | def __init__(self, input_size, hidden_size, num_grulstm_layers,fc_units, output_size): 23 | super(DecoderRNN, self).__init__() 24 | self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_grulstm_layers,batch_first=True) 25 | self.fc = nn.Linear(hidden_size, fc_units) 26 | self.out = nn.Linear(fc_units, output_size) 27 | 28 | def forward(self, input, hidden): 29 | output, hidden = self.gru(input, hidden) 30 | output = F.relu( self.fc(output) ) 31 | output = self.out(output) 32 | return output, hidden 33 | 34 | class Net_GRU(nn.Module): 35 | def __init__(self, encoder, decoder, target_length, device): 36 | super(Net_GRU, self).__init__() 37 | self.encoder = encoder 38 | self.decoder = decoder 39 | self.target_length = target_length 40 | self.device = device 41 | 42 | def forward(self, x): 43 | input_length = x.shape[1] 44 | encoder_hidden = self.encoder.init_hidden(self.device) 45 | for ei in range(input_length): 46 | encoder_output, encoder_hidden = self.encoder(x[:,ei:ei+1,:] , encoder_hidden) 47 | 48 | decoder_input = x[:,-1,:].unsqueeze(1) # first decoder input= last element of input sequence 49 | decoder_hidden = encoder_hidden 50 | 51 | outputs = torch.zeros([x.shape[0], self.target_length, x.shape[2]] ).to(self.device) 52 | for di in range(self.target_length): 53 | decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden) 54 | decoder_input = decoder_output 55 | outputs[:,di:di+1,:] = decoder_output 56 | return outputs -------------------------------------------------------------------------------- /poster_DILATE.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vincent-leguen/DILATE/5f114b587fd7abb7b79726ed68d8e4a91049cc0e/poster_DILATE.pdf --------------------------------------------------------------------------------