├── .gitattributes ├── models ├── __init__.py ├── model_helper.py ├── clstm.py ├── crnn.py └── cmlp.py ├── LICENSE ├── synthetic.py └── README.md /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-language=Python 2 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import models.cmlp 2 | import models.clstm 3 | import models.crnn 4 | import models.model_helper 5 | -------------------------------------------------------------------------------- /models/model_helper.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def activation_helper(activation, dim=None): 5 | if activation == 'sigmoid': 6 | act = nn.Sigmoid() 7 | elif activation == 'tanh': 8 | act = nn.Tanh() 9 | elif activation == 'relu': 10 | act = nn.ReLU() 11 | elif activation == 'leakyrelu': 12 | act = nn.LeakyReLU() 13 | elif activation is None: 14 | def act(x): 15 | return x 16 | else: 17 | raise ValueError('unsupported activation: %s' % activation) 18 | return act 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Ian Covert 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /synthetic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.integrate import odeint 3 | 4 | 5 | def make_var_stationary(beta, radius=0.97): 6 | '''Rescale coefficients of VAR model to make stable.''' 7 | p = beta.shape[0] 8 | lag = beta.shape[1] // p 9 | bottom = np.hstack((np.eye(p * (lag - 1)), np.zeros((p * (lag - 1), p)))) 10 | beta_tilde = np.vstack((beta, bottom)) 11 | eigvals = np.linalg.eigvals(beta_tilde) 12 | max_eig = max(np.abs(eigvals)) 13 | nonstationary = max_eig > radius 14 | if nonstationary: 15 | return make_var_stationary(0.95 * beta, radius) 16 | else: 17 | return beta 18 | 19 | 20 | def simulate_var(p, T, lag, sparsity=0.2, beta_value=1.0, sd=0.1, seed=0): 21 | if seed is not None: 22 | np.random.seed(seed) 23 | 24 | # Set up coefficients and Granger causality ground truth. 25 | GC = np.eye(p, dtype=int) 26 | beta = np.eye(p) * beta_value 27 | 28 | num_nonzero = int(p * sparsity) - 1 29 | for i in range(p): 30 | choice = np.random.choice(p - 1, size=num_nonzero, replace=False) 31 | choice[choice >= i] += 1 32 | beta[i, choice] = beta_value 33 | GC[i, choice] = 1 34 | 35 | beta = np.hstack([beta for _ in range(lag)]) 36 | beta = make_var_stationary(beta) 37 | 38 | # Generate data. 39 | burn_in = 100 40 | errors = np.random.normal(scale=sd, size=(p, T + burn_in)) 41 | X = np.zeros((p, T + burn_in)) 42 | X[:, :lag] = errors[:, :lag] 43 | for t in range(lag, T + burn_in): 44 | X[:, t] = np.dot(beta, X[:, (t-lag):t].flatten(order='F')) 45 | X[:, t] += + errors[:, t-1] 46 | 47 | return X.T[burn_in:], beta, GC 48 | 49 | 50 | def lorenz(x, t, F): 51 | '''Partial derivatives for Lorenz-96 ODE.''' 52 | p = len(x) 53 | dxdt = np.zeros(p) 54 | for i in range(p): 55 | dxdt[i] = (x[(i+1) % p] - x[(i-2) % p]) * x[(i-1) % p] - x[i] + F 56 | 57 | return dxdt 58 | 59 | 60 | def simulate_lorenz_96(p, T, F=10.0, delta_t=0.1, sd=0.1, burn_in=1000, 61 | seed=0): 62 | if seed is not None: 63 | np.random.seed(seed) 64 | 65 | # Use scipy to solve ODE. 66 | x0 = np.random.normal(scale=0.01, size=p) 67 | t = np.linspace(0, (T + burn_in) * delta_t, T + burn_in) 68 | X = odeint(lorenz, x0, t, args=(F,)) 69 | X += np.random.normal(scale=sd, size=(T + burn_in, p)) 70 | 71 | # Set up Granger causality ground truth. 72 | GC = np.zeros((p, p), dtype=int) 73 | for i in range(p): 74 | GC[i, i] = 1 75 | GC[i, (i + 1) % p] = 1 76 | GC[i, (i - 1) % p] = 1 77 | GC[i, (i - 2) % p] = 1 78 | 79 | return X[burn_in:], GC 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Granger Causality 2 | 3 | The `Neural-GC` repository contains code for a deep learning-based approach to discovering Granger causality networks in multivariate time series. The methods implemented here are described in [this paper](https://arxiv.org/abs/1802.05842). 4 | 5 | ## Installation 6 | 7 | To install the code, please clone the repository. All you need is `Python 3`, `PyTorch (>= 0.4.0)`, `numpy` and `scipy`. 8 | 9 | ## Usage 10 | 11 | See examples of how to apply our approach in the notebooks `cmlp_lagged_var_demo.ipynb`, `clstm_lorenz_demo.ipynb`, and `crnn_lorenz_demo.ipynb`. 12 | 13 | ## How it works 14 | 15 | The models implemented in this repository, called the cMLP, cLSTM and cRNN, are neural networks that model multivariate time series by forecasting each time series separately. During training, sparse penalties on the input layer's weight matrix set groups of parameters to zero, which can be interpreted as discovering Granger non-causality. 16 | 17 | The cMLP model can be trained with three different penalties: group lasso, group sparse group lasso, and hierarchical. The cLSTM and cRNN models both use a group lasso penalty, and they differ from one another only in the type of RNN cell they use. 18 | 19 | Training models with non-convex loss functions and non-smooth penalties requires a specialized optimization strategy, and we use a proximal gradient descent approach (ISTA). Our paper finds that ISTA provides comparable performance to two other approaches: proximal gradient descent with a line search (GISTA), which guarantees convergence to a local minimum, and Adam, which converges faster (although it requires an additional thresholding parameter). 20 | 21 | ## Other information 22 | 23 | - Selecting the right regularization strength can be difficult and time consuming. To get results for many regularization strengths, you may want to run parallel training jobs or use a warm start strategy. 24 | - Pretraining (training without regularization) followed by ISTA can lead to a different result than training directly with ISTA. Given the non-convex objective function, this is unsurprising, because the initialization from pretraining is very different than a random initialization. You may need to experiment to find what works best for you. 25 | - If you want to train a debiased model with the learned sparsity pattern, use the `cMLPSparse`, `cLSTMSparse`, and `cRNNSparse` classes. 26 | 27 | ## Authors 28 | 29 | - Ian Covert () 30 | - Alex Tank 31 | - Nicholas Foti 32 | - Ali Shojaie 33 | - Emily Fox 34 | 35 | ## References 36 | 37 | - Alex Tank, Ian Covert, Nicholas Foti, Ali Shojaie, Emily Fox. "Neural Granger Causality." *Transactions on Pattern Analysis and Machine Intelligence*, 2021. -------------------------------------------------------------------------------- /models/clstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from copy import deepcopy 5 | 6 | 7 | class LSTM(nn.Module): 8 | def __init__(self, num_series, hidden): 9 | ''' 10 | LSTM model with output layer to generate predictions. 11 | 12 | Args: 13 | num_series: number of input time series. 14 | hidden: number of hidden units. 15 | ''' 16 | super(LSTM, self).__init__() 17 | self.p = num_series 18 | self.hidden = hidden 19 | 20 | # Set up network. 21 | self.lstm = nn.LSTM(num_series, hidden, batch_first=True) 22 | self.lstm.flatten_parameters() 23 | self.linear = nn.Conv1d(hidden, 1, 1) 24 | 25 | def init_hidden(self, batch): 26 | '''Initialize hidden states for LSTM cell.''' 27 | device = self.lstm.weight_ih_l0.device 28 | return (torch.zeros(1, batch, self.hidden, device=device), 29 | torch.zeros(1, batch, self.hidden, device=device)) 30 | 31 | def forward(self, X, hidden=None): 32 | # Set up hidden state. 33 | if hidden is None: 34 | hidden = self.init_hidden(X.shape[0]) 35 | 36 | # Apply LSTM. 37 | X, hidden = self.lstm(X, hidden) 38 | 39 | # Calculate predictions using output layer. 40 | X = X.transpose(2, 1) 41 | X = self.linear(X) 42 | return X.transpose(2, 1), hidden 43 | 44 | 45 | class cLSTM(nn.Module): 46 | def __init__(self, num_series, hidden): 47 | ''' 48 | cLSTM model with one LSTM per time series. 49 | 50 | Args: 51 | num_series: dimensionality of multivariate time series. 52 | hidden: number of units in LSTM cell. 53 | ''' 54 | super(cLSTM, self).__init__() 55 | self.p = num_series 56 | self.hidden = hidden 57 | 58 | # Set up networks. 59 | self.networks = nn.ModuleList([ 60 | LSTM(num_series, hidden) for _ in range(num_series)]) 61 | 62 | def forward(self, X, hidden=None): 63 | ''' 64 | Perform forward pass. 65 | 66 | Args: 67 | X: torch tensor of shape (batch, T, p). 68 | hidden: hidden states for LSTM cell. 69 | ''' 70 | if hidden is None: 71 | hidden = [None for _ in range(self.p)] 72 | pred = [self.networks[i](X, hidden[i]) 73 | for i in range(self.p)] 74 | pred, hidden = zip(*pred) 75 | pred = torch.cat(pred, dim=2) 76 | return pred, hidden 77 | 78 | def GC(self, threshold=True): 79 | ''' 80 | Extract learned Granger causality. 81 | 82 | Args: 83 | threshold: return norm of weights, or whether norm is nonzero. 84 | 85 | Returns: 86 | GC: (p x p) matrix. Entry (i, j) indicates whether variable j is 87 | Granger causal of variable i. 88 | ''' 89 | GC = [torch.norm(net.lstm.weight_ih_l0, dim=0) 90 | for net in self.networks] 91 | GC = torch.stack(GC) 92 | if threshold: 93 | return (GC > 0).int() 94 | else: 95 | return GC 96 | 97 | 98 | class cLSTMSparse(nn.Module): 99 | def __init__(self, num_series, sparsity, hidden): 100 | ''' 101 | cLSTM model that only uses specified interactions. 102 | 103 | Args: 104 | num_series: dimensionality of multivariate time series. 105 | sparsity: torch byte tensor indicating Granger causality, with size 106 | (num_series, num_series). 107 | hidden: number of units in LSTM cell. 108 | ''' 109 | super(cLSTMSparse, self).__init__() 110 | self.p = num_series 111 | self.hidden = hidden 112 | self.sparsity = sparsity 113 | 114 | # Set up networks. 115 | self.networks = nn.ModuleList([ 116 | LSTM(int(torch.sum(sparsity[i].int())), hidden) 117 | for i in range(num_series)]) 118 | 119 | def forward(self, X, hidden=None): 120 | ''' 121 | Perform forward pass. 122 | 123 | Args: 124 | X: torch tensor of shape (batch, T, p). 125 | hidden: hidden states for LSTM cell. 126 | ''' 127 | if hidden is None: 128 | hidden = [None for _ in range(self.p)] 129 | pred = [self.networks[i](X[:, :, self.sparsity[i]], hidden[i]) 130 | for i in range(self.p)] 131 | pred, hidden = zip(*pred) 132 | pred = torch.cat(pred, dim=2) 133 | return pred, hidden 134 | 135 | 136 | def prox_update(network, lam, lr): 137 | '''Perform in place proximal update on first layer weight matrix.''' 138 | W = network.lstm.weight_ih_l0 139 | norm = torch.norm(W, dim=0, keepdim=True) 140 | W.data = ((W / torch.clamp(norm, min=(lam * lr))) 141 | * torch.clamp(norm - (lr * lam), min=0.0)) 142 | network.lstm.flatten_parameters() 143 | 144 | 145 | def regularize(network, lam): 146 | '''Calculate regularization term for first layer weight matrix.''' 147 | W = network.lstm.weight_ih_l0 148 | return lam * torch.sum(torch.norm(W, dim=0)) 149 | 150 | 151 | def ridge_regularize(network, lam): 152 | '''Apply ridge penalty at linear layer and hidden-hidden weights.''' 153 | return lam * ( 154 | torch.sum(network.linear.weight ** 2) + 155 | torch.sum(network.lstm.weight_hh_l0 ** 2)) 156 | 157 | 158 | def restore_parameters(model, best_model): 159 | '''Move parameter values from best_model to model.''' 160 | for params, best_params in zip(model.parameters(), best_model.parameters()): 161 | params.data = best_params 162 | 163 | 164 | def arrange_input(data, context): 165 | ''' 166 | Arrange a single time series into overlapping short sequences. 167 | 168 | Args: 169 | data: time series of shape (T, dim). 170 | context: length of short sequences. 171 | ''' 172 | assert context >= 1 and isinstance(context, int) 173 | input = torch.zeros(len(data) - context, context, data.shape[1], 174 | dtype=torch.float32, device=data.device) 175 | target = torch.zeros(len(data) - context, context, data.shape[1], 176 | dtype=torch.float32, device=data.device) 177 | for i in range(context): 178 | start = i 179 | end = len(data) - context + i 180 | input[:, i, :] = data[start:end] 181 | target[:, i, :] = data[start+1:end+1] 182 | return input.detach(), target.detach() 183 | 184 | 185 | def train_model_gista(clstm, X, context, lam, lam_ridge, lr, max_iter, 186 | check_every=50, r=0.8, lr_min=1e-8, sigma=0.5, 187 | monotone=False, m=10, lr_decay=0.5, 188 | begin_line_search=True, switch_tol=1e-3, verbose=1): 189 | ''' 190 | Train cLSTM model with GISTA. 191 | 192 | Args: 193 | clstm: clstm model. 194 | X: tensor of data, shape (batch, T, p). 195 | context: length for short overlapping subsequences. 196 | lam: parameter for nonsmooth regularization. 197 | lam_ridge: parameter for ridge regularization on output layer. 198 | lr: learning rate. 199 | max_iter: max number of GISTA iterations. 200 | check_every: how frequently to record loss. 201 | r: for line search. 202 | lr_min: for line search. 203 | sigma: for line search. 204 | monotone: for line search. 205 | m: for line search. 206 | lr_decay: for adjusting initial learning rate of line search. 207 | begin_line_search: whether to begin with line search. 208 | switch_tol: tolerance for switching to line search. 209 | verbose: level of verbosity (0, 1, 2). 210 | ''' 211 | p = clstm.p 212 | clstm_copy = deepcopy(clstm) 213 | loss_fn = nn.MSELoss(reduction='mean') 214 | lr_list = [lr for _ in range(p)] 215 | 216 | # Set up data. 217 | X, Y = zip(*[arrange_input(x, context) for x in X]) 218 | X = torch.cat(X, dim=0) 219 | Y = torch.cat(Y, dim=0) 220 | 221 | # Calculate full loss. 222 | mse_list = [] 223 | smooth_list = [] 224 | loss_list = [] 225 | for i in range(p): 226 | net = clstm.networks[i] 227 | pred, _ = net(X) 228 | mse = loss_fn(pred[:, :, 0], Y[:, :, i]) 229 | ridge = ridge_regularize(net, lam_ridge) 230 | smooth = mse + ridge 231 | mse_list.append(mse) 232 | smooth_list.append(smooth) 233 | with torch.no_grad(): 234 | nonsmooth = regularize(net, lam) 235 | loss = smooth + nonsmooth 236 | loss_list.append(loss) 237 | 238 | # Set up lists for loss and mse. 239 | with torch.no_grad(): 240 | loss_mean = sum(loss_list) / p 241 | mse_mean = sum(mse_list) / p 242 | train_loss_list = [loss_mean] 243 | train_mse_list = [mse_mean] 244 | 245 | # For switching to line search. 246 | line_search = begin_line_search 247 | 248 | # For line search criterion. 249 | done = [False for _ in range(p)] 250 | assert 0 < sigma <= 1 251 | assert m > 0 252 | if not monotone: 253 | last_losses = [[loss_list[i]] for i in range(p)] 254 | 255 | for it in range(max_iter): 256 | # Backpropagate errors. 257 | sum([smooth_list[i] for i in range(p) if not done[i]]).backward() 258 | 259 | # For next iteration. 260 | new_mse_list = [] 261 | new_smooth_list = [] 262 | new_loss_list = [] 263 | 264 | # Perform GISTA step for each network. 265 | for i in range(p): 266 | # Skip if network converged. 267 | if done[i]: 268 | new_mse_list.append(mse_list[i]) 269 | new_smooth_list.append(smooth_list[i]) 270 | new_loss_list.append(loss_list[i]) 271 | continue 272 | 273 | # Prepare for line search. 274 | step = False 275 | lr_it = lr_list[i] 276 | net = clstm.networks[i] 277 | net_copy = clstm_copy.networks[i] 278 | 279 | while not step: 280 | # Perform tentative ISTA step. 281 | for param, temp_param in zip(net.parameters(), 282 | net_copy.parameters()): 283 | temp_param.data = param - lr_it * param.grad 284 | 285 | # Proximal update. 286 | prox_update(net_copy, lam, lr_it) 287 | 288 | # Check line search criterion. 289 | pred, _ = net_copy(X) 290 | mse = loss_fn(pred[:, :, 0], Y[:, :, i]) 291 | ridge = ridge_regularize(net_copy, lam_ridge) 292 | smooth = mse + ridge 293 | with torch.no_grad(): 294 | nonsmooth = regularize(net_copy, lam) 295 | loss = smooth + nonsmooth 296 | tol = (0.5 * sigma / lr_it) * sum( 297 | [torch.sum((param - temp_param) ** 2) 298 | for param, temp_param in 299 | zip(net.parameters(), net_copy.parameters())]) 300 | 301 | comp = loss_list[i] if monotone else max(last_losses[i]) 302 | if not line_search or (comp - loss) > tol: 303 | step = True 304 | if verbose > 1: 305 | print('Taking step, network i = %d, lr = %f' 306 | % (i, lr_it)) 307 | print('Gap = %f, tol = %f' % (comp - loss, tol)) 308 | 309 | # For next iteration. 310 | new_mse_list.append(mse) 311 | new_smooth_list.append(smooth) 312 | new_loss_list.append(loss) 313 | 314 | # Adjust initial learning rate. 315 | lr_list[i] = ( 316 | (lr_list[i] ** (1 - lr_decay)) * (lr_it ** lr_decay)) 317 | 318 | if not monotone: 319 | if len(last_losses[i]) == m: 320 | last_losses[i].pop(0) 321 | last_losses[i].append(loss) 322 | else: 323 | # Reduce learning rate. 324 | lr_it *= r 325 | if lr_it < lr_min: 326 | done[i] = True 327 | new_mse_list.append(mse_list[i]) 328 | new_smooth_list.append(smooth_list[i]) 329 | new_loss_list.append(loss_list[i]) 330 | if verbose > 0: 331 | print('Network %d converged' % (i + 1)) 332 | break 333 | 334 | # Clean up. 335 | net.zero_grad() 336 | 337 | if step: 338 | # Swap network parameters. 339 | clstm.networks[i], clstm_copy.networks[i] = net_copy, net 340 | 341 | # For next iteration. 342 | mse_list = new_mse_list 343 | smooth_list = new_smooth_list 344 | loss_list = new_loss_list 345 | 346 | # Check if all networks have converged. 347 | if sum(done) == p: 348 | if verbose > 0: 349 | print('Done at iteration = %d' % (it + 1)) 350 | break 351 | 352 | # Check progress 353 | if (it + 1) % check_every == 0: 354 | with torch.no_grad(): 355 | loss_mean = sum(loss_list) / p 356 | mse_mean = sum(mse_list) / p 357 | ridge_mean = (sum(smooth_list) - sum(mse_list)) / p 358 | nonsmooth_mean = (sum(loss_list) - sum(smooth_list)) / p 359 | 360 | train_loss_list.append(loss_mean) 361 | train_mse_list.append(mse_mean) 362 | 363 | if verbose > 0: 364 | print(('-' * 10 + 'Iter = %d' + '-' * 10) % (it + 1)) 365 | print('Total loss = %f' % loss_mean) 366 | print('MSE = %f, Ridge = %f, Nonsmooth = %f' 367 | % (mse_mean, ridge_mean, nonsmooth_mean)) 368 | print('Variable usage = %.2f%%' 369 | % (100 * torch.mean(clstm.GC().float()))) 370 | 371 | # Check whether loss has increased. 372 | if not line_search: 373 | if train_loss_list[-2] - train_loss_list[-1] < switch_tol: 374 | line_search = True 375 | if verbose > 0: 376 | print('Switching to line search') 377 | 378 | return train_loss_list, train_mse_list 379 | 380 | 381 | def train_model_adam(clstm, X, context, lr, max_iter, lam=0, lam_ridge=0, 382 | lookback=5, check_every=50, verbose=1): 383 | '''Train model with Adam.''' 384 | p = X.shape[-1] 385 | loss_fn = nn.MSELoss(reduction='mean') 386 | optimizer = torch.optim.Adam(clstm.parameters(), lr=lr) 387 | train_loss_list = [] 388 | 389 | # Set up data. 390 | X, Y = zip(*[arrange_input(x, context) for x in X]) 391 | X = torch.cat(X, dim=0) 392 | Y = torch.cat(Y, dim=0) 393 | 394 | # For early stopping. 395 | best_it = None 396 | best_loss = np.inf 397 | best_model = None 398 | 399 | for it in range(max_iter): 400 | # Calculate loss. 401 | pred = [clstm.networks[i](X)[0] for i in range(p)] 402 | loss = sum([loss_fn(pred[i][:, :, 0], Y[:, :, i]) for i in range(p)]) 403 | 404 | # Add penalty term. 405 | if lam > 0: 406 | loss = loss + sum([regularize(net, lam) for net in clstm.networks]) 407 | 408 | if lam_ridge > 0: 409 | loss = loss + sum([ridge_regularize(net, lam_ridge) 410 | for net in clstm.networks]) 411 | 412 | # Take gradient step. 413 | loss.backward() 414 | optimizer.step() 415 | clstm.zero_grad() 416 | 417 | # Check progress. 418 | if (it + 1) % check_every == 0: 419 | mean_loss = loss / p 420 | train_loss_list.append(mean_loss.detach()) 421 | 422 | if verbose > 0: 423 | print(('-' * 10 + 'Iter = %d' + '-' * 10) % (it + 1)) 424 | print('Loss = %f' % mean_loss) 425 | 426 | # Check for early stopping. 427 | if mean_loss < best_loss: 428 | best_loss = mean_loss 429 | best_it = it 430 | best_model = deepcopy(clstm) 431 | elif (it - best_it) == lookback * check_every: 432 | if verbose: 433 | print('Stopping early') 434 | break 435 | 436 | # Restore best model. 437 | restore_parameters(clstm, best_model) 438 | 439 | return train_loss_list 440 | 441 | 442 | def train_model_ista(clstm, X, context, lr, max_iter, lam=0, lam_ridge=0, 443 | lookback=5, check_every=50, verbose=1): 444 | '''Train model with Adam.''' 445 | p = X.shape[-1] 446 | loss_fn = nn.MSELoss(reduction='mean') 447 | train_loss_list = [] 448 | 449 | # Set up data. 450 | X, Y = zip(*[arrange_input(x, context) for x in X]) 451 | X = torch.cat(X, dim=0) 452 | Y = torch.cat(Y, dim=0) 453 | 454 | # For early stopping. 455 | best_it = None 456 | best_loss = np.inf 457 | best_model = None 458 | 459 | # Calculate smooth error. 460 | pred = [clstm.networks[i](X)[0] for i in range(p)] 461 | loss = sum([loss_fn(pred[i][:, :, 0], Y[:, :, i]) for i in range(p)]) 462 | ridge = sum([ridge_regularize(net, lam_ridge) for net in clstm.networks]) 463 | smooth = loss + ridge 464 | 465 | for it in range(max_iter): 466 | # Take gradient step. 467 | smooth.backward() 468 | for param in clstm.parameters(): 469 | param.data -= lr * param.grad 470 | 471 | # Take prox step. 472 | if lam > 0: 473 | for net in clstm.networks: 474 | prox_update(net, lam, lr) 475 | 476 | clstm.zero_grad() 477 | 478 | # Calculate loss for next iteration. 479 | pred = [clstm.networks[i](X)[0] for i in range(p)] 480 | loss = sum([loss_fn(pred[i][:, :, 0], Y[:, :, i]) for i in range(p)]) 481 | ridge = sum([ridge_regularize(net, lam_ridge) 482 | for net in clstm.networks]) 483 | smooth = loss + ridge 484 | 485 | # Check progress. 486 | if (it + 1) % check_every == 0: 487 | # Add nonsmooth penalty. 488 | nonsmooth = sum([regularize(net, lam) for net in clstm.networks]) 489 | mean_loss = (smooth + nonsmooth) / p 490 | train_loss_list.append(mean_loss.detach()) 491 | 492 | if verbose > 0: 493 | print(('-' * 10 + 'Iter = %d' + '-' * 10) % (it + 1)) 494 | print('Loss = %f' % mean_loss) 495 | print('Variable usage = %.2f%%' 496 | % (100 * torch.mean(clstm.GC().float()))) 497 | 498 | # Check for early stopping. 499 | if mean_loss < best_loss: 500 | best_loss = mean_loss 501 | best_it = it 502 | best_model = deepcopy(clstm) 503 | elif (it - best_it) == lookback * check_every: 504 | if verbose: 505 | print('Stopping early') 506 | break 507 | 508 | # Restore best model. 509 | restore_parameters(clstm, best_model) 510 | 511 | return train_loss_list 512 | 513 | 514 | def train_unregularized(clstm, X, context, lr, max_iter, lookback=5, 515 | check_every=50, verbose=1): 516 | '''Train model with Adam.''' 517 | p = X.shape[-1] 518 | loss_fn = nn.MSELoss(reduction='mean') 519 | optimizer = torch.optim.Adam(clstm.parameters(), lr=lr) 520 | train_loss_list = [] 521 | 522 | # Set up data. 523 | X, Y = zip(*[arrange_input(x, context) for x in X]) 524 | X = torch.cat(X, dim=0) 525 | Y = torch.cat(Y, dim=0) 526 | 527 | # For early stopping. 528 | best_it = None 529 | best_loss = np.inf 530 | best_model = None 531 | 532 | for it in range(max_iter): 533 | # Calculate loss. 534 | pred, hidden = clstm(X) 535 | loss = sum([loss_fn(pred[:, :, i], Y[:, :, i]) for i in range(p)]) 536 | 537 | # Take gradient step. 538 | loss.backward() 539 | optimizer.step() 540 | clstm.zero_grad() 541 | 542 | # Check progress. 543 | if (it + 1) % check_every == 0: 544 | mean_loss = loss / p 545 | train_loss_list.append(mean_loss.detach()) 546 | 547 | if verbose > 0: 548 | print(('-' * 10 + 'Iter = %d' + '-' * 10) % (it + 1)) 549 | print('Loss = %f' % mean_loss) 550 | 551 | # Check for early stopping. 552 | if mean_loss < best_loss: 553 | best_loss = mean_loss 554 | best_it = it 555 | best_model = deepcopy(clstm) 556 | elif (it - best_it) == lookback * check_every: 557 | if verbose: 558 | print('Stopping early') 559 | break 560 | 561 | # Restore best model. 562 | restore_parameters(clstm, best_model) 563 | 564 | return train_loss_list 565 | -------------------------------------------------------------------------------- /models/crnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from copy import deepcopy 5 | 6 | 7 | class RNN(nn.Module): 8 | def __init__(self, num_series, hidden, nonlinearity): 9 | ''' 10 | RNN model with output layer to generate predictions. 11 | 12 | Args: 13 | num_series: number of input time series. 14 | hidden: number of hidden units. 15 | ''' 16 | super(RNN, self).__init__() 17 | self.p = num_series 18 | self.hidden = hidden 19 | 20 | # Set up network. 21 | self.rnn = nn.RNN(num_series, hidden, nonlinearity=nonlinearity, 22 | batch_first=True) 23 | self.rnn.flatten_parameters() 24 | self.linear = nn.Conv1d(hidden, 1, 1) 25 | 26 | def init_hidden(self, batch): 27 | '''Initialize hidden states for RNN cell.''' 28 | device = self.rnn.weight_ih_l0.device 29 | return torch.zeros(1, batch, self.hidden, device=device) 30 | 31 | def forward(self, X, hidden=None, truncation=None): 32 | # Set up hidden state. 33 | if hidden is None: 34 | hidden = self.init_hidden(X.shape[0]) 35 | 36 | # Apply RNN. 37 | X, hidden = self.rnn(X, hidden) 38 | 39 | # Calculate predictions using output layer. 40 | X = X.transpose(2, 1) 41 | X = self.linear(X) 42 | return X.transpose(2, 1), hidden 43 | 44 | 45 | class cRNN(nn.Module): 46 | def __init__(self, num_series, hidden, nonlinearity='relu'): 47 | ''' 48 | cRNN model with one RNN per time series. 49 | 50 | Args: 51 | num_series: dimensionality of multivariate time series. 52 | hidden: number of units in RNN cell. 53 | nonlinearity: nonlinearity of RNN cell. 54 | ''' 55 | super(cRNN, self).__init__() 56 | self.p = num_series 57 | self.hidden = hidden 58 | 59 | # Set up networks. 60 | self.networks = nn.ModuleList([ 61 | RNN(num_series, hidden, nonlinearity) for _ in range(num_series)]) 62 | 63 | def forward(self, X, hidden=None): 64 | ''' 65 | Perform forward pass. 66 | 67 | Args: 68 | X: torch tensor of shape (batch, T, p). 69 | hidden: hidden states for RNN cell. 70 | ''' 71 | if hidden is None: 72 | hidden = [None for _ in range(self.p)] 73 | pred = [self.networks[i](X, hidden[i]) 74 | for i in range(self.p)] 75 | pred, hidden = zip(*pred) 76 | pred = torch.cat(pred, dim=2) 77 | return pred, hidden 78 | 79 | def GC(self, threshold=True): 80 | ''' 81 | Extract learned Granger causality. 82 | 83 | Args: 84 | threshold: return norm of weights, or whether norm is nonzero. 85 | 86 | Returns: 87 | GC: (p x p) matrix. Entry (i, j) indicates whether variable j is 88 | Granger causal of variable i. 89 | ''' 90 | GC = [torch.norm(net.rnn.weight_ih_l0, dim=0) 91 | for net in self.networks] 92 | GC = torch.stack(GC) 93 | if threshold: 94 | return (GC > 0).int() 95 | else: 96 | return GC 97 | 98 | 99 | class cRNNSparse(nn.Module): 100 | def __init__(self, num_series, sparsity, hidden, nonlinearity='relu'): 101 | ''' 102 | cRNN model that only uses specified interactions. 103 | 104 | Args: 105 | num_series: dimensionality of multivariate time series. 106 | sparsity: torch byte tensor indicating Granger causality, with size 107 | (num_series, num_series). 108 | hidden: number of units in RNN cell. 109 | nonlinearity: nonlinearity of RNN cell. 110 | ''' 111 | super(cRNNSparse, self).__init__() 112 | self.p = num_series 113 | self.hidden = hidden 114 | self.sparsity = sparsity 115 | 116 | # Set up networks. 117 | self.networks = nn.ModuleList([ 118 | RNN(int(torch.sum(sparsity[i].int())), hidden, nonlinearity) 119 | for i in range(num_series)]) 120 | 121 | def forward(self, X, i=None, hidden=None, truncation=None): 122 | '''Perform forward pass. 123 | 124 | Args: 125 | X: torch tensor of shape (batch, T, p). 126 | i: index of the time series to forecast. 127 | hidden: hidden states for RNN cell. 128 | ''' 129 | if hidden is None: 130 | hidden = [None for _ in range(self.p)] 131 | pred = [self.networks[i](X[:, :, self.sparsity[i]], hidden[i]) 132 | for i in range(self.p)] 133 | pred, hidden = zip(*pred) 134 | pred = torch.cat(pred, dim=2) 135 | return pred, hidden 136 | 137 | 138 | def prox_update(network, lam, lr): 139 | '''Perform in place proximal update on first layer weight matrix.''' 140 | W = network.rnn.weight_ih_l0 141 | norm = torch.norm(W, dim=0, keepdim=True) 142 | W.data = ((W / torch.clamp(norm, min=(lam * lr))) 143 | * torch.clamp(norm - (lr * lam), min=0.0)) 144 | network.rnn.flatten_parameters() 145 | 146 | 147 | def regularize(network, lam): 148 | '''Calculate regularization term for first layer weight matrix.''' 149 | W = network.rnn.weight_ih_l0 150 | return lam * torch.sum(torch.norm(W, dim=0)) 151 | 152 | 153 | def ridge_regularize(network, lam): 154 | '''Apply ridge penalty at linear layer and hidden-hidden weights.''' 155 | return lam * ( 156 | torch.sum(network.linear.weight ** 2) + 157 | torch.sum(network.rnn.weight_hh_l0 ** 2)) 158 | 159 | 160 | def restore_parameters(model, best_model): 161 | '''Move parameter values from best_model to model.''' 162 | for params, best_params in zip(model.parameters(), best_model.parameters()): 163 | params.data = best_params 164 | 165 | 166 | def arrange_input(data, context): 167 | ''' 168 | Arrange a single time series into overlapping short sequences. 169 | 170 | Args: 171 | data: time series of shape (T, dim). 172 | context: length of short sequences. 173 | ''' 174 | assert context >= 1 and isinstance(context, int) 175 | input = torch.zeros(len(data) - context, context, data.shape[1], 176 | dtype=torch.float32, device=data.device) 177 | target = torch.zeros(len(data) - context, context, data.shape[1], 178 | dtype=torch.float32, device=data.device) 179 | for i in range(context): 180 | start = i 181 | end = len(data) - context + i 182 | input[:, i, :] = data[start:end] 183 | target[:, i, :] = data[start+1:end+1] 184 | return input.detach(), target.detach() 185 | 186 | 187 | def train_model_gista(crnn, X, context, lam, lam_ridge, lr, max_iter, 188 | check_every=50, r=0.8, lr_min=1e-8, sigma=0.5, 189 | monotone=False, m=10, lr_decay=0.5, 190 | begin_line_search=True, switch_tol=1e-3, verbose=1): 191 | ''' 192 | Train cRNN model with GISTA. 193 | 194 | Args: 195 | crnn: crnn model. 196 | X: tensor of data, shape (batch, T, p). 197 | context: length for short overlapping subsequences. 198 | lam: parameter for nonsmooth regularization. 199 | lam_ridge: parameter for ridge regularization on output layer. 200 | lr: learning rate. 201 | max_iter: max number of GISTA iterations. 202 | check_every: how frequently to record loss. 203 | r: for line search. 204 | lr_min: for line search. 205 | sigma: for line search. 206 | monotone: for line search. 207 | m: for line search. 208 | lr_decay: for adjusting initial learning rate of line search. 209 | begin_line_search: whether to begin with line search. 210 | switch_tol: tolerance for switching to line search. 211 | verbose: level of verbosity (0, 1, 2). 212 | ''' 213 | p = crnn.p 214 | crnn_copy = deepcopy(crnn) 215 | loss_fn = nn.MSELoss(reduction='mean') 216 | lr_list = [lr for _ in range(p)] 217 | 218 | # Set up data. 219 | X, Y = zip(*[arrange_input(x, context) for x in X]) 220 | X = torch.cat(X, dim=0) 221 | Y = torch.cat(Y, dim=0) 222 | 223 | # Calculate full loss. 224 | mse_list = [] 225 | smooth_list = [] 226 | loss_list = [] 227 | for i in range(p): 228 | net = crnn.networks[i] 229 | pred, _ = net(X) 230 | mse = loss_fn(pred[:, :, 0], Y[:, :, i]) 231 | ridge = ridge_regularize(net, lam_ridge) 232 | smooth = mse + ridge 233 | mse_list.append(mse) 234 | smooth_list.append(smooth) 235 | with torch.no_grad(): 236 | nonsmooth = regularize(net, lam) 237 | loss = smooth + nonsmooth 238 | loss_list.append(loss) 239 | 240 | # Set up lists for loss and mse. 241 | with torch.no_grad(): 242 | loss_mean = sum(loss_list) / p 243 | mse_mean = sum(mse_list) / p 244 | train_loss_list = [loss_mean] 245 | train_mse_list = [mse_mean] 246 | 247 | # For switching to line search. 248 | line_search = begin_line_search 249 | 250 | # For line search criterion. 251 | done = [False for _ in range(p)] 252 | assert 0 < sigma <= 1 253 | assert m > 0 254 | if not monotone: 255 | last_losses = [[loss_list[i]] for i in range(p)] 256 | 257 | for it in range(max_iter): 258 | # Backpropagate errors. 259 | sum([smooth_list[i] for i in range(p) if not done[i]]).backward() 260 | 261 | # For next iteration. 262 | new_mse_list = [] 263 | new_smooth_list = [] 264 | new_loss_list = [] 265 | 266 | # Perform GISTA step for each network. 267 | for i in range(p): 268 | # Skip if network converged. 269 | if done[i]: 270 | new_mse_list.append(mse_list[i]) 271 | new_smooth_list.append(smooth_list[i]) 272 | new_loss_list.append(loss_list[i]) 273 | continue 274 | 275 | # Prepare for line search. 276 | step = False 277 | lr_it = lr_list[i] 278 | net = crnn.networks[i] 279 | net_copy = crnn_copy.networks[i] 280 | 281 | while not step: 282 | # Perform tentative ISTA step. 283 | for param, temp_param in zip(net.parameters(), 284 | net_copy.parameters()): 285 | temp_param.data = param - lr_it * param.grad 286 | 287 | # Proximal update. 288 | prox_update(net_copy, lam, lr_it) 289 | 290 | # Check line search criterion. 291 | pred, _ = net_copy(X) 292 | mse = loss_fn(pred[:, :, 0], Y[:, :, i]) 293 | ridge = ridge_regularize(net_copy, lam_ridge) 294 | smooth = mse + ridge 295 | with torch.no_grad(): 296 | nonsmooth = regularize(net_copy, lam) 297 | loss = smooth + nonsmooth 298 | tol = (0.5 * sigma / lr_it) * sum( 299 | [torch.sum((param - temp_param) ** 2) 300 | for param, temp_param in 301 | zip(net.parameters(), net_copy.parameters())]) 302 | 303 | comp = loss_list[i] if monotone else max(last_losses[i]) 304 | if not line_search or (comp - loss) > tol: 305 | step = True 306 | if verbose > 1: 307 | print('Taking step, network i = %d, lr = %f' 308 | % (i, lr_it)) 309 | print('Gap = %f, tol = %f' % (comp - loss, tol)) 310 | 311 | # For next iteration. 312 | new_mse_list.append(mse) 313 | new_smooth_list.append(smooth) 314 | new_loss_list.append(loss) 315 | 316 | # Adjust initial learning rate. 317 | lr_list[i] = ( 318 | (lr_list[i] ** (1 - lr_decay)) * (lr_it ** lr_decay)) 319 | 320 | if not monotone: 321 | if len(last_losses[i]) == m: 322 | last_losses[i].pop(0) 323 | last_losses[i].append(loss) 324 | else: 325 | # Reduce learning rate. 326 | lr_it *= r 327 | if lr_it < lr_min: 328 | done[i] = True 329 | new_mse_list.append(mse_list[i]) 330 | new_smooth_list.append(smooth_list[i]) 331 | new_loss_list.append(loss_list[i]) 332 | if verbose > 0: 333 | print('Network %d converged' % (i + 1)) 334 | break 335 | 336 | # Clean up. 337 | net.zero_grad() 338 | 339 | if step: 340 | # Swap network parameters. 341 | crnn.networks[i], crnn_copy.networks[i] = net_copy, net 342 | 343 | # For next iteration. 344 | mse_list = new_mse_list 345 | smooth_list = new_smooth_list 346 | loss_list = new_loss_list 347 | 348 | # Check if all networks have converged. 349 | if sum(done) == p: 350 | if verbose > 0: 351 | print('Done at iteration = %d' % (it + 1)) 352 | break 353 | 354 | # Check progress 355 | if (it + 1) % check_every == 0: 356 | with torch.no_grad(): 357 | loss_mean = sum(loss_list) / p 358 | mse_mean = sum(mse_list) / p 359 | ridge_mean = (sum(smooth_list) - sum(mse_list)) / p 360 | nonsmooth_mean = (sum(loss_list) - sum(smooth_list)) / p 361 | 362 | train_loss_list.append(loss_mean) 363 | train_mse_list.append(mse_mean) 364 | 365 | if verbose > 0: 366 | print(('-' * 10 + 'Iter = %d' + '-' * 10) % (it + 1)) 367 | print('Total loss = %f' % loss_mean) 368 | print('MSE = %f, Ridge = %f, Nonsmooth = %f' 369 | % (mse_mean, ridge_mean, nonsmooth_mean)) 370 | print('Variable usage = %.2f%%' 371 | % (100 * torch.mean(crnn.GC().float()))) 372 | 373 | # Check whether loss has increased. 374 | if not line_search: 375 | if train_loss_list[-2] - train_loss_list[-1] < switch_tol: 376 | line_search = True 377 | if verbose > 0: 378 | print('Switching to line search') 379 | 380 | return train_loss_list, train_mse_list 381 | 382 | 383 | def train_model_adam(crnn, X, context, lr, max_iter, lam=0, lam_ridge=0, 384 | lookback=5, check_every=50, verbose=1): 385 | '''Train model with Adam.''' 386 | p = X.shape[-1] 387 | loss_fn = nn.MSELoss(reduction='mean') 388 | optimizer = torch.optim.Adam(crnn.parameters(), lr=lr) 389 | train_loss_list = [] 390 | 391 | # Set up data. 392 | X, Y = zip(*[arrange_input(x, context) for x in X]) 393 | X = torch.cat(X, dim=0) 394 | Y = torch.cat(Y, dim=0) 395 | 396 | # For early stopping. 397 | best_it = None 398 | best_loss = np.inf 399 | best_model = None 400 | 401 | for it in range(max_iter): 402 | # Calculate loss. 403 | pred = [crnn.networks[i](X)[0] for i in range(p)] 404 | loss = sum([loss_fn(pred[i][:, :, 0], Y[:, :, i]) for i in range(p)]) 405 | 406 | # Add penalty term. 407 | if lam > 0: 408 | loss = loss + sum([regularize(net, lam) for net in crnn.networks]) 409 | 410 | if lam_ridge > 0: 411 | loss = loss + sum([ridge_regularize(net, lam_ridge) 412 | for net in crnn.networks]) 413 | 414 | # Take gradient step. 415 | loss.backward() 416 | optimizer.step() 417 | crnn.zero_grad() 418 | 419 | # Check progress. 420 | if (it + 1) % check_every == 0: 421 | mean_loss = loss / p 422 | train_loss_list.append(mean_loss.detach()) 423 | 424 | if verbose > 0: 425 | print(('-' * 10 + 'Iter = %d' + '-' * 10) % (it + 1)) 426 | print('Loss = %f' % mean_loss) 427 | 428 | # Check for early stopping. 429 | if mean_loss < best_loss: 430 | best_loss = mean_loss 431 | best_it = it 432 | best_model = deepcopy(crnn) 433 | elif (it - best_it) == lookback * check_every: 434 | if verbose: 435 | print('Stopping early') 436 | break 437 | 438 | # Restore best model. 439 | restore_parameters(crnn, best_model) 440 | 441 | return train_loss_list 442 | 443 | 444 | def train_model_ista(crnn, X, context, lr, max_iter, lam=0, lam_ridge=0, 445 | lookback=5, check_every=50, verbose=1): 446 | '''Train model with Adam.''' 447 | p = X.shape[-1] 448 | loss_fn = nn.MSELoss(reduction='mean') 449 | train_loss_list = [] 450 | 451 | # Set up data. 452 | X, Y = zip(*[arrange_input(x, context) for x in X]) 453 | X = torch.cat(X, dim=0) 454 | Y = torch.cat(Y, dim=0) 455 | 456 | # For early stopping. 457 | best_it = None 458 | best_loss = np.inf 459 | best_model = None 460 | 461 | # Calculate smooth error. 462 | pred = [crnn.networks[i](X)[0] for i in range(p)] 463 | loss = sum([loss_fn(pred[i][:, :, 0], Y[:, :, i]) for i in range(p)]) 464 | ridge = sum([ridge_regularize(net, lam_ridge) for net in crnn.networks]) 465 | smooth = loss + ridge 466 | 467 | for it in range(max_iter): 468 | # Take gradient step. 469 | smooth.backward() 470 | for param in crnn.parameters(): 471 | param.data -= lr * param.grad 472 | 473 | # Take prox step. 474 | if lam > 0: 475 | for net in crnn.networks: 476 | prox_update(net, lam, lr) 477 | 478 | crnn.zero_grad() 479 | 480 | # Calculate loss for next iteration. 481 | pred = [crnn.networks[i](X)[0] for i in range(p)] 482 | loss = sum([loss_fn(pred[i][:, :, 0], Y[:, :, i]) for i in range(p)]) 483 | ridge = sum([ridge_regularize(net, lam_ridge) 484 | for net in crnn.networks]) 485 | smooth = loss + ridge 486 | 487 | # Check progress. 488 | if (it + 1) % check_every == 0: 489 | # Add nonsmooth penalty. 490 | nonsmooth = sum([regularize(net, lam) for net in crnn.networks]) 491 | mean_loss = (smooth + nonsmooth) / p 492 | train_loss_list.append(mean_loss.detach()) 493 | 494 | if verbose > 0: 495 | print(('-' * 10 + 'Iter = %d' + '-' * 10) % (it + 1)) 496 | print('Loss = %f' % mean_loss) 497 | print('Variable usage = %.2f%%' 498 | % (100 * torch.mean(crnn.GC().float()))) 499 | 500 | # Check for early stopping. 501 | if mean_loss < best_loss: 502 | best_loss = mean_loss 503 | best_it = it 504 | best_model = deepcopy(crnn) 505 | elif (it - best_it) == lookback * check_every: 506 | if verbose: 507 | print('Stopping early') 508 | break 509 | 510 | # Restore best model. 511 | restore_parameters(crnn, best_model) 512 | 513 | return train_loss_list 514 | 515 | 516 | def train_unregularized(crnn, X, context, lr, max_iter, lookback=5, 517 | check_every=50, verbose=1): 518 | '''Train model with Adam.''' 519 | p = X.shape[-1] 520 | loss_fn = nn.MSELoss(reduction='mean') 521 | optimizer = torch.optim.Adam(crnn.parameters(), lr=lr) 522 | train_loss_list = [] 523 | 524 | # Set up data. 525 | X, Y = zip(*[arrange_input(x, context) for x in X]) 526 | X = torch.cat(X, dim=0) 527 | Y = torch.cat(Y, dim=0) 528 | 529 | # For early stopping. 530 | best_it = None 531 | best_loss = np.inf 532 | best_model = None 533 | 534 | for it in range(max_iter): 535 | # Calculate loss. 536 | pred, hidden = crnn(X) 537 | loss = sum([loss_fn(pred[:, :, i], Y[:, :, i]) for i in range(p)]) 538 | 539 | # Take gradient step. 540 | loss.backward() 541 | optimizer.step() 542 | crnn.zero_grad() 543 | 544 | # Check progress. 545 | if (it + 1) % check_every == 0: 546 | mean_loss = loss / p 547 | train_loss_list.append(mean_loss.detach()) 548 | 549 | if verbose > 0: 550 | print(('-' * 10 + 'Iter = %d' + '-' * 10) % (it + 1)) 551 | print('Loss = %f' % mean_loss) 552 | 553 | # Check for early stopping. 554 | if mean_loss < best_loss: 555 | best_loss = mean_loss 556 | best_it = it 557 | best_model = deepcopy(crnn) 558 | elif (it - best_it) == lookback * check_every: 559 | if verbose: 560 | print('Stopping early') 561 | break 562 | 563 | # Restore best model. 564 | restore_parameters(crnn, best_model) 565 | 566 | return train_loss_list 567 | -------------------------------------------------------------------------------- /models/cmlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from copy import deepcopy 5 | from models.model_helper import activation_helper 6 | 7 | 8 | class MLP(nn.Module): 9 | def __init__(self, num_series, lag, hidden, activation): 10 | super(MLP, self).__init__() 11 | self.activation = activation_helper(activation) 12 | 13 | # Set up network. 14 | layer = nn.Conv1d(num_series, hidden[0], lag) 15 | modules = [layer] 16 | 17 | for d_in, d_out in zip(hidden, hidden[1:] + [1]): 18 | layer = nn.Conv1d(d_in, d_out, 1) 19 | modules.append(layer) 20 | 21 | # Register parameters. 22 | self.layers = nn.ModuleList(modules) 23 | 24 | def forward(self, X): 25 | X = X.transpose(2, 1) 26 | for i, fc in enumerate(self.layers): 27 | if i != 0: 28 | X = self.activation(X) 29 | X = fc(X) 30 | 31 | return X.transpose(2, 1) 32 | 33 | 34 | class cMLP(nn.Module): 35 | def __init__(self, num_series, lag, hidden, activation='relu'): 36 | ''' 37 | cMLP model with one MLP per time series. 38 | 39 | Args: 40 | num_series: dimensionality of multivariate time series. 41 | lag: number of previous time points to use in prediction. 42 | hidden: list of number of hidden units per layer. 43 | activation: nonlinearity at each layer. 44 | ''' 45 | super(cMLP, self).__init__() 46 | self.p = num_series 47 | self.lag = lag 48 | self.activation = activation_helper(activation) 49 | 50 | # Set up networks. 51 | self.networks = nn.ModuleList([ 52 | MLP(num_series, lag, hidden, activation) 53 | for _ in range(num_series)]) 54 | 55 | def forward(self, X): 56 | ''' 57 | Perform forward pass. 58 | 59 | Args: 60 | X: torch tensor of shape (batch, T, p). 61 | ''' 62 | return torch.cat([network(X) for network in self.networks], dim=2) 63 | 64 | def GC(self, threshold=True, ignore_lag=True): 65 | ''' 66 | Extract learned Granger causality. 67 | 68 | Args: 69 | threshold: return norm of weights, or whether norm is nonzero. 70 | ignore_lag: if true, calculate norm of weights jointly for all lags. 71 | 72 | Returns: 73 | GC: (p x p) or (p x p x lag) matrix. In first case, entry (i, j) 74 | indicates whether variable j is Granger causal of variable i. In 75 | second case, entry (i, j, k) indicates whether it's Granger causal 76 | at lag k. 77 | ''' 78 | if ignore_lag: 79 | GC = [torch.norm(net.layers[0].weight, dim=(0, 2)) 80 | for net in self.networks] 81 | else: 82 | GC = [torch.norm(net.layers[0].weight, dim=0) 83 | for net in self.networks] 84 | GC = torch.stack(GC) 85 | if threshold: 86 | return (GC > 0).int() 87 | else: 88 | return GC 89 | 90 | 91 | class cMLPSparse(nn.Module): 92 | def __init__(self, num_series, sparsity, lag, hidden, activation='relu'): 93 | ''' 94 | cMLP model that only uses specified interactions. 95 | 96 | Args: 97 | num_series: dimensionality of multivariate time series. 98 | sparsity: torch byte tensor indicating Granger causality, with size 99 | (num_series, num_series). 100 | lag: number of previous time points to use in prediction. 101 | hidden: list of number of hidden units per layer. 102 | activation: nonlinearity at each layer. 103 | ''' 104 | super(cMLPSparse, self).__init__() 105 | self.p = num_series 106 | self.lag = lag 107 | self.activation = activation_helper(activation) 108 | self.sparsity = sparsity 109 | 110 | # Set up networks. 111 | self.networks = [] 112 | for i in range(num_series): 113 | num_inputs = int(torch.sum(sparsity[i].int())) 114 | self.networks.append(MLP(num_inputs, lag, hidden, activation)) 115 | 116 | # Register parameters. 117 | param_list = [] 118 | for i in range(num_series): 119 | param_list += list(self.networks[i].parameters()) 120 | self.param_list = nn.ParameterList(param_list) 121 | 122 | def forward(self, X): 123 | ''' 124 | Perform forward pass. 125 | 126 | Args: 127 | X: torch tensor of shape (batch, T, p). 128 | ''' 129 | return torch.cat([self.networks[i](X[:, :, self.sparsity[i]]) 130 | for i in range(self.p)], dim=2) 131 | 132 | 133 | def prox_update(network, lam, lr, penalty): 134 | ''' 135 | Perform in place proximal update on first layer weight matrix. 136 | 137 | Args: 138 | network: MLP network. 139 | lam: regularization parameter. 140 | lr: learning rate. 141 | penalty: one of GL (group lasso), GSGL (group sparse group lasso), 142 | H (hierarchical). 143 | ''' 144 | W = network.layers[0].weight 145 | hidden, p, lag = W.shape 146 | if penalty == 'GL': 147 | norm = torch.norm(W, dim=(0, 2), keepdim=True) 148 | W.data = ((W / torch.clamp(norm, min=(lr * lam))) 149 | * torch.clamp(norm - (lr * lam), min=0.0)) 150 | elif penalty == 'GSGL': 151 | norm = torch.norm(W, dim=0, keepdim=True) 152 | W.data = ((W / torch.clamp(norm, min=(lr * lam))) 153 | * torch.clamp(norm - (lr * lam), min=0.0)) 154 | norm = torch.norm(W, dim=(0, 2), keepdim=True) 155 | W.data = ((W / torch.clamp(norm, min=(lr * lam))) 156 | * torch.clamp(norm - (lr * lam), min=0.0)) 157 | elif penalty == 'H': 158 | # Lowest indices along third axis touch most lagged values. 159 | for i in range(lag): 160 | norm = torch.norm(W[:, :, :(i + 1)], dim=(0, 2), keepdim=True) 161 | W.data[:, :, :(i+1)] = ( 162 | (W.data[:, :, :(i+1)] / torch.clamp(norm, min=(lr * lam))) 163 | * torch.clamp(norm - (lr * lam), min=0.0)) 164 | else: 165 | raise ValueError('unsupported penalty: %s' % penalty) 166 | 167 | 168 | def regularize(network, lam, penalty): 169 | ''' 170 | Calculate regularization term for first layer weight matrix. 171 | 172 | Args: 173 | network: MLP network. 174 | penalty: one of GL (group lasso), GSGL (group sparse group lasso), 175 | H (hierarchical). 176 | ''' 177 | W = network.layers[0].weight 178 | hidden, p, lag = W.shape 179 | if penalty == 'GL': 180 | return lam * torch.sum(torch.norm(W, dim=(0, 2))) 181 | elif penalty == 'GSGL': 182 | return lam * (torch.sum(torch.norm(W, dim=(0, 2))) 183 | + torch.sum(torch.norm(W, dim=0))) 184 | elif penalty == 'H': 185 | # Lowest indices along third axis touch most lagged values. 186 | return lam * sum([torch.sum(torch.norm(W[:, :, :(i+1)], dim=(0, 2))) 187 | for i in range(lag)]) 188 | else: 189 | raise ValueError('unsupported penalty: %s' % penalty) 190 | 191 | 192 | def ridge_regularize(network, lam): 193 | '''Apply ridge penalty at all subsequent layers.''' 194 | return lam * sum([torch.sum(fc.weight ** 2) for fc in network.layers[1:]]) 195 | 196 | 197 | def restore_parameters(model, best_model): 198 | '''Move parameter values from best_model to model.''' 199 | for params, best_params in zip(model.parameters(), best_model.parameters()): 200 | params.data = best_params 201 | 202 | 203 | def train_model_gista(cmlp, X, lam, lam_ridge, lr, penalty, max_iter, 204 | check_every=100, r=0.8, lr_min=1e-8, sigma=0.5, 205 | monotone=False, m=10, lr_decay=0.5, 206 | begin_line_search=True, switch_tol=1e-3, verbose=1): 207 | ''' 208 | Train cMLP model with GISTA. 209 | 210 | Args: 211 | clstm: clstm model. 212 | X: tensor of data, shape (batch, T, p). 213 | lam: parameter for nonsmooth regularization. 214 | lam_ridge: parameter for ridge regularization on output layer. 215 | lr: learning rate. 216 | penalty: type of nonsmooth regularization. 217 | max_iter: max number of GISTA iterations. 218 | check_every: how frequently to record loss. 219 | r: for line search. 220 | lr_min: for line search. 221 | sigma: for line search. 222 | monotone: for line search. 223 | m: for line search. 224 | lr_decay: for adjusting initial learning rate of line search. 225 | begin_line_search: whether to begin with line search. 226 | switch_tol: tolerance for switching to line search. 227 | verbose: level of verbosity (0, 1, 2). 228 | ''' 229 | p = cmlp.p 230 | lag = cmlp.lag 231 | cmlp_copy = deepcopy(cmlp) 232 | loss_fn = nn.MSELoss(reduction='mean') 233 | lr_list = [lr for _ in range(p)] 234 | 235 | # Calculate full loss. 236 | mse_list = [] 237 | smooth_list = [] 238 | loss_list = [] 239 | for i in range(p): 240 | net = cmlp.networks[i] 241 | mse = loss_fn(net(X[:, :-1]), X[:, lag:, i:i+1]) 242 | ridge = ridge_regularize(net, lam_ridge) 243 | smooth = mse + ridge 244 | mse_list.append(mse) 245 | smooth_list.append(smooth) 246 | with torch.no_grad(): 247 | nonsmooth = regularize(net, lam, penalty) 248 | loss = smooth + nonsmooth 249 | loss_list.append(loss) 250 | 251 | # Set up lists for loss and mse. 252 | with torch.no_grad(): 253 | loss_mean = sum(loss_list) / p 254 | mse_mean = sum(mse_list) / p 255 | train_loss_list = [loss_mean] 256 | train_mse_list = [mse_mean] 257 | 258 | # For switching to line search. 259 | line_search = begin_line_search 260 | 261 | # For line search criterion. 262 | done = [False for _ in range(p)] 263 | assert 0 < sigma <= 1 264 | assert m > 0 265 | if not monotone: 266 | last_losses = [[loss_list[i]] for i in range(p)] 267 | 268 | for it in range(max_iter): 269 | # Backpropagate errors. 270 | sum([smooth_list[i] for i in range(p) if not done[i]]).backward() 271 | 272 | # For next iteration. 273 | new_mse_list = [] 274 | new_smooth_list = [] 275 | new_loss_list = [] 276 | 277 | # Perform GISTA step for each network. 278 | for i in range(p): 279 | # Skip if network converged. 280 | if done[i]: 281 | new_mse_list.append(mse_list[i]) 282 | new_smooth_list.append(smooth_list[i]) 283 | new_loss_list.append(loss_list[i]) 284 | continue 285 | 286 | # Prepare for line search. 287 | step = False 288 | lr_it = lr_list[i] 289 | net = cmlp.networks[i] 290 | net_copy = cmlp_copy.networks[i] 291 | 292 | while not step: 293 | # Perform tentative ISTA step. 294 | for param, temp_param in zip(net.parameters(), 295 | net_copy.parameters()): 296 | temp_param.data = param - lr_it * param.grad 297 | 298 | # Proximal update. 299 | prox_update(net_copy, lam, lr_it, penalty) 300 | 301 | # Check line search criterion. 302 | mse = loss_fn(net_copy(X[:, :-1]), X[:, lag:, i:i+1]) 303 | ridge = ridge_regularize(net_copy, lam_ridge) 304 | smooth = mse + ridge 305 | with torch.no_grad(): 306 | nonsmooth = regularize(net_copy, lam, penalty) 307 | loss = smooth + nonsmooth 308 | tol = (0.5 * sigma / lr_it) * sum( 309 | [torch.sum((param - temp_param) ** 2) 310 | for param, temp_param in 311 | zip(net.parameters(), net_copy.parameters())]) 312 | 313 | comp = loss_list[i] if monotone else max(last_losses[i]) 314 | if not line_search or (comp - loss) > tol: 315 | step = True 316 | if verbose > 1: 317 | print('Taking step, network i = %d, lr = %f' 318 | % (i, lr_it)) 319 | print('Gap = %f, tol = %f' % (comp - loss, tol)) 320 | 321 | # For next iteration. 322 | new_mse_list.append(mse) 323 | new_smooth_list.append(smooth) 324 | new_loss_list.append(loss) 325 | 326 | # Adjust initial learning rate. 327 | lr_list[i] = ( 328 | (lr_list[i] ** (1 - lr_decay)) * (lr_it ** lr_decay)) 329 | 330 | if not monotone: 331 | if len(last_losses[i]) == m: 332 | last_losses[i].pop(0) 333 | last_losses[i].append(loss) 334 | else: 335 | # Reduce learning rate. 336 | lr_it *= r 337 | if lr_it < lr_min: 338 | done[i] = True 339 | new_mse_list.append(mse_list[i]) 340 | new_smooth_list.append(smooth_list[i]) 341 | new_loss_list.append(loss_list[i]) 342 | if verbose > 0: 343 | print('Network %d converged' % (i + 1)) 344 | break 345 | 346 | # Clean up. 347 | net.zero_grad() 348 | 349 | if step: 350 | # Swap network parameters. 351 | cmlp.networks[i], cmlp_copy.networks[i] = net_copy, net 352 | 353 | # For next iteration. 354 | mse_list = new_mse_list 355 | smooth_list = new_smooth_list 356 | loss_list = new_loss_list 357 | 358 | # Check if all networks have converged. 359 | if sum(done) == p: 360 | if verbose > 0: 361 | print('Done at iteration = %d' % (it + 1)) 362 | break 363 | 364 | # Check progress. 365 | if (it + 1) % check_every == 0: 366 | with torch.no_grad(): 367 | loss_mean = sum(loss_list) / p 368 | mse_mean = sum(mse_list) / p 369 | ridge_mean = (sum(smooth_list) - sum(mse_list)) / p 370 | nonsmooth_mean = (sum(loss_list) - sum(smooth_list)) / p 371 | 372 | train_loss_list.append(loss_mean) 373 | train_mse_list.append(mse_mean) 374 | 375 | if verbose > 0: 376 | print(('-' * 10 + 'Iter = %d' + '-' * 10) % (it + 1)) 377 | print('Total loss = %f' % loss_mean) 378 | print('MSE = %f, Ridge = %f, Nonsmooth = %f' 379 | % (mse_mean, ridge_mean, nonsmooth_mean)) 380 | print('Variable usage = %.2f%%' 381 | % (100 * torch.mean(cmlp.GC().float()))) 382 | 383 | # Check whether loss has increased. 384 | if not line_search: 385 | if train_loss_list[-2] - train_loss_list[-1] < switch_tol: 386 | line_search = True 387 | if verbose > 0: 388 | print('Switching to line search') 389 | 390 | return train_loss_list, train_mse_list 391 | 392 | 393 | def train_model_adam(cmlp, X, lr, max_iter, lam=0, lam_ridge=0, penalty='H', 394 | lookback=5, check_every=100, verbose=1): 395 | '''Train model with Adam.''' 396 | lag = cmlp.lag 397 | p = X.shape[-1] 398 | loss_fn = nn.MSELoss(reduction='mean') 399 | optimizer = torch.optim.Adam(cmlp.parameters(), lr=lr) 400 | train_loss_list = [] 401 | 402 | # For early stopping. 403 | best_it = None 404 | best_loss = np.inf 405 | best_model = None 406 | 407 | for it in range(max_iter): 408 | # Calculate loss. 409 | loss = sum([loss_fn(cmlp.networks[i](X[:, :-1]), X[:, lag:, i:i+1]) 410 | for i in range(p)]) 411 | 412 | # Add penalty terms. 413 | if lam > 0: 414 | loss = loss + sum([regularize(net, lam, penalty) 415 | for net in cmlp.networks]) 416 | if lam_ridge > 0: 417 | loss = loss + sum([ridge_regularize(net, lam_ridge) 418 | for net in cmlp.networks]) 419 | 420 | # Take gradient step. 421 | loss.backward() 422 | optimizer.step() 423 | cmlp.zero_grad() 424 | 425 | # Check progress. 426 | if (it + 1) % check_every == 0: 427 | mean_loss = loss / p 428 | train_loss_list.append(mean_loss.detach()) 429 | 430 | if verbose > 0: 431 | print(('-' * 10 + 'Iter = %d' + '-' * 10) % (it + 1)) 432 | print('Loss = %f' % mean_loss) 433 | 434 | # Check for early stopping. 435 | if mean_loss < best_loss: 436 | best_loss = mean_loss 437 | best_it = it 438 | best_model = deepcopy(cmlp) 439 | elif (it - best_it) == lookback * check_every: 440 | if verbose: 441 | print('Stopping early') 442 | break 443 | 444 | # Restore best model. 445 | restore_parameters(cmlp, best_model) 446 | 447 | return train_loss_list 448 | 449 | 450 | def train_model_ista(cmlp, X, lr, max_iter, lam=0, lam_ridge=0, penalty='H', 451 | lookback=5, check_every=100, verbose=1): 452 | '''Train model with Adam.''' 453 | lag = cmlp.lag 454 | p = X.shape[-1] 455 | loss_fn = nn.MSELoss(reduction='mean') 456 | train_loss_list = [] 457 | 458 | # For early stopping. 459 | best_it = None 460 | best_loss = np.inf 461 | best_model = None 462 | 463 | # Calculate smooth error. 464 | loss = sum([loss_fn(cmlp.networks[i](X[:, :-1]), X[:, lag:, i:i+1]) 465 | for i in range(p)]) 466 | ridge = sum([ridge_regularize(net, lam_ridge) for net in cmlp.networks]) 467 | smooth = loss + ridge 468 | 469 | for it in range(max_iter): 470 | # Take gradient step. 471 | smooth.backward() 472 | for param in cmlp.parameters(): 473 | param.data = param - lr * param.grad 474 | 475 | # Take prox step. 476 | if lam > 0: 477 | for net in cmlp.networks: 478 | prox_update(net, lam, lr, penalty) 479 | 480 | cmlp.zero_grad() 481 | 482 | # Calculate loss for next iteration. 483 | loss = sum([loss_fn(cmlp.networks[i](X[:, :-1]), X[:, lag:, i:i+1]) 484 | for i in range(p)]) 485 | ridge = sum([ridge_regularize(net, lam_ridge) for net in cmlp.networks]) 486 | smooth = loss + ridge 487 | 488 | # Check progress. 489 | if (it + 1) % check_every == 0: 490 | # Add nonsmooth penalty. 491 | nonsmooth = sum([regularize(net, lam, penalty) 492 | for net in cmlp.networks]) 493 | mean_loss = (smooth + nonsmooth) / p 494 | train_loss_list.append(mean_loss.detach()) 495 | 496 | if verbose > 0: 497 | print(('-' * 10 + 'Iter = %d' + '-' * 10) % (it + 1)) 498 | print('Loss = %f' % mean_loss) 499 | print('Variable usage = %.2f%%' 500 | % (100 * torch.mean(cmlp.GC().float()))) 501 | 502 | # Check for early stopping. 503 | if mean_loss < best_loss: 504 | best_loss = mean_loss 505 | best_it = it 506 | best_model = deepcopy(cmlp) 507 | elif (it - best_it) == lookback * check_every: 508 | if verbose: 509 | print('Stopping early') 510 | break 511 | 512 | # Restore best model. 513 | restore_parameters(cmlp, best_model) 514 | 515 | return train_loss_list 516 | 517 | 518 | def train_unregularized(cmlp, X, lr, max_iter, lookback=5, check_every=100, 519 | verbose=1): 520 | '''Train model with Adam and no regularization.''' 521 | lag = cmlp.lag 522 | p = X.shape[-1] 523 | loss_fn = nn.MSELoss(reduction='mean') 524 | optimizer = torch.optim.Adam(cmlp.parameters(), lr=lr) 525 | train_loss_list = [] 526 | 527 | # For early stopping. 528 | best_it = None 529 | best_loss = np.inf 530 | best_model = None 531 | 532 | for it in range(max_iter): 533 | # Calculate loss. 534 | pred = cmlp(X[:, :-1]) 535 | loss = sum([loss_fn(pred[:, :, i], X[:, lag:, i]) for i in range(p)]) 536 | 537 | # Take gradient step. 538 | loss.backward() 539 | optimizer.step() 540 | cmlp.zero_grad() 541 | 542 | # Check progress. 543 | if (it + 1) % check_every == 0: 544 | mean_loss = loss / p 545 | train_loss_list.append(mean_loss.detach()) 546 | 547 | if verbose > 0: 548 | print(('-' * 10 + 'Iter = %d' + '-' * 10) % (it + 1)) 549 | print('Loss = %f' % mean_loss) 550 | 551 | # Check for early stopping. 552 | if mean_loss < best_loss: 553 | best_loss = mean_loss 554 | best_it = it 555 | best_model = deepcopy(cmlp) 556 | elif (it - best_it) == lookback * check_every: 557 | if verbose: 558 | print('Stopping early') 559 | break 560 | 561 | # Restore best model. 562 | restore_parameters(cmlp, best_model) 563 | 564 | return train_loss_list 565 | --------------------------------------------------------------------------------