├── .gitignore ├── README.md ├── convlstmcell.py ├── debug.py ├── kitti_data.py ├── kitti_test.py ├── kitti_train.py └── prednet.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .vscode 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **PredNet** implementation of PyTorch. 2 | 3 | ### Details 4 | "Deep Predictive Coding Networks for Video Prediction and Unsupervised Learning"(https://arxiv.org/abs/1605.08104) 5 | 6 | The PredNet is a deep recurrent convolutional neural network that is inspired by the neuroscience concept of predictive coding (Rao and Ballard, 1999; Friston, 2005) 7 | 8 | Original paper's [code](https://github.com/coxlab/prednet) is writen in Keras. Examples and project website can be found [here](https://coxlab.github.io/prednet/). 9 | 10 | 11 | ConvLSTMCell is borrowed from https://gist.github.com/Kaixhin/57901e91e5c5a8bac3eb0cbbdd3aba81 12 | 13 | ### Training data 14 | The preprocessed KITTI data is used which can be accessed using `downlaod_data.sh` from https://github.com/coxlab/prednet -------------------------------------------------------------------------------- /convlstmcell.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | from torch.nn import Parameter 5 | from torch.nn import functional as F 6 | from torch.autograd import Variable 7 | from torch.nn.modules.utils import _pair 8 | 9 | # https://gist.github.com/Kaixhin/57901e91e5c5a8bac3eb0cbbdd3aba81 10 | 11 | class ConvLSTMCell(nn.Module): 12 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1, dilation=1, groups=1, bias=True): 13 | super(ConvLSTMCell, self).__init__() 14 | if in_channels % groups != 0: 15 | raise ValueError('in_channels must be divisible by groups') 16 | if out_channels % groups != 0: 17 | raise ValueError('out_channels must be divisible by groups') 18 | kernel_size = _pair(kernel_size) 19 | stride = _pair(stride) 20 | padding = _pair(padding) 21 | dilation = _pair(dilation) 22 | self.in_channels = in_channels 23 | self.out_channels = out_channels 24 | self.kernel_size = kernel_size 25 | self.stride = stride 26 | self.padding = padding 27 | self.padding_h = tuple( 28 | k // 2 for k, s, p, d in zip(kernel_size, stride, padding, dilation)) 29 | self.dilation = dilation 30 | self.groups = groups 31 | self.weight_ih = Parameter(torch.Tensor( 32 | 4 * out_channels, in_channels // groups, *kernel_size)) 33 | self.weight_hh = Parameter(torch.Tensor( 34 | 4 * out_channels, out_channels // groups, *kernel_size)) 35 | self.weight_ch = Parameter(torch.Tensor( 36 | 3 * out_channels, out_channels // groups, *kernel_size)) 37 | if bias: 38 | self.bias_ih = Parameter(torch.Tensor(4 * out_channels)) 39 | self.bias_hh = Parameter(torch.Tensor(4 * out_channels)) 40 | self.bias_ch = Parameter(torch.Tensor(3 * out_channels)) 41 | else: 42 | self.register_parameter('bias_ih', None) 43 | self.register_parameter('bias_hh', None) 44 | self.register_parameter('bias_ch', None) 45 | self.register_buffer('wc_blank', torch.zeros(1, 1, 1, 1)) 46 | self.reset_parameters() 47 | 48 | def reset_parameters(self): 49 | n = 4 * self.in_channels 50 | for k in self.kernel_size: 51 | n *= k 52 | stdv = 1. / math.sqrt(n) 53 | self.weight_ih.data.uniform_(-stdv, stdv) 54 | self.weight_hh.data.uniform_(-stdv, stdv) 55 | self.weight_ch.data.uniform_(-stdv, stdv) 56 | if self.bias_ih is not None: 57 | self.bias_ih.data.uniform_(-stdv, stdv) 58 | self.bias_hh.data.uniform_(-stdv, stdv) 59 | self.bias_ch.data.uniform_(-stdv, stdv) 60 | 61 | def forward(self, input, hx): 62 | h_0, c_0 = hx 63 | wx = F.conv2d(input, self.weight_ih, self.bias_ih, 64 | self.stride, self.padding, self.dilation, self.groups) 65 | 66 | wh = F.conv2d(h_0, self.weight_hh, self.bias_hh, self.stride, 67 | self.padding_h, self.dilation, self.groups) 68 | 69 | # Cell uses a Hadamard product instead of a convolution? 70 | wc = F.conv2d(c_0, self.weight_ch, self.bias_ch, self.stride, 71 | self.padding_h, self.dilation, self.groups) 72 | 73 | wxhc = wx + wh + torch.cat((wc[:, :2 * self.out_channels], Variable(self.wc_blank).expand( 74 | wc.size(0), wc.size(1) // 3, wc.size(2), wc.size(3)), wc[:, 2 * self.out_channels:]), 1) 75 | 76 | i = F.sigmoid(wxhc[:, :self.out_channels]) 77 | f = F.sigmoid(wxhc[:, self.out_channels:2 * self.out_channels]) 78 | g = F.tanh(wxhc[:, 2 * self.out_channels:3 * self.out_channels]) 79 | o = F.sigmoid(wxhc[:, 3 * self.out_channels:]) 80 | 81 | c_1 = f * c_0 + i * g 82 | h_1 = o * F.tanh(c_1) 83 | return h_1, (h_1, c_1) 84 | -------------------------------------------------------------------------------- /debug.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def info(prefix, var): 4 | print('-------{}----------'.format(prefix)) 5 | if isinstance(var, torch.autograd.variable.Variable): 6 | print('Variable:') 7 | print('size: ', var.data.size()) 8 | print('data type: ', type(var.data)) 9 | elif isinstance(var, torch.FloatTensor) or isinstance(var, torch.cuda.FloatTensor): 10 | print('Tensor:') 11 | print('size: ', var.size()) 12 | print('type: ', type(var)) 13 | else: 14 | print(type(var)) -------------------------------------------------------------------------------- /kitti_data.py: -------------------------------------------------------------------------------- 1 | import hickle as hkl 2 | 3 | import torch 4 | import torch.utils.data as data 5 | 6 | 7 | 8 | class KITTI(data.Dataset): 9 | def __init__(self, datafile, sourcefile, nt): 10 | self.datafile = datafile 11 | self.sourcefile = sourcefile 12 | self.X = hkl.load(self.datafile) 13 | self.sources = hkl.load(self.sourcefile) 14 | self.nt = nt 15 | cur_loc = 0 16 | possible_starts = [] 17 | while cur_loc < self.X.shape[0] - self.nt + 1: 18 | if self.sources[cur_loc] == self.sources[cur_loc + self.nt - 1]: 19 | possible_starts.append(cur_loc) 20 | cur_loc += self.nt 21 | else: 22 | cur_loc += 1 23 | self.possible_starts = possible_starts 24 | 25 | def __getitem__(self, index): 26 | loc = self.possible_starts[index] 27 | return self.X[loc:loc+self.nt] 28 | 29 | 30 | def __len__(self): 31 | return len(self.possible_starts) -------------------------------------------------------------------------------- /kitti_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import hickle as hkl 5 | 6 | from torch.utils.data import DataLoader 7 | from torch.autograd import Variable 8 | from kitti_data import KITTI 9 | from prednet import PredNet 10 | 11 | import torchvision 12 | 13 | def save_image(tensor, filename, nrow=8, padding=2, 14 | normalize=False, range=None, scale_each=False, pad_value=0): 15 | from PIL import Image 16 | im = Image.fromarray(np.rollaxis(tensor.numpy(), 0, 3)) 17 | im.save(filename) 18 | from scipy.misc import imshow, imsave 19 | 20 | batch_size = 16 21 | A_channels = (3, 48, 96, 192) 22 | R_channels = (3, 48, 96, 192) 23 | 24 | DATA_DIR = '/media/lei/000F426D0004CCF4/datasets/kitti_data' 25 | test_file = os.path.join(DATA_DIR, 'X_test.hkl') 26 | test_sources = os.path.join(DATA_DIR, 'sources_test.hkl') 27 | 28 | nt = 10 29 | 30 | kitti_test = KITTI(test_file, test_sources, nt) 31 | 32 | test_loader = DataLoader(kitti_test, batch_size=batch_size, shuffle=False) 33 | 34 | model = PredNet(R_channels, A_channels, output_mode='prediction') 35 | model.load_state_dict(torch.load('training.pt')) 36 | 37 | if torch.cuda.is_available(): 38 | print('Using GPU.') 39 | model.cuda() 40 | 41 | for i, inputs in enumerate(test_loader): 42 | inputs = inputs.permute(0, 1, 4, 2, 3) # batch x time_steps x channel x width x height 43 | inputs = Variable(inputs.cuda()) 44 | origin = inputs.data.cpu().byte()[:, nt-1] 45 | print('origin:') 46 | print(type(origin)) 47 | print(origin.size()) 48 | 49 | print('predicted:') 50 | pred = model(inputs) 51 | pred = pred.data.cpu().byte() 52 | print(type(pred)) 53 | print(pred.size()) 54 | origin = torchvision.utils.make_grid(origin, nrow=4) 55 | pred = torchvision.utils.make_grid(pred, nrow=4) 56 | save_image(origin, 'origin.jpg') 57 | save_image(pred, 'predicted.jpg') 58 | break 59 | 60 | -------------------------------------------------------------------------------- /kitti_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | from torch.utils.data import DataLoader 7 | 8 | from kitti_data import KITTI 9 | from prednet import PredNet 10 | 11 | from debug import info 12 | 13 | 14 | num_epochs = 150 15 | batch_size = 16 16 | A_channels = (3, 48, 96, 192) 17 | R_channels = (3, 48, 96, 192) 18 | lr = 0.001 # if epoch < 75 else 0.0001 19 | nt = 10 # num of time steps 20 | 21 | layer_loss_weights = Variable(torch.FloatTensor([[1.], [0.], [0.], [0.]]).cuda()) 22 | time_loss_weights = 1./(nt - 1) * torch.ones(nt, 1) 23 | time_loss_weights[0] = 0 24 | time_loss_weights = Variable(time_loss_weights.cuda()) 25 | 26 | DATA_DIR = '/media/lei/000F426D0004CCF4/datasets/kitti_data' 27 | 28 | train_file = os.path.join(DATA_DIR, 'X_train.hkl') 29 | train_sources = os.path.join(DATA_DIR, 'sources_train.hkl') 30 | val_file = os.path.join(DATA_DIR, 'X_val.hkl') 31 | val_sources = os.path.join(DATA_DIR, 'sources_val.hkl') 32 | 33 | 34 | kitti_train = KITTI(train_file, train_sources, nt) 35 | kitti_val = KITTI(val_file, val_sources, nt) 36 | 37 | train_loader = DataLoader(kitti_train, batch_size=batch_size, shuffle=True) 38 | val_loader = DataLoader(kitti_val, batch_size=batch_size, shuffle=True) 39 | 40 | model = PredNet(R_channels, A_channels, output_mode='error') 41 | if torch.cuda.is_available(): 42 | print('Using GPU.') 43 | model.cuda() 44 | 45 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 46 | 47 | def lr_scheduler(optimizer, epoch): 48 | if epoch < num_epochs //2: 49 | return optimizer 50 | else: 51 | for param_group in optimizer.param_groups: 52 | param_group['lr'] = 0.0001 53 | return optimizer 54 | 55 | 56 | 57 | for epoch in range(num_epochs): 58 | optimizer = lr_scheduler(optimizer, epoch) 59 | for i, inputs in enumerate(train_loader): 60 | inputs = inputs.permute(0, 1, 4, 2, 3) # batch x time_steps x channel x width x height 61 | inputs = Variable(inputs.cuda()) 62 | errors = model(inputs) # batch x n_layers x nt 63 | loc_batch = errors.size(0) 64 | errors = torch.mm(errors.view(-1, nt), time_loss_weights) # batch*n_layers x 1 65 | errors = torch.mm(errors.view(loc_batch, -1), layer_loss_weights) 66 | errors = torch.mean(errors) 67 | 68 | optimizer.zero_grad() 69 | 70 | errors.backward() 71 | 72 | optimizer.step() 73 | if i%10 == 0: 74 | print('Epoch: {}/{}, step: {}/{}, errors: {}'.format(epoch, num_epochs, i, len(kitti_train)//batch_size, errors.data[0])) 75 | 76 | torch.save(model.state_dict(), 'training.pt') 77 | -------------------------------------------------------------------------------- /prednet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from convlstmcell import ConvLSTMCell 5 | from torch.autograd import Variable 6 | 7 | 8 | from debug import info 9 | 10 | 11 | class PredNet(nn.Module): 12 | def __init__(self, R_channels, A_channels, output_mode='error'): 13 | super(PredNet, self).__init__() 14 | self.r_channels = R_channels + (0, ) # for convenience 15 | self.a_channels = A_channels 16 | self.n_layers = len(R_channels) 17 | self.output_mode = output_mode 18 | 19 | default_output_modes = ['prediction', 'error'] 20 | assert output_mode in default_output_modes, 'Invalid output_mode: ' + str(output_mode) 21 | 22 | for i in range(self.n_layers): 23 | cell = ConvLSTMCell(2 * self.a_channels[i] + self.r_channels[i+1], self.r_channels[i], 24 | (3, 3)) 25 | setattr(self, 'cell{}'.format(i), cell) 26 | 27 | for i in range(self.n_layers): 28 | conv = nn.Sequential(nn.Conv2d(self.r_channels[i], self.a_channels[i], 3, padding=1), nn.ReLU()) 29 | if i == 0: 30 | conv.add_module('satlu', SatLU()) 31 | setattr(self, 'conv{}'.format(i), conv) 32 | 33 | 34 | self.upsample = nn.Upsample(scale_factor=2) 35 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 36 | 37 | for l in range(self.n_layers - 1): 38 | update_A = nn.Sequential(nn.Conv2d(2* self.a_channels[l], self.a_channels[l+1], (3, 3), padding=1), self.maxpool) 39 | setattr(self, 'update_A{}'.format(l), update_A) 40 | 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self): 44 | for l in range(self.n_layers): 45 | cell = getattr(self, 'cell{}'.format(l)) 46 | cell.reset_parameters() 47 | 48 | def forward(self, input): 49 | 50 | R_seq = [None] * self.n_layers 51 | H_seq = [None] * self.n_layers 52 | E_seq = [None] * self.n_layers 53 | 54 | w, h = input.size(-2), input.size(-1) 55 | batch_size = input.size(0) 56 | 57 | for l in range(self.n_layers): 58 | E_seq[l] = Variable(torch.zeros(batch_size, 2*self.a_channels[l], w, h)).cuda() 59 | R_seq[l] = Variable(torch.zeros(batch_size, self.r_channels[l], w, h)).cuda() 60 | w = w//2 61 | h = h//2 62 | time_steps = input.size(1) 63 | total_error = [] 64 | 65 | for t in range(time_steps): 66 | A = input[:,t] 67 | A = A.type(torch.cuda.FloatTensor) 68 | 69 | for l in reversed(range(self.n_layers)): 70 | cell = getattr(self, 'cell{}'.format(l)) 71 | if t == 0: 72 | E = E_seq[l] 73 | R = R_seq[l] 74 | hx = (R, R) 75 | else: 76 | E = E_seq[l] 77 | R = R_seq[l] 78 | hx = H_seq[l] 79 | if l == self.n_layers - 1: 80 | R, hx = cell(E, hx) 81 | else: 82 | tmp = torch.cat((E, self.upsample(R_seq[l+1])), 1) 83 | R, hx = cell(tmp, hx) 84 | R_seq[l] = R 85 | H_seq[l] = hx 86 | 87 | 88 | for l in range(self.n_layers): 89 | conv = getattr(self, 'conv{}'.format(l)) 90 | A_hat = conv(R_seq[l]) 91 | if l == 0: 92 | frame_prediction = A_hat 93 | pos = F.relu(A_hat - A) 94 | neg = F.relu(A - A_hat) 95 | E = torch.cat([pos, neg],1) 96 | E_seq[l] = E 97 | if l < self.n_layers - 1: 98 | update_A = getattr(self, 'update_A{}'.format(l)) 99 | A = update_A(E) 100 | if self.output_mode == 'error': 101 | mean_error = torch.cat([torch.mean(e.view(e.size(0), -1), 1, keepdim=True) for e in E_seq], 1) 102 | # batch x n_layers 103 | total_error.append(mean_error) 104 | 105 | if self.output_mode == 'error': 106 | return torch.stack(total_error, 2) # batch x n_layers x nt 107 | elif self.output_mode == 'prediction': 108 | return frame_prediction 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | class SatLU(nn.Module): 120 | 121 | def __init__(self, lower=0, upper=255, inplace=False): 122 | super(SatLU, self).__init__() 123 | self.lower = lower 124 | self.upper = upper 125 | self.inplace = inplace 126 | 127 | def forward(self, input): 128 | return F.hardtanh(input, self.lower, self.upper, self.inplace) 129 | 130 | 131 | def __repr__(self): 132 | inplace_str = ', inplace' if self.inplace else '' 133 | return self.__class__.__name__ + ' ('\ 134 | + 'min_val=' + str(self.lower) \ 135 | + ', max_val=' + str(self.upper) \ 136 | + inplace_str + ')' --------------------------------------------------------------------------------