├── README.md ├── __pycache__ └── utils.cpython-38.pyc ├── denoise_sample.py ├── loss.py ├── networks ├── ConvLSTM.py ├── DnCNN.py ├── EventDenoiser.py ├── EventVLAD.py ├── SubBlocks.py ├── UNet.py ├── __init__.py ├── __pycache__ │ ├── ConvLSTM.cpython-38.pyc │ ├── DnCNN.cpython-38.pyc │ ├── EventDenoiser.cpython-38.pyc │ ├── SubBlocks.cpython-38.pyc │ ├── UNet.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── submodules.cpython-38.pyc │ ├── unet_rpg.cpython-38.pyc │ └── vgg16.cpython-38.pyc ├── base │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── base_model.cpython-36.pyc │ │ ├── base_model.cpython-37.pyc │ │ └── base_model.cpython-38.pyc │ └── base_model.py ├── image_transformer.py ├── netvlad.py ├── submodules.py ├── unet_rpg.py └── vgg16.py ├── preprocess ├── process_dvs.py └── process_img.py ├── sample ├── d0.png ├── d1.png ├── d2.png ├── d_gt.png └── result.png └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | ## EventVLAD (IROS 2021) 2 | 3 | This repo contains codebase for our paper presented in IROS 2021 [[pdf](https://ieeexplore.ieee.org/document/9635907)], 4 | 5 | "EventVLAD: Visual place recognition with reconstructed edges from event cameras". 6 | 7 | ### Pretrained weights 8 | 9 | - Denoiser 10 | 11 | We provide pretrained weights and a minimal example for event-based denoiser used in our module. You may use the processed outputs from our event denoiser to other VPR pipelines such as [NetVLAD](https://github.com/Nanne/pytorch-NetVlad). 12 | 13 | (carla-pretrained) 14 | https://drive.google.com/file/d/1D1tHHSRd-2iVfD4GuEz0jHDlh7evkzf6/view?usp=sharing 15 | 16 | (brisbane-pretrained) 17 | https://drive.google.com/file/d/1xdoGI7vmNelaR_D9-FUk5SbB3webqa5c/view?usp=sharing 18 | 19 | - Encoder 20 | 21 | The pretrained VGG16 encoder weight for event-based place recognition can be downloaded from: 22 | 23 | https://drive.google.com/file/d/1rSIhH1pk8ADxfqYQXoos_hTuWyfiWSu3/view?usp=sharing 24 | 25 | 26 | ### Run your example 27 | 28 | - Preprocess ViViD, create event frames 29 | 30 | ```python process_img.py seqname dst_folder``` 31 | 32 | Running the above sample with the downloaded bagfile, will create a directory (dst_folder) and save images / gps coord. 33 | 34 | ```python process_dvs.py seqname path_prefix``` 35 | 36 | Running the above sample with the preprocessed image directory and bagfile, will create a set of event images based using parameters (5ms, 1% threshold). 37 | 38 | - Event denoiser 39 | 40 | ```python denoise_sample.py``` 41 | 42 | 43 | Running the above sample requires a consequent three event-generated frame images (in our sample case, we provided event-image consisted with events during 1.6ms for each frames). The samples are located under 'samples' folder. 44 | 45 | As a result, you will see the groundtruth (generated from ideal camera in simulation), noisy events from simple accumulation, and the results of denoising module, respectively. 46 | 47 |  48 | 49 | ### bibtex: 50 | 51 | ``` 52 | @inproceedings{lee2021eventvlad, 53 | title={EventVLAD: Visual place recognition with reconstructed edges from event cameras}, 54 | author={Lee, Alex Junho and Kim, Ayoung}, 55 | booktitle={2021 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)}, 56 | pages={2247--2252}, 57 | year={2021}, 58 | organization={IEEE} 59 | } 60 | ``` 61 | -------------------------------------------------------------------------------- /__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexjunholee/EventVLAD/b02f7c5a5758be08e39f5b6bcd51d7c7720b3251/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /denoise_sample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import cv2 5 | import numpy as np 6 | import torch 7 | from networks import EventDenoiser 8 | from skimage.metrics import peak_signal_noise_ratio, structural_similarity 9 | from skimage import img_as_float, img_as_ubyte, img_as_float32 10 | from utils import load_state_dict_cpu 11 | from matplotlib import pyplot as plt 12 | import time 13 | import png 14 | 15 | use_gpu = True 16 | dep_U = 4 17 | 18 | testpath = 'sample' 19 | test_imgs = ['d0', 'd1', 'd2'] 20 | truth_edge = 'd_gt' 21 | 22 | # load the pretrained model 23 | print('Loading the Model') 24 | checkpoint = torch.load('./denoiser_carla') 25 | 26 | net = EventDenoiser(3, slope=0.2, dep_U=5, dep_S=5) 27 | if use_gpu: 28 | net = torch.nn.DataParallel(net).cuda() 29 | net.load_state_dict(checkpoint) 30 | else: 31 | load_state_dict_cpu(net, checkpoint) 32 | net.eval() 33 | 34 | # load images 35 | im_gt_path = os.path.join(testpath, truth_edge+'.png') 36 | im_gt = cv2.imread(im_gt_path,0) 37 | 38 | H, W = im_gt.shape 39 | if H % 2**dep_U != 0: 40 | H -= H % 2**dep_U 41 | if W % 2**dep_U != 0: 42 | W -= W % 2**dep_U 43 | im_gt = im_gt[:H, :W] 44 | im_gt = cv2.resize(im_gt, (256, 256)) 45 | im_gt = img_as_float32(im_gt[:,:,np.newaxis]) 46 | im_gt = torch.from_numpy(im_gt.transpose((2,0,1)))[np.newaxis,] 47 | 48 | im_0 = cv2.imread(os.path.join(testpath, test_imgs[0]+'.png'), 0) 49 | im_1 = cv2.imread(os.path.join(testpath, test_imgs[1]+'.png'), 0) 50 | im_2 = cv2.imread(os.path.join(testpath, test_imgs[2]+'.png'), 0) 51 | 52 | im_0 = cv2.resize(im_0, (256, 256)) 53 | im_1 = cv2.resize(im_1, (256, 256)) 54 | im_2 = cv2.resize(im_2, (256, 256)) 55 | 56 | im_0 = img_as_float32(im_0[:,:,np.newaxis]) 57 | im_1 = img_as_float32(im_1[:,:,np.newaxis]) 58 | im_2 = img_as_float32(im_2[:,:,np.newaxis]) 59 | 60 | im_0 = cv2.rotate(im_0, cv2.ROTATE_180)[:,:,np.newaxis] 61 | im_1 = cv2.rotate(im_1, cv2.ROTATE_180)[:,:,np.newaxis] 62 | im_2 = cv2.rotate(im_2, cv2.ROTATE_180)[:,:,np.newaxis] 63 | 64 | im_0 = torch.from_numpy(im_0.transpose((2,0,1))) 65 | im_1 = torch.from_numpy(im_1.transpose((2,0,1))) 66 | im_2 = torch.from_numpy(im_2.transpose((2,0,1))) 67 | 68 | im_noisy = torch.cat((im_0,im_1,im_2),0)[np.newaxis,] 69 | if use_gpu: 70 | im_noisy = im_noisy.cuda() 71 | print('Begin Testing on GPU') 72 | else: 73 | print('Begin Testing on CPU') 74 | with torch.autograd.set_grad_enabled(False): 75 | torch.cuda.synchronize() 76 | tic = time.perf_counter() 77 | img_estim = net(im_noisy) 78 | torch.cuda.synchronize() 79 | toc = time.perf_counter() 80 | outimg = img_estim.cpu().numpy() 81 | if use_gpu: 82 | im_noisy = im_noisy.cpu().numpy() 83 | else: 84 | im_noisy = im_noisy.numpy() 85 | #im_noisy = im_noisy[:,1,] 86 | im_noisy = np.mean(im_noisy,1) 87 | im_denoise = outimg[:,0,] 88 | im_denoise = np.transpose(im_denoise.squeeze(), (0,1)) 89 | im_denoise = cv2.rotate(img_as_ubyte(im_denoise.clip(0,1)),cv2.ROTATE_180) 90 | 91 | im_noisy = np.transpose(im_noisy.squeeze(), (0,1)) 92 | im_noisy = cv2.rotate(img_as_ubyte(im_noisy.clip(0,1)),cv2.ROTATE_180) 93 | im_gt = img_as_ubyte(im_gt.squeeze()) 94 | psnr_val = peak_signal_noise_ratio(im_gt, im_denoise, data_range=255) 95 | ssim_val = structural_similarity(im_gt, im_denoise, data_range=255, multichannel=False) 96 | 97 | print('PSNR={:5.2f}, SSIM={:7.4f}, time={:.4f}'.format(psnr_val, ssim_val, toc-tic)) 98 | plt.subplot(131) 99 | plt.imshow(im_gt) 100 | plt.title('Groundtruth') 101 | plt.subplot(132) 102 | plt.imshow(im_noisy) 103 | plt.title('Noisy Image') 104 | plt.subplot(133) 105 | plt.imshow(im_denoise) 106 | plt.title('Denoised Image') 107 | plt.show() 108 | 109 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from math import pi, log 7 | from ssim import msssim, ssim 8 | 9 | def loss_denoise(evterr_day, im_gt): 10 | ''' 11 | Size of Inputs: 12 | evterr_day = N * 2 * W * H 13 | im_gt = N * 1 * W * H 14 | ''' 15 | #for training phase 1 (recon) 16 | day_event = evterr_day[:,0,] 17 | ssim_recon = 1-msssim(im_gt.clamp(min=0, max=1), day_event[:,np.newaxis,].clamp(min=0, max=1), window_size=11, size_average=True) 18 | loss = ssim_recon 19 | 20 | #for training phase 2 (masking) 21 | 22 | # day_error = evterr_day[:,1,] 23 | # day_corrt = day_event.detach() * day_error 24 | # ssim_loss_day = 1-msssim(im_gt[:,0:1,].clamp(min=0, max=1), day_corrt[:,np.newaxis,].clamp(min=0, max=1), window_size=11, size_average=True) 25 | # l1_loss = torch.sum(torch.abs(im_gt[:,0,] - day_event)) 26 | # l1_mask_loss = torch.sum(torch.abs(day_corrt - im_gt[:,0,])) 27 | # loss = 0.1*ssim_loss_day + 0.1*ssim_mask_loss + l1_mask_loss + l1_loss 28 | 29 | return loss 30 | 31 | class HardTripletLoss(nn.Module): 32 | """Hard/Hardest Triplet Loss 33 | (pytorch implementation of https://omoindrot.github.io/triplet-loss) 34 | 35 | For each anchor, we get the hardest positive and hardest negative to form a triplet. 36 | """ 37 | def __init__(self, margin=0.1, hardest=False, squared=False): 38 | """ 39 | Args: 40 | margin: margin for triplet loss 41 | hardest: If true, loss is considered only hardest triplets. 42 | squared: If true, output is the pairwise squared euclidean distance matrix. 43 | If false, output is the pairwise euclidean distance matrix. 44 | """ 45 | super(HardTripletLoss, self).__init__() 46 | self.margin = margin 47 | self.hardest = hardest 48 | self.squared = squared 49 | 50 | def forward(self, embeddings, labels): 51 | """ 52 | Args: 53 | labels: labels of the batch, of size (batch_size,) 54 | embeddings: tensor of shape (batch_size, embed_dim) 55 | 56 | Returns: 57 | triplet_loss: scalar tensor containing the triplet loss 58 | """ 59 | embedding_vec = torch.reshape(embeddings,(64,1000)) 60 | pairwise_dist = _pairwise_distance(embedding_vec, squared=self.squared) 61 | 62 | if self.hardest: 63 | # Get the hardest positive pairs 64 | mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float() 65 | valid_positive_dist = pairwise_dist * mask_anchor_positive 66 | hardest_positive_dist, _ = torch.max(valid_positive_dist, dim=1, keepdim=True) 67 | 68 | # Get the hardest negative pairs 69 | mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float() 70 | max_anchor_negative_dist, _ = torch.max(pairwise_dist, dim=1, keepdim=True) 71 | anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative) 72 | hardest_negative_dist, _ = torch.min(anchor_negative_dist, dim=1, keepdim=True) 73 | 74 | # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss 75 | triplet_loss = F.relu(hardest_positive_dist - hardest_negative_dist + 0.1) 76 | triplet_loss = torch.mean(triplet_loss) 77 | else: 78 | anc_pos_dist = pairwise_dist.unsqueeze(dim=2) 79 | anc_neg_dist = pairwise_dist.unsqueeze(dim=1) 80 | 81 | # Compute a 3D tensor of size (batch_size, batch_size, batch_size) 82 | # triplet_loss[i, j, k] will contain the triplet loss of anc=i, pos=j, neg=k 83 | # Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1) 84 | # and the 2nd (batch_size, 1, batch_size) 85 | loss = anc_pos_dist - anc_neg_dist + self.margin 86 | 87 | mask = _get_triplet_mask(labels).float() 88 | triplet_loss = loss * mask 89 | 90 | # Remove negative losses (i.e. the easy triplets) 91 | triplet_loss = F.relu(triplet_loss) 92 | 93 | # Count number of hard triplets (where triplet_loss > 0) 94 | hard_triplets = torch.gt(triplet_loss, 1e-16).float() 95 | num_hard_triplets = torch.sum(hard_triplets) 96 | 97 | triplet_loss = torch.sum(triplet_loss) / (num_hard_triplets + 1e-16) 98 | 99 | return triplet_loss 100 | 101 | 102 | def _pairwise_distance(x, squared=False, eps=1e-16): 103 | # Compute the 2D matrix of distances between all the embeddings. 104 | cor_mat = torch.matmul(x, x.t()) 105 | norm_mat = cor_mat.diag() 106 | distances = norm_mat.unsqueeze(1) - 2 * cor_mat + norm_mat.unsqueeze(0) 107 | distances = F.relu(distances) 108 | 109 | if not squared: 110 | mask = torch.eq(distances, 0.0).float() 111 | distances = distances + mask * eps 112 | distances = torch.sqrt(distances) 113 | distances = distances * (1.0 - mask) 114 | return distances 115 | 116 | 117 | def _get_anchor_positive_triplet_mask(labels): 118 | # Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label. 119 | 120 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 121 | 122 | indices_not_equal = torch.eye(labels.shape[0]).to(device).byte() ^ 1 123 | 124 | # Check if labels[i] == labels[j] 125 | labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1) 126 | 127 | mask = indices_not_equal * labels_equal 128 | 129 | return mask 130 | 131 | 132 | def _get_anchor_negative_triplet_mask(labels): 133 | # Return a 2D mask where mask[a, n] is True iff a and n have distinct labels. 134 | 135 | # Check if labels[i] != labels[k] 136 | labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1) 137 | mask = labels_equal ^ 1 138 | 139 | return mask 140 | 141 | 142 | def _get_triplet_mask(labels): 143 | """Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid. 144 | 145 | A triplet (i, j, k) is valid if: 146 | - i, j, k are distinct 147 | - labels[i] == labels[j] and labels[i] != labels[k] 148 | """ 149 | device = torch.device("cuda:0") 150 | 151 | # Check that i, j and k are distinct 152 | indices_not_same = torch.eye(labels.shape[0]).to(device).byte() ^ 1 153 | i_not_equal_j = torch.unsqueeze(indices_not_same, 2) 154 | i_not_equal_k = torch.unsqueeze(indices_not_same, 1) 155 | j_not_equal_k = torch.unsqueeze(indices_not_same, 0) 156 | distinct_indices = i_not_equal_j * i_not_equal_k * j_not_equal_k 157 | 158 | # Check if labels[i] == labels[j] and labels[i] != labels[k] 159 | label_equal = torch.eq(torch.unsqueeze(labels, 0), torch.unsqueeze(labels, 1)).long().cuda() 160 | i_equal_j = torch.unsqueeze(label_equal, 2) 161 | i_equal_k = torch.unsqueeze(label_equal, 1) 162 | valid_labels = i_equal_j * (i_equal_k ^ 1) 163 | 164 | mask = distinct_indices * valid_labels # Combine the two masks 165 | 166 | return mask 167 | -------------------------------------------------------------------------------- /networks/ConvLSTM.py: -------------------------------------------------------------------------------- 1 | from networks.base import BaseModel 2 | import torch.nn as nn 3 | import torch 4 | import numpy as np 5 | from .unet_rpg import UNet, UNetRecurrent 6 | from os.path import join 7 | from .submodules import ConvLSTM, ResidualBlock, ConvLayer, UpsampleConvLayer, TransposedConvLayer 8 | 9 | 10 | class BaseE2VID(BaseModel): 11 | def __init__(self, config): 12 | super().__init__(config) 13 | 14 | try: 15 | self.skip_type = str(config['skip_type']) 16 | except KeyError: 17 | self.skip_type = 'sum' 18 | 19 | try: 20 | self.num_encoders = int(config['num_encoders']) 21 | except KeyError: 22 | self.num_encoders = 4 23 | 24 | try: 25 | self.base_num_channels = int(config['base_num_channels']) 26 | except KeyError: 27 | self.base_num_channels = 32 28 | 29 | try: 30 | self.num_residual_blocks = int(config['num_residual_blocks']) 31 | except KeyError: 32 | self.num_residual_blocks = 2 33 | 34 | try: 35 | self.norm = str(config['norm']) 36 | except KeyError: 37 | self.norm = None 38 | 39 | try: 40 | self.use_upsample_conv = bool(config['use_upsample_conv']) 41 | except KeyError: 42 | self.use_upsample_conv = True 43 | 44 | class RecurrUNet(BaseE2VID): 45 | """ 46 | Recurrent, UNet-like architecture where each encoder is followed by a ConvLSTM or ConvGRU. 47 | """ 48 | 49 | def __init__(self, num_bins = 3, in_channels=1, out_channels=2, depth=4, slope=0.2): 50 | self.output_channels = out_channels 51 | self.num_encoders = depth 52 | self.base_num_channels = out_channels 53 | self.num_residual_blocks = depth 54 | self.in_channels = in_channels 55 | self.num_bins = num_bins # number of bins in the voxel grid event tensor 56 | config = {} 57 | super(RecurrUNet, self).__init__(config) 58 | 59 | try: 60 | self.recurrent_block_type = str(config['recurrent_block_type']) 61 | except KeyError: 62 | self.recurrent_block_type = 'convgru' # or 'convlstm' 63 | 64 | # self.unetrecurrent = UNet(num_input_channels=self.in_channels, 65 | # num_output_channels=self.output_channels, 66 | # skip_type='sum', 67 | # activation='sigmoid', 68 | # num_encoders=self.num_encoders, 69 | # base_num_channels=self.base_num_channels, 70 | # num_residual_blocks=self.num_residual_blocks, 71 | # norm=self.norm, 72 | # use_upsample_conv=self.use_upsample_conv) 73 | self.unetrecurrent = UNetRecurrent(num_input_channels=self.in_channels, 74 | num_output_channels=self.output_channels, 75 | skip_type='sum', 76 | recurrent_block_type=self.recurrent_block_type, 77 | activation='sigmoid', 78 | num_encoders=self.num_encoders, 79 | base_num_channels=self.base_num_channels, 80 | num_residual_blocks=self.num_residual_blocks, 81 | norm=self.norm, 82 | use_upsample_conv=self.use_upsample_conv) 83 | 84 | def forward(self, event_tensor): 85 | """ 86 | :param event_tensor: N x num_bins x H x W 87 | :param prev_states: previous ConvLSTM state for each encoder module 88 | :return: reconstructed image, taking values in [0,1]. 89 | """ 90 | # img_pred = self.unetrecurrent.forward(event_tensor) 91 | states = None 92 | num_bins = event_tensor.shape[1] 93 | for nth in range(num_bins): 94 | eventimg = event_tensor[:,nth,] 95 | eventimg = eventimg[:,np.newaxis,] 96 | img_pred, states = self.unetrecurrent.forward(eventimg, states) 97 | return img_pred 98 | -------------------------------------------------------------------------------- /networks/DnCNN.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2019-09-01 19:27:05 4 | 5 | import torch.nn as nn 6 | from .SubBlocks import conv3x3 7 | 8 | class DnCNN(nn.Module): 9 | def __init__(self, in_channels, out_channels, dep=20, num_filters=64, slope=0.2): 10 | ''' 11 | Reference: 12 | K. Zhang, W. Zuo, Y. Chen, D. Meng and L. Zhang, "Beyond a Gaussian Denoiser: Residual 13 | Learning of Deep CNN for Image Denoising," TIP, 2017. 14 | 15 | Args: 16 | in_channels (int): number of input channels 17 | out_channels (int): number of output channels 18 | dep (int): depth of the network, Default 20 19 | num_filters (int): number of filters in each layer, Default 64 20 | ''' 21 | super(DnCNN, self).__init__() 22 | self.conv1 = conv3x3(in_channels, num_filters, bias=True) 23 | self.relu = nn.LeakyReLU(slope, inplace=True) 24 | mid_layer = [] 25 | for ii in range(1, dep-1): 26 | mid_layer.append(conv3x3(num_filters, num_filters, bias=True)) 27 | mid_layer.append(nn.LeakyReLU(slope, inplace=True)) 28 | self.mid_layer = nn.Sequential(*mid_layer) 29 | self.conv_last = conv3x3(num_filters, out_channels, bias=True) 30 | 31 | def forward(self, x): 32 | x = self.conv1(x) 33 | x = self.relu(x) 34 | x = self.mid_layer(x) 35 | out = self.conv_last(x) 36 | 37 | return out 38 | 39 | -------------------------------------------------------------------------------- /networks/EventDenoiser.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2019-09-01 19:35:06 4 | 5 | import torch.nn as nn 6 | import torch 7 | from .DnCNN import DnCNN 8 | from .ConvLSTM import RecurrUNet 9 | 10 | def weight_init_kaiming(net): 11 | for m in net.modules(): 12 | if isinstance(m, nn.Conv2d): 13 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 14 | if not m.bias is None: 15 | nn.init.constant_(m.bias, 0) 16 | elif isinstance(m, nn.BatchNorm2d): 17 | nn.init.constant_(m.weight, 1) 18 | nn.init.constant_(m.bias, 0) 19 | return net 20 | 21 | class EventDenoiser(nn.Module): 22 | def __init__(self, input_images, dep_S=5, dep_U=4, slope=0.2): 23 | super(EventDenoiser, self).__init__() 24 | config = {'num_bins' : 3} 25 | self.ReconNet = RecurrUNet(num_bins = 3, in_channels = 1, out_channels = 1, depth=dep_U, slope=slope) 26 | self.ErrorNet = DnCNN(in_channels = 3, out_channels = 1, dep=dep_S, num_filters=64, slope=slope) 27 | 28 | def forward(self, x): 29 | img_estim = self.ReconNet(x) 30 | err_estim = self.ErrorNet(x) 31 | evterr = torch.cat((img_estim,err_estim),dim=1) 32 | return evterr 33 | -------------------------------------------------------------------------------- /networks/EventVLAD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Imagenet_matconvnet_vgg_verydeep_16_dag(nn.Module): 5 | 6 | def __init__(self): 7 | super().__init__() 8 | self.meta = {'mean': [122.74494171142578, 114.94409942626953, 101.64177703857422], 9 | 'std': [1, 1, 1], 10 | 'imageSize': [224, 224, 3]} 11 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 12 | self.relu1_1 = nn.ReLU() 13 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 14 | self.relu1_2 = nn.ReLU() 15 | self.pool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False) 16 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 17 | self.relu2_1 = nn.ReLU() 18 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 19 | self.relu2_2 = nn.ReLU() 20 | self.pool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False) 21 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 22 | self.relu3_1 = nn.ReLU() 23 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 24 | self.relu3_2 = nn.ReLU() 25 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 26 | self.relu3_3 = nn.ReLU() 27 | self.pool3 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False) 28 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 29 | self.relu4_1 = nn.ReLU() 30 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 31 | self.relu4_2 = nn.ReLU() 32 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 33 | self.relu4_3 = nn.ReLU() 34 | self.pool4 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False) 35 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 36 | self.relu5_1 = nn.ReLU() 37 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 38 | self.relu5_2 = nn.ReLU() 39 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 40 | self.relu5_3 = nn.ReLU() 41 | self.pool5 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False) 42 | self.fc6 = nn.Conv2d(512, 4096, kernel_size=[7, 7], stride=(1, 1)) 43 | self.relu6 = nn.ReLU() 44 | self.fc7 = nn.Linear(in_features=4096, out_features=4096, bias=True) 45 | self.relu7 = nn.ReLU() 46 | self.fc8 = nn.Linear(in_features=4096, out_features=1000, bias=True) 47 | 48 | def forward(self, x0): 49 | x1 = self.conv1_1(x0) 50 | x2 = self.relu1_1(x1) 51 | x3 = self.conv1_2(x2) 52 | x4 = self.relu1_2(x3) 53 | x5 = self.pool1(x4) 54 | x6 = self.conv2_1(x5) 55 | x7 = self.relu2_1(x6) 56 | x8 = self.conv2_2(x7) 57 | x9 = self.relu2_2(x8) 58 | x10 = self.pool2(x9) 59 | x11 = self.conv3_1(x10) 60 | x12 = self.relu3_1(x11) 61 | x13 = self.conv3_2(x12) 62 | x14 = self.relu3_2(x13) 63 | x15 = self.conv3_3(x14) 64 | x16 = self.relu3_3(x15) 65 | x17 = self.pool3(x16) 66 | x18 = self.conv4_1(x17) 67 | x19 = self.relu4_1(x18) 68 | x20 = self.conv4_2(x19) 69 | x21 = self.relu4_2(x20) 70 | x22 = self.conv4_3(x21) 71 | x23 = self.relu4_3(x22) 72 | x24 = self.pool4(x23) 73 | x25 = self.conv5_1(x24) 74 | x26 = self.relu5_1(x25) 75 | x27 = self.conv5_2(x26) 76 | x28 = self.relu5_2(x27) 77 | x29 = self.conv5_3(x28) 78 | x30 = self.relu5_3(x29) 79 | x31 = self.pool5(x30) 80 | x32 = self.fc6(x31) 81 | x33_preflatten = self.relu6(x32) 82 | x33 = x33_preflatten.view(x33_preflatten.size(0), -1) 83 | x34 = self.fc7(x33) 84 | x35 = self.relu7(x34) 85 | x36 = self.fc8(x35) 86 | return x36 87 | 88 | def Imagenet_vgg(weights_path=None, **kwargs): 89 | """ 90 | load imported model instance 91 | 92 | Args: 93 | weights_path (str): If set, loads model weights from the given path 94 | """ 95 | model = Imagenet_matconvnet_vgg_verydeep_16_dag() 96 | if weights_path: 97 | state_dict = torch.load(weights_path) 98 | model.load_state_dict(state_dict) 99 | return model 100 | -------------------------------------------------------------------------------- /networks/SubBlocks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2019-09-01 19:19:32 4 | 5 | import torch 6 | import torch.nn as nn 7 | import sys 8 | import torch.nn.functional as F 9 | 10 | def conv3x3(in_chn, out_chn, bias=True): 11 | layer = nn.Conv2d(in_chn, out_chn, kernel_size=3, stride=1, padding=1, bias=bias) 12 | return layer 13 | 14 | def pixel_unshuffle(input, upscale_factor): 15 | ''' 16 | Input: 17 | input: (N, C, rH, rW) tensor 18 | output: 19 | (N, r^2C, H, W) 20 | Written by Kai Zhang: https://github.com/cszn/FFDNet 21 | ''' 22 | batch_size, channels, in_height, in_width = input.size() 23 | 24 | out_height = in_height // upscale_factor 25 | out_width = in_width // upscale_factor 26 | 27 | input_view = input.contiguous().view( batch_size, channels, out_height, upscale_factor, 28 | out_width, upscale_factor) 29 | 30 | channels *= upscale_factor ** 2 31 | unshuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous() 32 | return unshuffle_out.view(batch_size, channels, out_height, out_width) 33 | 34 | class PixelUnShuffle(nn.Module): 35 | ''' 36 | Input: 37 | input: (N, C, rH, rW) tensor 38 | output: 39 | (N, r^2C, H, W) 40 | Written by Kai Zhang: https://github.com/cszn/FFDNet 41 | ''' 42 | def __init__(self, upscale_factor): 43 | super(PixelUnShuffle, self).__init__() 44 | self.upscale_factor = upscale_factor 45 | 46 | def forward(self, input): 47 | return pixel_unshuffle(input, self.upscale_factor) 48 | 49 | def extra_repr(self): 50 | return 'upscale_factor={}'.format(self.upscale_factor) 51 | 52 | -------------------------------------------------------------------------------- /networks/UNet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2019-03-20 19:48:14 4 | # Adapted from https://github.com/jvanvugt/pytorch-unet 5 | 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | from .SubBlocks import conv3x3 10 | 11 | class UNet(nn.Module): 12 | def __init__(self, in_channels=3, out_channels=6, depth=4, wf=64, slope=0.2): 13 | """ 14 | Reference: 15 | Ronneberger O., Fischer P., Brox T. (2015) U-Net: Convolutional Networks for Biomedical 16 | Image Segmentation. MICCAI 2015. 17 | ArXiv Version: https://arxiv.org/abs/1505.04597 18 | 19 | Args: 20 | in_channels (int): number of input channels, Default 3 21 | depth (int): depth of the network, Default 4 22 | wf (int): number of filters in the first layer, Default 32 23 | """ 24 | super(UNet, self).__init__() 25 | self.depth = depth 26 | prev_channels = in_channels 27 | self.down_path = nn.ModuleList() 28 | for i in range(depth): 29 | self.down_path.append(UNetConvBlock(prev_channels, (2**i)*wf, slope)) 30 | prev_channels = (2**i) * wf 31 | 32 | self.up_path = nn.ModuleList() 33 | for i in reversed(range(depth - 1)): 34 | self.up_path.append(UNetUpBlock(prev_channels, (2**i)*wf, slope)) 35 | prev_channels = (2**i)*wf 36 | 37 | self.last = conv3x3(prev_channels, out_channels, bias=True) 38 | 39 | def forward(self, x): 40 | blocks = [] 41 | for i, down in enumerate(self.down_path): 42 | x = down(x) 43 | if i != len(self.down_path)-1: 44 | blocks.append(x) 45 | x = F.avg_pool2d(x, 2) 46 | 47 | for i, up in enumerate(self.up_path): 48 | x = up(x, blocks[-i-1]) 49 | 50 | return self.last(x) 51 | 52 | class UNetConvBlock(nn.Module): 53 | def __init__(self, in_size, out_size, slope=0.2): 54 | super(UNetConvBlock, self).__init__() 55 | block = [] 56 | 57 | block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=1, bias=True)) 58 | block.append(nn.LeakyReLU(slope, inplace=True)) 59 | 60 | block.append(nn.Conv2d(out_size, out_size, kernel_size=3, padding=1, bias=True)) 61 | block.append(nn.LeakyReLU(slope, inplace=True)) 62 | 63 | self.block = nn.Sequential(*block) 64 | 65 | def forward(self, x): 66 | out = self.block(x) 67 | return out 68 | 69 | class UNetUpBlock(nn.Module): 70 | def __init__(self, in_size, out_size, slope=0.2): 71 | super(UNetUpBlock, self).__init__() 72 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, bias=True) 73 | self.conv_block = UNetConvBlock(in_size, out_size, slope) 74 | 75 | def center_crop(self, layer, target_size): 76 | _, _, layer_height, layer_width = layer.size() 77 | diff_y = (layer_height - target_size[0]) // 2 78 | diff_x = (layer_width - target_size[1]) // 2 79 | return layer[:, :, diff_y:(diff_y + target_size[0]), diff_x:(diff_x + target_size[1])] 80 | 81 | def forward(self, x, bridge): 82 | up = self.up(x) 83 | crop1 = self.center_crop(bridge, up.shape[2:]) 84 | out = torch.cat([up, crop1], 1) 85 | out = self.conv_block(out) 86 | 87 | return out 88 | 89 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .DnCNN import DnCNN 2 | from .UNet import UNet 3 | from .EventDenoiser import EventDenoiser, weight_init_kaiming 4 | from .vgg16 import Imagenet_vgg, Imagenet_matconvnet_vgg_verydeep_16_dag 5 | -------------------------------------------------------------------------------- /networks/__pycache__/ConvLSTM.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexjunholee/EventVLAD/b02f7c5a5758be08e39f5b6bcd51d7c7720b3251/networks/__pycache__/ConvLSTM.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/DnCNN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexjunholee/EventVLAD/b02f7c5a5758be08e39f5b6bcd51d7c7720b3251/networks/__pycache__/DnCNN.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/EventDenoiser.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexjunholee/EventVLAD/b02f7c5a5758be08e39f5b6bcd51d7c7720b3251/networks/__pycache__/EventDenoiser.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/SubBlocks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexjunholee/EventVLAD/b02f7c5a5758be08e39f5b6bcd51d7c7720b3251/networks/__pycache__/SubBlocks.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/UNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexjunholee/EventVLAD/b02f7c5a5758be08e39f5b6bcd51d7c7720b3251/networks/__pycache__/UNet.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexjunholee/EventVLAD/b02f7c5a5758be08e39f5b6bcd51d7c7720b3251/networks/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/submodules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexjunholee/EventVLAD/b02f7c5a5758be08e39f5b6bcd51d7c7720b3251/networks/__pycache__/submodules.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/unet_rpg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexjunholee/EventVLAD/b02f7c5a5758be08e39f5b6bcd51d7c7720b3251/networks/__pycache__/unet_rpg.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/vgg16.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexjunholee/EventVLAD/b02f7c5a5758be08e39f5b6bcd51d7c7720b3251/networks/__pycache__/vgg16.cpython-38.pyc -------------------------------------------------------------------------------- /networks/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_model import * 2 | -------------------------------------------------------------------------------- /networks/base/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexjunholee/EventVLAD/b02f7c5a5758be08e39f5b6bcd51d7c7720b3251/networks/base/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /networks/base/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexjunholee/EventVLAD/b02f7c5a5758be08e39f5b6bcd51d7c7720b3251/networks/base/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /networks/base/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexjunholee/EventVLAD/b02f7c5a5758be08e39f5b6bcd51d7c7720b3251/networks/base/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /networks/base/__pycache__/base_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexjunholee/EventVLAD/b02f7c5a5758be08e39f5b6bcd51d7c7720b3251/networks/base/__pycache__/base_model.cpython-36.pyc -------------------------------------------------------------------------------- /networks/base/__pycache__/base_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexjunholee/EventVLAD/b02f7c5a5758be08e39f5b6bcd51d7c7720b3251/networks/base/__pycache__/base_model.cpython-37.pyc -------------------------------------------------------------------------------- /networks/base/__pycache__/base_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexjunholee/EventVLAD/b02f7c5a5758be08e39f5b6bcd51d7c7720b3251/networks/base/__pycache__/base_model.cpython-38.pyc -------------------------------------------------------------------------------- /networks/base/base_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class BaseModel(nn.Module): 7 | """ 8 | Base class for all models 9 | """ 10 | def __init__(self, config): 11 | super(BaseModel, self).__init__() 12 | self.config = config 13 | self.logger = logging.getLogger(self.__class__.__name__) 14 | 15 | def forward(self, *input): 16 | """ 17 | Forward pass logic 18 | 19 | :return: Model output 20 | """ 21 | raise NotImplementedError 22 | 23 | def summary(self): 24 | """ 25 | Model summary 26 | """ 27 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 28 | params = sum([np.prod(p.size()) for p in model_parameters]) 29 | self.logger.info('Trainable parameters: {}'.format(params)) 30 | self.logger.info(self) 31 | -------------------------------------------------------------------------------- /networks/image_transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import logging 6 | from tqdm import tqdm 7 | import matplotlib.pyplot as plt 8 | import seaborn as sns 9 | 10 | NUM_PIXELS = 256 11 | 12 | # Numerically stable implementations 13 | def logsoftmax(x): 14 | m = torch.max(x, -1, keepdim=True).values 15 | return x - m - torch.log(torch.exp(x - m).sum(-1, keepdim=True)) 16 | 17 | def logsumexp(x): 18 | m = x.max(-1).values 19 | return m + torch.log(torch.exp(x - m[...,None]).sum(-1)) 20 | 21 | class ConvLayer(nn.Module): 22 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None): 23 | super(ConvLayer, self).__init__() 24 | 25 | bias = False if norm == 'BN' else True 26 | self.conv3d = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 27 | nn.init.xavier_uniform(self.conv3d.weight) 28 | 29 | if activation is not None: 30 | self.activation = getattr(torch, activation, 'relu') 31 | else: 32 | self.activation = None 33 | 34 | self.norm = norm 35 | if norm == 'BN': 36 | self.norm_layer = nn.BatchNorm2d(out_channels) 37 | elif norm == 'IN': 38 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) 39 | 40 | def forward(self, x): 41 | out = self.conv3d(x) 42 | 43 | if self.norm in ['BN', 'IN']: 44 | out = self.norm_layer(out) 45 | 46 | if self.activation is not None: 47 | out = self.activation(out) 48 | 49 | return out 50 | 51 | class ImageTransformer(nn.Module): 52 | """ImageTransformer with DMOL or categorical distribution.""" 53 | def __init__(self, hparams): 54 | super().__init__() 55 | self.hparams = hparams 56 | self.layers = nn.ModuleList([DecoderLayer(hparams) for _ in range(hparams.nlayers)]) 57 | self.input_dropout = nn.Dropout(p=hparams.dropout) 58 | self.last_fc1 = nn.Linear(3*256, 3*16, bias = True) 59 | self.last_fc2 = nn.Linear(3*16, 3, bias = True) 60 | nn.init.xavier_uniform(self.last_fc1.weight) 61 | nn.init.xavier_uniform(self.last_fc2.weight) 62 | self.last_conv = ConvLayer(in_channels=3,out_channels=1,kernel_size=5,stride=1,padding=2) 63 | if self.hparams.distr == "dmol": # Discretized mixture of logistic, for ordinal valued inputs 64 | assert self.hparams.channels == 3, "Only supports 3 channels for DML" 65 | size = (1, self.hparams.channels) 66 | self.embedding_conv = nn.Conv2d(1, self.hparams.hidden_size, 67 | kernel_size=size, stride=size) 68 | # 10 = 1 + 2c + c(c-1)/2; if only 1 channel, then 3 total 69 | depth = self.hparams.num_mixtures * 10 70 | self.output_dense = nn.Linear(self.hparams.hidden_size, depth, bias=False) 71 | elif self.hparams.distr == "cat": # Categorical 72 | self.embeds = nn.Embedding(NUM_PIXELS * self.hparams.channels, self.hparams.hidden_size) 73 | self.output_dense = nn.Linear(self.hparams.hidden_size, NUM_PIXELS, bias=True) 74 | else: 75 | raise ValueError("Only dmol or categorical distributions") 76 | 77 | def add_timing_signal(self, X, min_timescale=1.0, max_timescale=1.0e4): 78 | num_dims = len(X.shape) - 2 # 2 corresponds to batch and hidden_size dimensions 79 | num_timescales = self.hparams.hidden_size // (num_dims * 2) 80 | log_timescale_increment = np.log(max_timescale / min_timescale) / (num_timescales - 1) 81 | inv_timescales = min_timescale * torch.exp((torch.arange(num_timescales).float() * -log_timescale_increment)) 82 | inv_timescales = inv_timescales.to(X.device) 83 | total_signal = torch.zeros_like(X) # Only for debugging purposes 84 | for dim in range(num_dims): 85 | length = X.shape[dim + 1] # add 1 to exclude batch dim 86 | position = torch.arange(length).float().to(X.device) 87 | scaled_time = position.view(-1, 1) * inv_timescales.view(1, -1) 88 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 1) 89 | prepad = dim * 2 * num_timescales 90 | postpad = self.hparams.hidden_size - (dim + 1) * 2 * num_timescales 91 | signal = F.pad(signal, (prepad, postpad)) 92 | for _ in range(1 + dim): 93 | signal = signal.unsqueeze(0) 94 | for _ in range(num_dims - 1 - dim): 95 | signal = signal.unsqueeze(-2) 96 | X += signal 97 | total_signal += signal 98 | return X 99 | 100 | def shift_and_pad_(self, X): 101 | # Shift inputs over by 1 and pad 102 | shape = X.shape 103 | X = X.view(shape[0], shape[1] * shape[2], shape[3]) 104 | X = X[:,:-1,:] 105 | X = F.pad(X, (0, 0, 1, 0)) # Pad second to last dimension 106 | X = X.view(shape) 107 | return X 108 | 109 | def forward(self, X, sampling=False): 110 | # Reshape inputs 111 | if sampling: 112 | curr_infer_length = X.shape[1] 113 | row_size = self.hparams.image_size * self.hparams.channels 114 | nrows = curr_infer_length // row_size + 1 115 | X = F.pad(X, (0, nrows * row_size - curr_infer_length)) 116 | X = X.view(X.shape[0], -1, row_size) 117 | else: 118 | X = X.permute([0, 2, 3, 1]).contiguous() 119 | X = X.view(X.shape[0], X.shape[1], X.shape[2] * X.shape[3]) # Flatten channels into width 120 | 121 | # Inputs -> embeddings 122 | if self.hparams.distr == "dmol": 123 | # Create a "channel" dimension for the 1x3 convolution 124 | # (NOTE: can apply a 1x1 convolution and not reshape, this is for consistency) 125 | X = X.unsqueeze(1) 126 | X = F.relu(self.embedding_conv(X)) 127 | X = X.permute([0, 2, 3, 1]) # move channels to the end 128 | elif self.hparams.distr == "cat": 129 | # Convert to indexes, and use separate embeddings for different channels 130 | X = (X * (NUM_PIXELS - 1)).long() 131 | channel_addition = (torch.tensor([0, 1, 2]) * NUM_PIXELS).to(X.device).repeat(X.shape[2] // 3).view(1, 1, -1) 132 | X += channel_addition 133 | X = self.embeds(X) * (self.hparams.hidden_size ** 0.5) 134 | 135 | X = self.shift_and_pad_(X) 136 | X = self.add_timing_signal(X) 137 | shape = X.shape 138 | X = X.view(shape[0], -1, shape[3]) 139 | 140 | X = self.input_dropout(X) 141 | for layer in self.layers: 142 | X = layer(X) 143 | X = self.layers[-1].preprocess_(X) # NOTE: this is identity (exists to replicate tensorflow code) 144 | X = self.output_dense(X).view(shape[:3] + (-1,)) 145 | 146 | if not sampling and self.hparams.distr == "cat": # Unpack the channels 147 | X = X.view(X.shape[0], X.shape[1], X.shape[2] // self.hparams.channels, self.hparams.channels, X.shape[3]) 148 | X = X.permute([0, 3, 1, 2, 4]) 149 | X = torch.reshape(X,(64*64,3*256)) 150 | X = self.last_fc1(X) 151 | X = self.last_fc2(X) 152 | X = torch.reshape(X,(1,3,64,64,1)) 153 | X = self.last_conv(X) 154 | 155 | return X 156 | 157 | def split_to_dml_params(self, preds, targets=None, sampling=False): 158 | nm = self.hparams.num_mixtures 159 | mix_logits, locs, log_scales, coeffs = torch.split(preds, [nm, nm * 3, nm * 3, nm * 3], dim=-1) 160 | new_shape = preds.shape[:-1] + (3, nm) 161 | locs = locs.view(new_shape) 162 | coeffs = torch.tanh(coeffs.view(new_shape)) 163 | log_scales = torch.clamp(log_scales.view(new_shape), min=-7.) 164 | if not sampling: 165 | targets = targets.unsqueeze(-1) 166 | locs1 = locs[...,1,:] + coeffs[...,0,:] * targets[:,0,...] 167 | locs2 = locs[...,2,:] + coeffs[...,1,:] * targets[:,0,...] + coeffs[...,2,:] * targets[:,1,...] 168 | locs = torch.stack([locs[...,0,:], locs1, locs2], dim=-2) 169 | return mix_logits, locs, log_scales 170 | else: 171 | return mix_logits, locs, log_scales, coeffs 172 | 173 | # Modified from official PixCNN++ code 174 | def dml_logp(self, logits, means, log_scales, targets): 175 | targets = targets.unsqueeze(-1) 176 | centered_x = targets - means 177 | inv_stdv = torch.exp(-log_scales) 178 | plus_in = inv_stdv * (centered_x + 1. / 255.) 179 | cdf_plus = torch.sigmoid(plus_in) 180 | min_in = inv_stdv * (centered_x - 1. / 255.) 181 | cdf_min = torch.sigmoid(min_in) 182 | log_cdf_plus = plus_in - F.softplus(plus_in) # log probability for edge case of 0 (before scaling) 183 | log_one_minus_cdf_min = -F.softplus(min_in) # log probability for edge case of 255 (before scaling) 184 | cdf_delta = cdf_plus - cdf_min # probability for all other cases 185 | mid_in = inv_stdv * centered_x 186 | log_pdf_mid = mid_in - log_scales - 2. * F.softplus( 187 | mid_in) # log probability in the center of the bin, to be used in extreme cases (not actually used in our code) 188 | 189 | # now select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen for us) 190 | log_probs = torch.where(targets < -0.999, log_cdf_plus, torch.where(targets > 0.999, log_one_minus_cdf_min, 191 | torch.where(cdf_delta > 1e-5, 192 | torch.log(torch.clamp(cdf_delta, min=1e-12)), 193 | log_pdf_mid - np.log(127.5)))) 194 | log_probs = log_probs.sum(3) + logsoftmax(logits) 195 | log_probs = logsumexp(log_probs) 196 | return log_probs 197 | 198 | # Assumes targets have been rescaled to [-1., 1.] 199 | def loss(self, preds, targets): 200 | if self.hparams.distr == "dmol": 201 | # Assumes 3 channels. Input: [batch_size, height, width, 10 * 10] 202 | logits, locs, log_scales = self.split_to_dml_params(preds, targets) 203 | targets = targets.permute([0, 2, 3, 1]) 204 | log_probs = self.dml_logp(logits, locs, log_scales, targets) 205 | return -log_probs 206 | elif self.hparams.distr == "cat": 207 | targets = (targets * (NUM_PIXELS - 1)).long() 208 | # ce = F.cross_entropy(preds.permute(0, 4, 1, 2, 3), targets, reduction='none') 209 | ce = torch.abs(preds[:,:,:,:,0] - targets) 210 | return ce 211 | 212 | def accuracy(self, preds, targets): 213 | if self.hparams.distr == "cat": 214 | targets = (targets * (NUM_PIXELS - 1)).long() 215 | argmax_preds = torch.argmax(preds, dim=-1) 216 | acc = torch.eq(argmax_preds, targets).float().sum() / np.prod(argmax_preds.shape) 217 | return acc 218 | else: 219 | # Computing accuracy for dmol is more computationally intensive, so we skip it 220 | return torch.zeros((1,)) 221 | 222 | def sample_from_dmol(self, outputs): 223 | logits, locs, log_scales, coeffs = self.split_to_dml_params(outputs, sampling=True) 224 | gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) * (1. - 2 * 1e-5) + 1e-5)) 225 | sel = torch.argmax(logits + gumbel_noise, -1, keepdim=True) 226 | one_hot = torch.zeros_like(logits).scatter_(-1, sel, 1).unsqueeze(-2) 227 | locs = (locs * one_hot).sum(-1) 228 | log_scales = (log_scales * one_hot).sum(-1) 229 | coeffs = (coeffs * one_hot).sum(-1) 230 | unif = torch.rand_like(log_scales) * (1. - 2 * 1e-5) + 1e-5 231 | logistic_noise = torch.log(unif) - torch.log1p(-unif) 232 | x = locs + torch.exp(log_scales) * logistic_noise 233 | # NOTE: sampling analogously to pixcnn++, which clamps first, unlike image transformer 234 | x0 = torch.clamp(x[..., 0], -1., 1.) 235 | x1 = torch.clamp(x[..., 1] + coeffs[..., 0] * x0, -1., 1.) 236 | x2 = torch.clamp(x[..., 2] + coeffs[..., 1] * x0 + coeffs[..., 2] * x1, -1., 1.) 237 | x = torch.stack([x0, x1, x2], -1) 238 | return x 239 | 240 | def sample_from_cat(self, logits, argmax=False): 241 | if argmax: 242 | sel = torch.argmax(logits, -1, keepdim=False).float() / 255. 243 | else: 244 | gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) * (1. - 2 * 1e-5) + 1e-5)) 245 | sel = torch.argmax(logits + gumbel_noise, -1, keepdim=False).float() / 255. 246 | return sel 247 | 248 | def sample(self, n, device, argmax=False): 249 | total_len = (self.hparams.image_size ** 2) 250 | if self.hparams.distr == "cat": 251 | total_len *= self.hparams.channels 252 | samples = torch.zeros((n, 3)).to(device) 253 | for curr_infer_length in tqdm(range(total_len)): 254 | outputs = self.forward(samples, sampling=True) 255 | outputs = outputs.view(n, -1, outputs.shape[-1])[:,curr_infer_length:curr_infer_length+1,:] 256 | if self.hparams.distr == "dmol": 257 | x = self.sample_from_dmol(outputs).squeeze() 258 | elif self.hparams.distr == "cat": 259 | x = self.sample_from_cat(outputs, argmax=argmax) 260 | if curr_infer_length == 0: 261 | samples = x 262 | else: 263 | samples = torch.cat([samples, x], 1) 264 | samples = samples.view(n, self.hparams.image_size, self.hparams.image_size, self.hparams.channels) 265 | samples = samples.permute(0, 3, 1, 2) 266 | return samples 267 | 268 | def sample_from_preds(self, preds, argmax=False): 269 | if self.hparams.distr == "dmol": 270 | samples = self.sample_from_dmol(preds) 271 | samples = samples.permute(0, 3, 1, 2) 272 | elif self.hparams.distr == "cat": 273 | samples = self.sample_from_cat(preds, argmax=argmax) 274 | return samples 275 | 276 | 277 | class DecoderLayer(nn.Module): 278 | """Implements a single layer of an unconditional ImageTransformer""" 279 | def __init__(self, hparams): 280 | super().__init__() 281 | self.attn = Attn(hparams) 282 | self.hparams = hparams 283 | self.dropout = nn.Dropout(p=hparams.dropout) 284 | self.layernorm_attn = nn.LayerNorm([self.hparams.hidden_size], eps=1e-6, elementwise_affine=True) 285 | self.layernorm_ffn = nn.LayerNorm([self.hparams.hidden_size], eps=1e-6, elementwise_affine=True) 286 | self.ffn = nn.Sequential(nn.Linear(self.hparams.hidden_size, self.hparams.filter_size, bias=True), 287 | nn.ReLU(), 288 | nn.Linear(self.hparams.filter_size, self.hparams.hidden_size, bias=True)) 289 | 290 | def preprocess_(self, X): 291 | return X 292 | 293 | # Takes care of the "postprocessing" from tensorflow code with the layernorm and dropout 294 | def forward(self, X): 295 | X = self.preprocess_(X) 296 | y = self.attn(X) 297 | X = self.layernorm_attn(self.dropout(y) + X) 298 | y = self.ffn(self.preprocess_(X)) 299 | X = self.layernorm_ffn(self.dropout(y) + X) 300 | return X 301 | 302 | class Attn(nn.Module): 303 | def __init__(self, hparams): 304 | super().__init__() 305 | self.hparams = hparams 306 | self.kd = self.hparams.total_key_depth or self.hparams.hidden_size 307 | self.vd = self.hparams.total_value_depth or self.hparams.hidden_size 308 | self.q_dense = nn.Linear(self.hparams.hidden_size, self.kd, bias=False) 309 | self.k_dense = nn.Linear(self.hparams.hidden_size, self.kd, bias=False) 310 | self.v_dense = nn.Linear(self.hparams.hidden_size, self.vd, bias=False) 311 | self.output_dense = nn.Linear(self.vd, self.hparams.hidden_size, bias=False) 312 | assert self.kd % self.hparams.num_heads == 0 313 | assert self.vd % self.hparams.num_heads == 0 314 | 315 | def dot_product_attention(self, q, k, v, bias=None): 316 | logits = torch.einsum("...kd,...qd->...qk", k, q) 317 | if bias is not None: 318 | logits += bias 319 | weights = F.softmax(logits, dim=-1) 320 | return weights @ v 321 | 322 | def forward(self, X): 323 | q = self.q_dense(X) 324 | k = self.k_dense(X) 325 | v = self.v_dense(X) 326 | # Split to shape [batch_size, num_heads, len, depth / num_heads] 327 | q = q.view(q.shape[:-1] + (self.hparams.num_heads, self.kd // self.hparams.num_heads)).permute([0, 2, 1, 3]) 328 | k = k.view(k.shape[:-1] + (self.hparams.num_heads, self.kd // self.hparams.num_heads)).permute([0, 2, 1, 3]) 329 | v = v.view(v.shape[:-1] + (self.hparams.num_heads, self.vd // self.hparams.num_heads)).permute([0, 2, 1, 3]) 330 | q *= (self.kd // self.hparams.num_heads) ** (-0.5) 331 | 332 | if self.hparams.attn_type == "global": 333 | bias = -1e9 * torch.triu(torch.ones(X.shape[1], X.shape[1]), 1).to(X.device) 334 | result = self.dot_product_attention(q, k, v, bias=bias) 335 | elif self.hparams.attn_type == "local_1d": 336 | len = X.shape[1] 337 | blen = self.hparams.block_length 338 | pad = (0, 0, 0, (-len) % self.hparams.block_length) # Append to multiple of block length 339 | q = F.pad(q, pad) 340 | k = F.pad(k, pad) 341 | v = F.pad(v, pad) 342 | 343 | bias = -1e9 * torch.triu(torch.ones(blen, blen), 1).to(X.device) 344 | first_output = self.dot_product_attention( 345 | q[:,:,:blen,:], k[:,:,:blen,:], v[:,:,:blen,:], bias=bias) 346 | 347 | if q.shape[2] > blen: 348 | q = q.view(q.shape[0], q.shape[1], -1, blen, q.shape[3]) 349 | k = k.view(k.shape[0], k.shape[1], -1, blen, k.shape[3]) 350 | v = v.view(v.shape[0], v.shape[1], -1, blen, v.shape[3]) 351 | local_k = torch.cat([k[:,:,:-1], k[:,:,1:]], 3) # [batch, nheads, (nblocks - 1), blen * 2, depth] 352 | local_v = torch.cat([v[:,:,:-1], v[:,:,1:]], 3) 353 | tail_q = q[:,:,1:] 354 | bias = -1e9 * torch.triu(torch.ones(blen, 2 * blen), blen + 1).to(X.device) 355 | tail_output = self.dot_product_attention(tail_q, local_k, local_v, bias=bias) 356 | tail_output = tail_output.view(tail_output.shape[0], tail_output.shape[1], -1, tail_output.shape[4]) 357 | result = torch.cat([first_output, tail_output], 2) 358 | result = result[:,:,:X.shape[1],:] 359 | else: 360 | result = first_output[:,:,:X.shape[1],:] 361 | 362 | result = result.permute([0, 2, 1, 3]).contiguous() 363 | result = result.view(result.shape[0:2] + (-1,)) 364 | result = self.output_dense(result) 365 | return result 366 | 367 | -------------------------------------------------------------------------------- /networks/netvlad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from sklearn.neighbors import NearestNeighbors 5 | 6 | import numpy as np 7 | 8 | class GeM(nn.Module): 9 | def __init__(self, p=3, eps=1e-6): 10 | super(GeM,self).__init__() 11 | self.p = nn.Parameter(torch.ones(1)*p) 12 | self.eps = eps 13 | 14 | def forward(self, x): 15 | return self.gem(x, p=self.p, eps=self.eps) 16 | 17 | def gem(self, x, p=3, eps=1e-6): 18 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p) 19 | 20 | def __repr__(self): 21 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' 22 | 23 | class NetVLAD(nn.Module): 24 | """NetVLAD layer implementation""" 25 | 26 | def __init__(self, num_clusters=64, dim=128, alpha=100.0, 27 | normalize_input=True): 28 | """ 29 | Args: 30 | num_clusters : int 31 | The number of clusters 32 | dim : int 33 | Dimension of descriptors 34 | alpha : float 35 | Parameter of initialization. Larger value is harder assignment. 36 | normalize_input : bool 37 | If true, descriptor-wise L2 normalization is applied to input. 38 | """ 39 | super(NetVLAD, self).__init__() 40 | self.num_clusters = num_clusters 41 | self.dim = dim 42 | self.alpha = alpha 43 | self.normalize_input = normalize_input 44 | self.conv = nn.Conv2d(dim, num_clusters, kernel_size=(1, 1), bias=True) 45 | self.centroids = nn.Parameter(torch.rand(num_clusters, dim)) 46 | self.lastfc = nn.Linear(in_features=num_clusters, out_features=1, bias=False) 47 | # self._init_params() 48 | 49 | def _init_params(self, clsts, traindescs): 50 | clstsAssign = clsts / np.linalg.norm(clsts, axis=1, keepdims=True) 51 | dots = np.dot(clstsAssign, traindescs.T) 52 | dots.sort(0) 53 | dots = dots[::-1, :] # sort, descending 54 | 55 | self.alpha = (-np.log(0.01) / np.mean(dots[0,:] - dots[1,:])).item() 56 | self.centroids = nn.Parameter(torch.from_numpy(clsts)) 57 | self.conv.weight = nn.Parameter(torch.from_numpy(self.alpha*clstsAssign).unsqueeze(2).unsqueeze(3)) 58 | self.conv.bias = None 59 | 60 | def forward(self, x): 61 | N, C = x.shape[:2] 62 | 63 | if self.normalize_input: 64 | x = F.normalize(x, p=2, dim=1) # across descriptor dim 65 | 66 | # soft-assignment 67 | soft_assign = self.conv(x).view(N, self.num_clusters, -1) 68 | soft_assign = F.softmax(soft_assign, dim=1) 69 | 70 | x_flatten = x.view(N, C, -1) 71 | 72 | # calculate residuals to each clusters 73 | vlad = torch.zeros([N, self.num_clusters, C], dtype=x.dtype, layout=x.layout, device=x.device) 74 | for C in range(self.num_clusters): # slower than non-looped, but lower memory usage 75 | residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - \ 76 | self.centroids[C:C+1, :].expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0) 77 | residual *= soft_assign[:,C:C+1,:].unsqueeze(2) 78 | vlad[:,C:C+1,:] = residual.sum(dim=-1) 79 | 80 | vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization 81 | # vlad = vlad.view(x.size(0), -1) # flatten 82 | vlad = self.lastfc(vlad.permute(0, 2, 1)).squeeze() # flatten 83 | # vlad = F.normalize(vlad, p=2, dim=0) # L2 normalize 84 | return vlad 85 | 86 | 87 | class EmbedNet(nn.Module): 88 | def __init__(self, base_model, net_vlad): 89 | super(EmbedNet, self).__init__() 90 | self.base_model = base_model 91 | self.net_vlad = net_vlad 92 | 93 | def forward(self, x): 94 | x = self.base_model(x) 95 | x = x[:,:,np.newaxis,np.newaxis] 96 | embedded_x = self.net_vlad(x) 97 | return embedded_x 98 | 99 | 100 | class TripletNet(nn.Module): 101 | def __init__(self, embed_net): 102 | super(TripletNet, self).__init__() 103 | self.embed_net = embed_net 104 | 105 | def forward(self, a, p, n): 106 | embedded_a = self.embed_net(a) 107 | embedded_p = self.embed_net(p) 108 | embedded_n = self.embed_net(n) 109 | return embedded_a, embedded_p, embedded_n 110 | 111 | def feature_extract(self, x): 112 | return self.embed_net(x) 113 | -------------------------------------------------------------------------------- /networks/submodules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as f 4 | from torch.nn import init 5 | 6 | 7 | class ConvLayer(nn.Module): 8 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None): 9 | super(ConvLayer, self).__init__() 10 | 11 | bias = False if norm == 'BN' else True 12 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 13 | nn.init.xavier_uniform_(self.conv2d.weight) 14 | 15 | if activation is not None: 16 | self.activation = getattr(torch, activation, 'relu') 17 | else: 18 | self.activation = None 19 | 20 | self.norm = norm 21 | if norm == 'BN': 22 | self.norm_layer = nn.BatchNorm2d(out_channels) 23 | elif norm == 'IN': 24 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) 25 | 26 | def forward(self, x): 27 | out = self.conv2d(x) 28 | 29 | if self.norm in ['BN', 'IN']: 30 | out = self.norm_layer(out) 31 | 32 | if self.activation is not None: 33 | out = self.activation(out) 34 | 35 | return out 36 | 37 | 38 | class TransposedConvLayer(nn.Module): 39 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None): 40 | super(TransposedConvLayer, self).__init__() 41 | 42 | bias = False if norm == 'BN' else True 43 | self.transposed_conv2d = nn.ConvTranspose2d( 44 | in_channels, out_channels, kernel_size, stride=2, padding=padding, output_padding=1, bias=bias) 45 | 46 | if activation is not None: 47 | self.activation = getattr(torch, activation, 'relu') 48 | else: 49 | self.activation = None 50 | 51 | self.norm = norm 52 | if norm == 'BN': 53 | self.norm_layer = nn.BatchNorm2d(out_channels) 54 | elif norm == 'IN': 55 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) 56 | 57 | def forward(self, x): 58 | out = self.transposed_conv2d(x) 59 | 60 | if self.norm in ['BN', 'IN']: 61 | out = self.norm_layer(out) 62 | 63 | if self.activation is not None: 64 | out = self.activation(out) 65 | 66 | return out 67 | 68 | 69 | class UpsampleConvLayer(nn.Module): 70 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None): 71 | super(UpsampleConvLayer, self).__init__() 72 | 73 | bias = False if norm == 'BN' else True 74 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 75 | 76 | nn.init.xavier_uniform_(self.conv2d.weight) 77 | 78 | if activation is not None: 79 | self.activation = getattr(torch, activation, 'relu') 80 | else: 81 | self.activation = None 82 | 83 | self.norm = norm 84 | if norm == 'BN': 85 | self.norm_layer = nn.BatchNorm2d(out_channels) 86 | elif norm == 'IN': 87 | self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) 88 | 89 | def forward(self, x): 90 | x_upsampled = f.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) 91 | out = self.conv2d(x_upsampled) 92 | 93 | if self.norm in ['BN', 'IN']: 94 | out = self.norm_layer(out) 95 | 96 | if self.activation is not None: 97 | out = self.activation(out) 98 | 99 | return out 100 | 101 | 102 | class RecurrentConvLayer(nn.Module): 103 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, 104 | recurrent_block_type='convlstm', activation='relu', norm=None): 105 | super(RecurrentConvLayer, self).__init__() 106 | 107 | assert(recurrent_block_type in ['convlstm', 'convgru']) 108 | self.recurrent_block_type = recurrent_block_type 109 | if self.recurrent_block_type == 'convlstm': 110 | RecurrentBlock = ConvLSTM 111 | else: 112 | RecurrentBlock = ConvGRU 113 | self.conv = ConvLayer(in_channels, out_channels, kernel_size, stride, padding, activation, norm) 114 | self.recurrent_block = RecurrentBlock(input_size=out_channels, hidden_size=out_channels, kernel_size=3) 115 | 116 | def forward(self, x, prev_state): 117 | x = self.conv(x) 118 | state = self.recurrent_block(x, prev_state) 119 | x = state[0] if self.recurrent_block_type == 'convlstm' else state 120 | return x, state 121 | 122 | 123 | class DownsampleRecurrentConvLayer(nn.Module): 124 | def __init__(self, in_channels, out_channels, kernel_size=3, recurrent_block_type='convlstm', padding=0, activation='relu'): 125 | super(DownsampleRecurrentConvLayer, self).__init__() 126 | 127 | self.activation = getattr(torch, activation, 'relu') 128 | 129 | assert(recurrent_block_type in ['convlstm', 'convgru']) 130 | self.recurrent_block_type = recurrent_block_type 131 | if self.recurrent_block_type == 'convlstm': 132 | RecurrentBlock = ConvLSTM 133 | else: 134 | RecurrentBlock = ConvGRU 135 | self.recurrent_block = RecurrentBlock(input_size=in_channels, hidden_size=out_channels, kernel_size=kernel_size) 136 | 137 | def forward(self, x, prev_state): 138 | state = self.recurrent_block(x, prev_state) 139 | x = state[0] if self.recurrent_block_type == 'convlstm' else state 140 | x = f.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False) 141 | return self.activation(x), state 142 | 143 | 144 | # Residual block 145 | class ResidualBlock(nn.Module): 146 | def __init__(self, in_channels, out_channels, stride=1, downsample=None, norm=None): 147 | super(ResidualBlock, self).__init__() 148 | bias = False if norm == 'BN' else True 149 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=bias) 150 | self.norm = norm 151 | if norm == 'BN': 152 | self.bn1 = nn.BatchNorm2d(out_channels) 153 | self.bn2 = nn.BatchNorm2d(out_channels) 154 | elif norm == 'IN': 155 | self.bn1 = nn.InstanceNorm2d(out_channels) 156 | self.bn2 = nn.InstanceNorm2d(out_channels) 157 | 158 | self.relu = nn.ReLU(inplace=True) 159 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias) 160 | self.downsample = downsample 161 | 162 | nn.init.xavier_uniform_(self.conv1.weight) 163 | nn.init.xavier_uniform_(self.conv2.weight) 164 | 165 | def forward(self, x): 166 | residual = x 167 | out = self.conv1(x) 168 | if self.norm in ['BN', 'IN']: 169 | out = self.bn1(out) 170 | out = self.relu(out) 171 | out = self.conv2(out) 172 | if self.norm in ['BN', 'IN']: 173 | out = self.bn2(out) 174 | 175 | if self.downsample: 176 | residual = self.downsample(x) 177 | 178 | out += residual 179 | out = self.relu(out) 180 | return out 181 | 182 | 183 | class ConvLSTM(nn.Module): 184 | """Adapted from: https://github.com/Atcold/pytorch-CortexNet/blob/master/model/ConvLSTMCell.py """ 185 | 186 | def __init__(self, input_size, hidden_size, kernel_size): 187 | super(ConvLSTM, self).__init__() 188 | 189 | self.input_size = input_size 190 | self.hidden_size = hidden_size 191 | pad = kernel_size // 2 192 | 193 | # cache a tensor filled with zeros to avoid reallocating memory at each inference step if --no-recurrent is enabled 194 | self.zero_tensors = {} 195 | 196 | self.Gates = nn.Conv2d(input_size + hidden_size, 4 * hidden_size, kernel_size, padding=pad) 197 | 198 | def forward(self, input_, prev_state=None): 199 | 200 | # get batch and spatial sizes 201 | batch_size = input_.data.size()[0] 202 | spatial_size = input_.data.size()[2:] 203 | 204 | # generate empty prev_state, if None is provided 205 | if prev_state is None: 206 | 207 | # create the zero tensor if it has not been created already 208 | state_size = tuple([batch_size, self.hidden_size] + list(spatial_size)) 209 | if state_size not in self.zero_tensors: 210 | # allocate a tensor with size `spatial_size`, filled with zero (if it has not been allocated already) 211 | self.zero_tensors[state_size] = ( 212 | torch.zeros(state_size).to(input_.device), 213 | torch.zeros(state_size).to(input_.device) 214 | ) 215 | 216 | prev_state = self.zero_tensors[tuple(state_size)] 217 | 218 | prev_hidden, prev_cell = prev_state 219 | 220 | # data size is [batch, channel, height, width] 221 | stacked_inputs = torch.cat((input_, prev_hidden), 1) 222 | gates = self.Gates(stacked_inputs) 223 | 224 | # chunk across channel dimension 225 | in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1) 226 | 227 | # apply sigmoid non linearity 228 | in_gate = torch.sigmoid(in_gate) 229 | remember_gate = torch.sigmoid(remember_gate) 230 | out_gate = torch.sigmoid(out_gate) 231 | 232 | # apply tanh non linearity 233 | cell_gate = torch.tanh(cell_gate) 234 | 235 | # compute current cell and hidden state 236 | cell = (remember_gate * prev_cell) + (in_gate * cell_gate) 237 | hidden = out_gate * torch.tanh(cell) 238 | 239 | return hidden, cell 240 | 241 | 242 | class ConvGRU(nn.Module): 243 | """ 244 | Generate a convolutional GRU cell 245 | Adapted from: https://github.com/jacobkimmel/pytorch_convgru/blob/master/convgru.py 246 | """ 247 | 248 | def __init__(self, input_size, hidden_size, kernel_size): 249 | super().__init__() 250 | padding = kernel_size // 2 251 | self.input_size = input_size 252 | self.hidden_size = hidden_size 253 | self.reset_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) 254 | self.update_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) 255 | self.out_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) 256 | 257 | init.orthogonal_(self.reset_gate.weight) 258 | init.orthogonal_(self.update_gate.weight) 259 | init.orthogonal_(self.out_gate.weight) 260 | init.constant_(self.reset_gate.bias, 0.) 261 | init.constant_(self.update_gate.bias, 0.) 262 | init.constant_(self.out_gate.bias, 0.) 263 | 264 | def forward(self, input_, prev_state): 265 | 266 | # get batch and spatial sizes 267 | batch_size = input_.data.size()[0] 268 | spatial_size = input_.data.size()[2:] 269 | 270 | # generate empty prev_state, if None is provided 271 | if prev_state is None: 272 | state_size = [batch_size, self.hidden_size] + list(spatial_size) 273 | prev_state = torch.zeros(state_size).to(input_.device) 274 | 275 | # data size is [batch, channel, height, width] 276 | stacked_inputs = torch.cat([input_, prev_state], dim=1) 277 | update = torch.sigmoid(self.update_gate(stacked_inputs)) 278 | reset = torch.sigmoid(self.reset_gate(stacked_inputs)) 279 | out_inputs = torch.tanh(self.out_gate(torch.cat([input_, prev_state * reset], dim=1))) 280 | new_state = prev_state * (1 - update) + out_inputs * update 281 | 282 | return new_state 283 | -------------------------------------------------------------------------------- /networks/unet_rpg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as f 4 | import numpy as np 5 | from torch.nn import init 6 | from .submodules import ConvLayer, UpsampleConvLayer, TransposedConvLayer, RecurrentConvLayer, ResidualBlock, ConvLSTM, ConvGRU 7 | 8 | 9 | def skip_concat(x1, x2): 10 | return torch.cat([x1, x2], dim=1) 11 | 12 | 13 | def skip_sum(x1, x2): 14 | return torch.add(x1,x2) 15 | 16 | 17 | class BaseUNet(nn.Module): 18 | def __init__(self, num_input_channels, num_output_channels=1, skip_type='sum', activation='sigmoid', 19 | num_encoders=4, base_num_channels=32, num_residual_blocks=2, norm=None, use_upsample_conv=True): 20 | super(BaseUNet, self).__init__() 21 | 22 | self.num_input_channels = num_input_channels 23 | self.num_output_channels = num_output_channels 24 | self.skip_type = skip_type 25 | self.apply_skip_connection = skip_sum 26 | self.activation = activation 27 | self.norm = norm 28 | 29 | if use_upsample_conv: 30 | print('Using UpsampleConvLayer (slow, but no checkerboard artefacts)') 31 | self.UpsampleLayer = UpsampleConvLayer 32 | else: 33 | print('Using TransposedConvLayer (fast, with checkerboard artefacts)') 34 | self.UpsampleLayer = TransposedConvLayer 35 | 36 | self.num_encoders = num_encoders 37 | self.base_num_channels = base_num_channels 38 | self.num_residual_blocks = num_residual_blocks 39 | self.max_num_channels = self.base_num_channels * pow(2, self.num_encoders) 40 | 41 | assert(self.num_input_channels > 0) 42 | assert(self.num_output_channels > 0) 43 | 44 | self.encoder_input_sizes = [] 45 | for i in range(self.num_encoders): 46 | self.encoder_input_sizes.append(self.base_num_channels * pow(2, i)) 47 | 48 | self.encoder_output_sizes = [self.base_num_channels * pow(2, i + 1) for i in range(self.num_encoders)] 49 | 50 | self.activation = getattr(torch, self.activation, 'sigmoid') 51 | 52 | def build_resblocks(self): 53 | self.resblocks = nn.ModuleList() 54 | for i in range(self.num_residual_blocks): 55 | self.resblocks.append(ResidualBlock(self.max_num_channels, self.max_num_channels, norm=self.norm)) 56 | 57 | def build_decoders(self): 58 | decoder_input_sizes = list(reversed([self.base_num_channels * pow(2, i + 1) for i in range(self.num_encoders)])) 59 | 60 | self.decoders = nn.ModuleList() 61 | for input_size in decoder_input_sizes: 62 | self.decoders.append(self.UpsampleLayer(input_size if self.skip_type == 'sum' else 2 * input_size, 63 | input_size // 2, 64 | kernel_size=5, padding=2, norm=self.norm)) 65 | 66 | def build_prediction_layer(self): 67 | self.pred = ConvLayer(self.base_num_channels if self.skip_type == 'sum' else 2 * self.base_num_channels, 68 | self.num_output_channels, 1, activation=None, norm=self.norm) 69 | 70 | 71 | class UNet(BaseUNet): 72 | def __init__(self, num_input_channels, num_output_channels=1, skip_type='sum', activation='sigmoid', 73 | num_encoders=4, base_num_channels=32, num_residual_blocks=2, norm=None, use_upsample_conv=True): 74 | super(UNet, self).__init__(num_input_channels, num_output_channels, skip_type, activation, 75 | num_encoders, base_num_channels, num_residual_blocks, norm, use_upsample_conv) 76 | 77 | self.head = ConvLayer(self.num_input_channels, self.base_num_channels, 78 | kernel_size=5, stride=1, padding=2) # N x C x H x W -> N x 32 x H x W 79 | 80 | self.encoders = nn.ModuleList() 81 | for input_size, output_size in zip(self.encoder_input_sizes, self.encoder_output_sizes): 82 | self.encoders.append(ConvLayer(input_size, output_size, kernel_size=5, 83 | stride=2, padding=2, norm=self.norm)) 84 | 85 | self.build_resblocks() 86 | self.build_decoders() 87 | self.build_prediction_layer() 88 | 89 | def forward(self, x): 90 | """ 91 | :param x: N x num_input_channels x H x W 92 | :return: N x num_output_channels x H x W 93 | """ 94 | 95 | # head 96 | x = self.head(x) 97 | head = x 98 | 99 | # encoder 100 | blocks = [] 101 | for i, encoder in enumerate(self.encoders): 102 | x = encoder(x) 103 | blocks.append(x) 104 | 105 | # residual blocks 106 | for resblock in self.resblocks: 107 | x = resblock(x) 108 | 109 | # decoder 110 | for i, decoder in enumerate(self.decoders): 111 | x = decoder(self.apply_skip_connection(x, blocks[self.num_encoders - i - 1])) 112 | 113 | # img = self.activation(self.pred(self.apply_skip_connection(x, head))) 114 | img = self.pred(self.apply_skip_connection(x, head)) 115 | 116 | return img 117 | 118 | 119 | class UNetRecurrent(BaseUNet): 120 | """ 121 | Recurrent UNet architecture where every encoder is followed by a recurrent convolutional block, 122 | such as a ConvLSTM or a ConvGRU. 123 | Symmetric, skip connections on every encoding layer. 124 | """ 125 | 126 | def __init__(self, num_input_channels, num_output_channels=1, skip_type='sum', 127 | recurrent_block_type='convlstm', activation='sigmoid', num_encoders=4, base_num_channels=32, 128 | num_residual_blocks=2, norm=None, use_upsample_conv=True): 129 | super(UNetRecurrent, self).__init__(num_input_channels, num_output_channels, skip_type, activation, 130 | num_encoders, base_num_channels, num_residual_blocks, norm, 131 | use_upsample_conv) 132 | 133 | self.head = ConvLayer(self.num_input_channels, self.base_num_channels, 134 | kernel_size=5, stride=1, padding=2) # N x C x H x W -> N x 32 x H x W 135 | 136 | self.encoders = nn.ModuleList() 137 | for input_size, output_size in zip(self.encoder_input_sizes, self.encoder_output_sizes): 138 | self.encoders.append(RecurrentConvLayer(input_size, output_size, 139 | kernel_size=5, stride=2, padding=2, 140 | recurrent_block_type=recurrent_block_type, 141 | norm=self.norm)) 142 | 143 | self.build_resblocks() 144 | self.build_decoders() 145 | self.build_prediction_layer() 146 | 147 | def forward(self, x, prev_states): 148 | """ 149 | :param x: N x num_input_channels x H x W 150 | :param prev_states: previous LSTM states for every encoder layer 151 | :return: N x num_output_channels x H x W 152 | """ 153 | 154 | # head 155 | x = self.head(x) 156 | head = x 157 | 158 | if prev_states is None: 159 | prev_states = [None] * self.num_encoders 160 | 161 | # encoder 162 | blocks = [] 163 | states = [] 164 | for i, encoder in enumerate(self.encoders): 165 | x, state = encoder(x, prev_states[i]) 166 | blocks.append(x) 167 | states.append(state) 168 | 169 | # residual blocks 170 | for resblock in self.resblocks: 171 | x = resblock(x) 172 | 173 | # decoder 174 | for i, decoder in enumerate(self.decoders): 175 | x = decoder(self.apply_skip_connection(x, blocks[self.num_encoders - i - 1])) 176 | 177 | # tail 178 | # img = self.activation(self.pred(self.apply_skip_connectiomn(x, head))) 179 | img = self.pred(self.apply_skip_connection(x, head)) 180 | 181 | return img, states 182 | -------------------------------------------------------------------------------- /networks/vgg16.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Imagenet_matconvnet_vgg_verydeep_16_dag(nn.Module): 5 | 6 | def __init__(self): 7 | super().__init__() 8 | self.meta = {'mean': [122.74494171142578, 114.94409942626953, 101.64177703857422], 9 | 'std': [1, 1, 1], 10 | 'imageSize': [224, 224, 3]} 11 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 12 | self.relu1_1 = nn.ReLU() 13 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 14 | self.relu1_2 = nn.ReLU() 15 | self.pool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False) 16 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 17 | self.relu2_1 = nn.ReLU() 18 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 19 | self.relu2_2 = nn.ReLU() 20 | self.pool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False) 21 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 22 | self.relu3_1 = nn.ReLU() 23 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 24 | self.relu3_2 = nn.ReLU() 25 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 26 | self.relu3_3 = nn.ReLU() 27 | self.pool3 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False) 28 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 29 | self.relu4_1 = nn.ReLU() 30 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 31 | self.relu4_2 = nn.ReLU() 32 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 33 | self.relu4_3 = nn.ReLU() 34 | self.pool4 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False) 35 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 36 | self.relu5_1 = nn.ReLU() 37 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 38 | self.relu5_2 = nn.ReLU() 39 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1)) 40 | self.relu5_3 = nn.ReLU() 41 | self.pool5 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False) 42 | self.fc6 = nn.Conv2d(512, 4096, kernel_size=[7, 7], stride=(1, 1)) 43 | self.relu6 = nn.ReLU() 44 | self.fc7 = nn.Linear(in_features=4096, out_features=4096, bias=True) 45 | self.relu7 = nn.ReLU() 46 | self.fc8 = nn.Linear(in_features=4096, out_features=1000, bias=True) 47 | 48 | def forward(self, x0): 49 | x1 = self.conv1_1(x0) 50 | x2 = self.relu1_1(x1) 51 | x3 = self.conv1_2(x2) 52 | x4 = self.relu1_2(x3) 53 | x5 = self.pool1(x4) 54 | x6 = self.conv2_1(x5) 55 | x7 = self.relu2_1(x6) 56 | x8 = self.conv2_2(x7) 57 | x9 = self.relu2_2(x8) 58 | x10 = self.pool2(x9) 59 | x11 = self.conv3_1(x10) 60 | x12 = self.relu3_1(x11) 61 | x13 = self.conv3_2(x12) 62 | x14 = self.relu3_2(x13) 63 | x15 = self.conv3_3(x14) 64 | x16 = self.relu3_3(x15) 65 | x17 = self.pool3(x16) 66 | x18 = self.conv4_1(x17) 67 | x19 = self.relu4_1(x18) 68 | x20 = self.conv4_2(x19) 69 | x21 = self.relu4_2(x20) 70 | x22 = self.conv4_3(x21) 71 | x23 = self.relu4_3(x22) 72 | x24 = self.pool4(x23) 73 | x25 = self.conv5_1(x24) 74 | x26 = self.relu5_1(x25) 75 | x27 = self.conv5_2(x26) 76 | x28 = self.relu5_2(x27) 77 | x29 = self.conv5_3(x28) 78 | x30 = self.relu5_3(x29) 79 | x31 = self.pool5(x30) 80 | x32 = self.fc6(x31) 81 | x33_preflatten = self.relu6(x32) 82 | x33 = x33_preflatten.view(x33_preflatten.size(0), -1) 83 | x34 = self.fc7(x33) 84 | x35 = self.relu7(x34) 85 | x36 = self.fc8(x35) 86 | return x36 87 | 88 | def Imagenet_vgg(weights_path=None, **kwargs): 89 | """ 90 | load imported model instance 91 | 92 | Args: 93 | weights_path (str): If set, loads model weights from the given path 94 | """ 95 | model = Imagenet_matconvnet_vgg_verydeep_16_dag() 96 | if weights_path: 97 | state_dict = torch.load(weights_path) 98 | model.load_state_dict(state_dict) 99 | return model 100 | -------------------------------------------------------------------------------- /preprocess/process_dvs.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | import rosbag 4 | import rospy 5 | import ros_numpy 6 | import sys 7 | from sensor_msgs.msg import Image 8 | from cv_bridge import CvBridge, CvBridgeError 9 | import numpy as np 10 | import numpy.matlib 11 | import math 12 | import copy 13 | import cv2 14 | 15 | from pytictoc import TicToc 16 | from PIL import Image 17 | from pyproj import Proj, transform 18 | from pathlib import Path 19 | 20 | ## global variables 21 | t = TicToc() 22 | width = 0 23 | height = 0 24 | imglist = [] 25 | cvbdg = CvBridge() 26 | class allevents(): 27 | def __init__(self): 28 | self.x = np.zeros((0,),dtype=np.uint16) 29 | self.y = np.zeros((0,),dtype=np.uint16) 30 | self.p = np.zeros((0,),dtype=np.bool) 31 | self.t = np.zeros((0,),dtype=np.float32) 32 | self.nevents = 0 33 | 34 | def update_event_q(msg, evtlist): 35 | global width, height 36 | width = msg.width 37 | height = msg.height 38 | nevents = np.shape(msg.events)[0] 39 | x_arr = np.zeros((nevents,),dtype=np.uint16) 40 | y_arr = np.zeros((nevents,),dtype=np.uint16) 41 | p_arr = np.zeros((nevents,),dtype=np.bool) 42 | t_arr = np.zeros((nevents,),dtype=np.float64) 43 | for i in range(nevents): 44 | x_arr[i] = msg.events[i].x 45 | y_arr[i] = msg.events[i].y 46 | p_arr[i] = msg.events[i].polarity 47 | t_arr[i] = msg.events[i].ts.to_sec() 48 | order = np.argsort(t_arr) 49 | x_arr = x_arr[order] 50 | y_arr = y_arr[order] 51 | p_arr = p_arr[order] 52 | t_arr = t_arr[order] 53 | evtlist.x = np.concatenate((evtlist.x, x_arr)); 54 | evtlist.y = np.concatenate((evtlist.y, y_arr)); 55 | evtlist.p = np.concatenate((evtlist.p, p_arr)); 56 | evtlist.t = np.concatenate((evtlist.t, t_arr)); 57 | evtlist.nevents += nevents 58 | return evtlist 59 | 60 | def generate_event_img(evpath, evtlist, im_stamp): 61 | global width, height 62 | dt = 0.005 # 5ms 63 | tstamp = im_stamp 64 | dvs_img0 = np.zeros((height,width, 3), dtype=np.uint8) 65 | dvs_img1 = np.zeros((height, width, 3), dtype=np.uint8) 66 | dvs_img2 = np.zeros((height, width, 3), dtype=np.uint8) 67 | 68 | e_idx = abs(evtlist.t - tstamp)