├── .gitignore ├── Figure ├── ELEDNet_demo1_v2.gif └── ELEDNet_demo2_v2.gif ├── README.md ├── models ├── model_factory │ └── models_final.py ├── model_manager.py └── submodules.py ├── pretrained_model └── README.md ├── sample_data ├── blur_images │ ├── 00014.png │ ├── 00015.png │ └── 00016.png └── event_voxel │ ├── 00014.npz │ ├── 00015.npz │ └── 00016.npz ├── test_model.py ├── test_sample.py ├── train.py └── utils ├── dataloader.py ├── eval_metrics.py ├── make_train_dataset.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | experiments/ 2 | pretrained_model/ 3 | saved_img/ 4 | __pycache__/ 5 | *.pyc -------------------------------------------------------------------------------- /Figure/ELEDNet_demo1_v2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelpro/ELEDNet/036259bd29cd6465941f5adbc4205035d06d5408/Figure/ELEDNet_demo1_v2.gif -------------------------------------------------------------------------------- /Figure/ELEDNet_demo2_v2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelpro/ELEDNet/036259bd29cd6465941f5adbc4205035d06d5408/Figure/ELEDNet_demo2_v2.gif -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ELEDNet (ECCV 2024) 2 | 3 | Official repository for the ECCV 2024 paper, **"Towards Real-world Event-guided Low-light Video Enhancement and Deblurring."** 4 | 5 | [[Paper](https://arxiv.org/abs/2408.14916)] 6 | [[Supp](https://drive.google.com/file/d/1xBy29Iy3ae7V0YTasPGBbE9Xf6fNUX3L/view?usp=sharing)] 7 | 8 | 9 | ## Video Demos 10 | ![ELEDNet Demo_1](https://github.com/intelpro/ELEDNet/blob/main/Figure/ELEDNet_demo1_v2.gif) 11 | ![ELEDNet Demo_2](https://github.com/intelpro/ELEDNet/blob/main/Figure/ELEDNet_demo2_v2.gif) 12 | 13 | 14 | 15 | ## Downloading the RELED datasets 16 | Please download and unzip the RELED dataset. 17 | 18 | * [[RELED-Train](https://drive.google.com/file/d/1SiUTEOm6ZrLgXnh2t1LeUqy0xDjiubH6/view?usp=drive_link)] / [[RELED-Test](https://drive.google.com/file/d/18XXfjZ59rQulFRH18UNHI9Gm0ZRGeJwN/view)] 19 | 20 | The dataset follows the below directory format: 21 | ``` 22 | ├── RELED/ 23 | ├── train/ 24 | │ ├── 0000/ 25 | │ │ ├── blur_processed/ 26 | │ │ │ ├── 00000.png 27 | │ │ │ ├── ... 28 | │ │ │ └── 00148.png 29 | │ │ ├── gt_processed/ 30 | │ │ │ ├── 00000.png 31 | │ │ │ ├── ... 32 | │ │ │ └── 00148.png 33 | │ │ ├── events/ 34 | │ │ │ ├── 00000.npz 35 | │ │ │ ├── ... 36 | │ │ │ └── 00148.npz 37 | │ │ └── event_voxel/ 38 | │ │ ├── 00000.npz 39 | │ │ ├── ... 40 | │ │ └── 00148.npz 41 | │ ├── 0001/ 42 | │ │ ├── ... 43 | ├── test/ 44 | │ ├── 0000/ 45 | │ │ ├── ... 46 | │ ├── 0001/ 47 | │ │ ├── ... 48 | ``` 49 | 50 | Sub-directory Descriptions: 51 | - **blur_processed**: Contains low-light blurred images (`*.png` files). 52 | - **gt_processed**: Contains normal-light sharp images (`*.png` files). 53 | - **events**: Contains raw event data in `.npz` format. 54 | - **event_voxel**: Contains event voxel data in `.npz` format. 55 | 56 | Reading Raw Event Data (`events`) and Event Voxel Data(`event_voxel`): 57 | 58 | To read `event` and `event voxel` data from `.npz` files using Python and NumPy: 59 | 60 | ```python 61 | import numpy as np 62 | 63 | # Replace YOUR_EVENT_DIR with the path to the directory containing the .npz files for events 64 | event_data = np.load('YOUR_EVENT_DIR/*.npz')['data'] 65 | ``` 66 | 67 | ## Requirements 68 | * PyTorch 1.9 69 | * CUDA 11.2 70 | * python 3.7 71 | 72 | ## Quick Train model 73 | 74 | Download repository: 75 | 76 | ``` bash 77 | $ git clone https://github.com/intelpro/ELEDNet 78 | ``` 79 | 80 | 81 | If you want to start training our model, you need to preprocess the raw dataset first. 82 | 83 | Run the following command to preprocess the dataset: 84 | 85 | ```bash 86 | $ python utils/make_train_dataset --train_data_dir TRAIN_DATASET_DIR 87 | ``` 88 | 89 | - ``--train_data_dir TRAIN_DATASET_DIR``: Specifies the directory containing the **training dataset** of the **RELED dataset**. 90 | - The process **divides the blur, event voxel, and ground truth (GT) data into four parts** to enhance training speed. 91 | 92 | Once preprocessing is complete, you can proceed to the model training step. 93 | 94 | ```bash 95 | $ python train.py --data_dir DATSET_DIR 96 | ``` 97 | 98 | - `--data_dir DATSET_DIR`: Specifies the directory containing the complete RELED dataset, including both training and test sets. 99 | 100 | ## Quick Test model 101 | 102 | Download repository: 103 | 104 | ``` bash 105 | $ git clone https://github.com/intelpro/ELEDNet 106 | ``` 107 | 108 | Download the network weights (trained on the RELED dataset) and place the downloaded model inside the `./pretrained_model` directory. 109 | 110 | 🔗 **[Download Ours Weights](https://drive.google.com/file/d/1sJiaIMOrt2Vs931FOjsx4oHROrEzNu2u/view?usp=sharing)** 111 | 112 | ``` bash 113 | # Ensure the directory exists 114 | mkdir -p pretrained_model 115 | 116 | # Move the downloaded model to the correct location 117 | mv /path/to/downloaded/Ours_RELED.pth ./pretrained_model/ 118 | ``` 119 | 120 | Generate output images using our model and sample data provided in this repository. 121 | 122 | ``` bash 123 | $ python test_sample.py --resume_ckpt True --ckpt_dir ./pretrained_model/Ours_RELED.pth 124 | ``` 125 | 126 | ## Test model 127 | 128 | If you want to test full RELED dataset, please generate output images using following command 129 | 130 | ```bash 131 | 132 | $ python test_model.py --data_dir RELED_PATH --resume_ckpt True --ckpt_dir PATH_CKPT --saved_dir SAVED_DIR 133 | 134 | ``` 135 | 136 | - `--data_dir PATH_RELED` : Path to the RELED dataset 137 | - `--resume_ckpt True` : Enables loading of a pretrained model checkpoint. 138 | - `--ckpt_dir PATH_CKPT` : Path to the pretrained checkpoint file. 139 | - `--saved_dir SAVED_DIR` : Directory where output images will be saved. 140 | 141 | 142 | 143 | ## Reference 144 | > Taewoo Kim, Jaeseok Jeong, Hoonhee Cho, Yuhwan Jeong, and Kuk-Jin Yoon, **"Towards Real-World Event-guided Low-Light Video Enhancement and Deblurring,"** In *ECCV*, 2024. 145 | ```bibtex 146 | @inproceedings{kim2024towards, 147 | title={Towards Real-world Event-guided Low-light Video Enhancement and Deblurring}, 148 | author={Kim, Taewoo and Jeong, Jaeseok and Cho, Hoonhee and Jeong, Yuhwan and Yoon, Kuk-Jin}, 149 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, 150 | pages={433--451}, 151 | year={2024}, 152 | publisher={Springer} 153 | } 154 | ``` 155 | 156 | ## Contact 157 | If you have any question, please send an email to taewoo(an625148@gmail.com) 158 | 159 | ## License 160 | The project codes and datasets can be used for research and education only. -------------------------------------------------------------------------------- /models/model_factory/models_final.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import argparse 5 | from collections import namedtuple 6 | from string import Template 7 | from torch.autograd import Function 8 | from torch.nn.modules.utils import _pair 9 | from torchvision.ops import DeformConv2d 10 | from thop import profile 11 | from einops import rearrange 12 | import cupy 13 | import numbers 14 | from models.submodules import * 15 | 16 | 17 | def conv1x1(in_channels, out_channels, stride=1): 18 | return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=True) 19 | 20 | 21 | def conv3x3(in_channels, out_channels, stride=1): 22 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True) 23 | 24 | 25 | def conv5x5(in_channels, out_channels, stride=1): 26 | return nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=stride, padding=2, bias=True) 27 | 28 | 29 | def deconv4x4(in_channels, out_channels, stride=2): 30 | return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=stride, padding=1) 31 | 32 | 33 | def deconv5x5(in_channels, out_channels, stride=2): 34 | return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=5, stride=stride, padding=2, output_padding=1) 35 | 36 | def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1): 37 | return nn.Conv2d( 38 | in_channels, out_channels, kernel_size, 39 | padding=(kernel_size//2), bias=bias, stride = stride) 40 | 41 | 42 | class ResBlock(nn.Module): 43 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, groups=1): 44 | super(ResBlock, self).__init__() 45 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, stride=stride, 46 | padding=get_same_padding(kernel_size, dilation), dilation=dilation, groups=groups) 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, stride=1, 48 | padding=get_same_padding(kernel_size, dilation), dilation=dilation, groups=groups) 49 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 50 | 51 | self.res_translate = None 52 | if not inplanes == planes or not stride == 1: 53 | self.res_translate = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride) 54 | 55 | 56 | def forward(self, x): 57 | residual = x 58 | 59 | out = self.relu(self.conv1(x)) 60 | out = self.conv2(out) 61 | 62 | if self.res_translate is not None: 63 | residual = self.res_translate(residual) 64 | out += residual 65 | 66 | return out 67 | 68 | class DownSample(nn.Module): 69 | def __init__(self, in_channels, s_factor): 70 | super(DownSample, self).__init__() 71 | self.down = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False), 72 | nn.Conv2d(in_channels, in_channels + s_factor, 1, stride=1, padding=0, bias=False)) 73 | 74 | def forward(self, x): 75 | x = self.down(x) 76 | return x 77 | 78 | 79 | class UpSample(nn.Module): 80 | def __init__(self, in_channels, s_factor): 81 | super(UpSample, self).__init__() 82 | self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 83 | nn.Conv2d(in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False)) 84 | 85 | def forward(self, x): 86 | x = self.up(x) 87 | return x 88 | 89 | 90 | class SkipUpSample(nn.Module): 91 | def __init__(self, in_channels, s_factor): 92 | super(SkipUpSample, self).__init__() 93 | self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 94 | nn.Conv2d(in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False)) 95 | 96 | def forward(self, x, y): 97 | x = self.up(x) 98 | x = x + y 99 | return x 100 | 101 | 102 | # Channel Attention Layer 103 | class CALayer(nn.Module): 104 | def __init__(self, channel, reduction=16, bias=False): 105 | super(CALayer, self).__init__() 106 | # global average pooling: feature --> point 107 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 108 | # feature channel downscale and upscale --> channel weight 109 | self.conv_du = nn.Sequential( 110 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias), 111 | nn.ReLU(inplace=True), 112 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), 113 | nn.Sigmoid() 114 | ) 115 | 116 | def forward(self, x): 117 | y = self.avg_pool(x) 118 | y = self.conv_du(y) 119 | return x * y 120 | 121 | ## Channel Attention Block (CAB) 122 | class CAB(nn.Module): 123 | def __init__(self, n_feat, kernel_size, reduction, bias, act): 124 | super(CAB, self).__init__() 125 | modules_body = [] 126 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 127 | modules_body.append(act) 128 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 129 | 130 | self.CA = CALayer(n_feat, reduction, bias=bias) 131 | self.body = nn.Sequential(*modules_body) 132 | 133 | def forward(self, x): 134 | res = self.body(x) 135 | res = self.CA(res) 136 | res += x 137 | return res 138 | 139 | ## Original Resolution Block (ORB) 140 | class CABs(nn.Module): 141 | def __init__(self, n_feat, kernel_size, reduction, act, bias, num_cab): 142 | super(CABs, self).__init__() 143 | modules_body = [] 144 | modules_body = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(num_cab)] 145 | modules_body.append(conv(n_feat, n_feat, kernel_size)) 146 | self.body = nn.Sequential(*modules_body) 147 | 148 | def forward(self, x): 149 | res = self.body(x) 150 | res += x 151 | return res 152 | 153 | # RDB-based RNN cell 154 | class shallow_cell(nn.Module): 155 | def __init__(self, n_feat): 156 | super(shallow_cell, self).__init__() 157 | self.n_feats = n_feat 158 | act = nn.PReLU() 159 | bias = False 160 | reduction = 4 161 | self.shallow_feat = nn.Sequential(conv(3, self.n_feats, 3, bias=bias), 162 | CAB(self.n_feats, 3, reduction, bias=bias, act=act)) 163 | 164 | def forward(self,x): 165 | feat = self.shallow_feat(x) 166 | return feat 167 | 168 | 169 | # RDB-based RNN cell 170 | class shallow_cell_events(nn.Module): 171 | def __init__(self, n_feat): 172 | super(shallow_cell_events, self).__init__() 173 | self.n_feats = n_feat 174 | act = nn.PReLU() 175 | bias = False 176 | reduction = 4 177 | self.shallow_feat = nn.Sequential(conv(16, self.n_feats, 3, bias=bias), 178 | CAB(self.n_feats, 3, reduction, bias=bias, act=act)) 179 | 180 | def forward(self,x): 181 | feat = self.shallow_feat(x) 182 | return feat 183 | 184 | def conv_down(in_chn, out_chn, bias=False): 185 | layer = nn.Conv2d(in_chn, out_chn, kernel_size=4, stride=2, padding=1, bias=bias) 186 | return layer 187 | 188 | 189 | def to_3d(x): 190 | return rearrange(x, 'b c h w -> b (h w) c') 191 | 192 | def to_4d(x,h,w): 193 | return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) 194 | 195 | class BiasFree_LayerNorm(nn.Module): 196 | def __init__(self, normalized_shape): 197 | super(BiasFree_LayerNorm, self).__init__() 198 | if isinstance(normalized_shape, numbers.Integral): 199 | normalized_shape = (normalized_shape,) 200 | normalized_shape = torch.Size(normalized_shape) 201 | 202 | assert len(normalized_shape) == 1 203 | 204 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 205 | self.normalized_shape = normalized_shape 206 | 207 | def forward(self, x): 208 | sigma = x.var(-1, keepdim=True, unbiased=False) 209 | return x / torch.sqrt(sigma+1e-5) * self.weight 210 | 211 | class WithBias_LayerNorm(nn.Module): 212 | def __init__(self, normalized_shape): 213 | super(WithBias_LayerNorm, self).__init__() 214 | if isinstance(normalized_shape, numbers.Integral): 215 | normalized_shape = (normalized_shape,) 216 | normalized_shape = torch.Size(normalized_shape) 217 | assert len(normalized_shape) == 1 218 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 219 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 220 | self.normalized_shape = normalized_shape 221 | 222 | def forward(self, x): 223 | mu = x.mean(-1, keepdim=True) 224 | sigma = x.var(-1, keepdim=True, unbiased=False) 225 | return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias 226 | 227 | class LayerNorm(nn.Module): 228 | def __init__(self, dim, LayerNorm_type): 229 | super(LayerNorm, self).__init__() 230 | if LayerNorm_type =='BiasFree': 231 | self.body = BiasFree_LayerNorm(dim) 232 | else: 233 | self.body = WithBias_LayerNorm(dim) 234 | 235 | def forward(self, x): 236 | h, w = x.shape[-2:] 237 | return to_4d(self.body(to_3d(x)), h, w) 238 | 239 | Stream = namedtuple('Stream', ['ptr']) 240 | 241 | def Dtype(t): 242 | if isinstance(t, torch.cuda.FloatTensor): 243 | return 'float' 244 | elif isinstance(t, torch.cuda.DoubleTensor): 245 | return 'double' 246 | 247 | 248 | # @cupy._util.memoize(for_each_device=True) 249 | def load_kernel(kernel_name, code, **kwargs): 250 | code = Template(code).substitute(**kwargs) 251 | kernel_code = cupy.cuda.compile_with_cache(code) 252 | return kernel_code.get_function(kernel_name) 253 | 254 | 255 | CUDA_NUM_THREADS = 1024 256 | 257 | kernel_loop = ''' 258 | #define CUDA_KERNEL_LOOP(i, n) \ 259 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 260 | i < (n); \ 261 | i += blockDim.x * gridDim.x) 262 | ''' 263 | 264 | 265 | def GET_BLOCKS(N): 266 | return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS 267 | 268 | 269 | _idynamic_kernel = kernel_loop + ''' 270 | extern "C" 271 | __global__ void idynamic_forward_kernel( 272 | const ${Dtype}* bottom_data, const ${Dtype}* weight_data, ${Dtype}* top_data) { 273 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 274 | const int n = index / ${channels} / ${top_height} / ${top_width}; 275 | const int c = (index / ${top_height} / ${top_width}) % ${channels}; 276 | const int h = (index / ${top_width}) % ${top_height}; 277 | const int w = index % ${top_width}; 278 | const int g = c / (${channels} / ${groups}); 279 | ${Dtype} value = 0; 280 | #pragma unroll 281 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 282 | #pragma unroll 283 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 284 | const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h}; 285 | const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w}; 286 | if ((h_in >= 0) && (h_in < ${bottom_height}) 287 | && (w_in >= 0) && (w_in < ${bottom_width})) { 288 | const int offset = ((n * ${channels} + c) * ${bottom_height} + h_in) 289 | * ${bottom_width} + w_in; 290 | const int offset_weight = ((((n * ${groups} + g) * ${kernel_h} + kh) * ${kernel_w} + kw) * ${top_height} + h) 291 | * ${top_width} + w; 292 | value += weight_data[offset_weight] * bottom_data[offset]; 293 | } 294 | } 295 | } 296 | top_data[index] = value; 297 | } 298 | } 299 | ''' 300 | 301 | _idynamic_kernel_backward_grad_input = kernel_loop + ''' 302 | extern "C" 303 | __global__ void idynamic_backward_grad_input_kernel( 304 | const ${Dtype}* const top_diff, const ${Dtype}* const weight_data, ${Dtype}* const bottom_diff) { 305 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 306 | const int n = index / ${channels} / ${bottom_height} / ${bottom_width}; 307 | const int c = (index / ${bottom_height} / ${bottom_width}) % ${channels}; 308 | const int h = (index / ${bottom_width}) % ${bottom_height}; 309 | const int w = index % ${bottom_width}; 310 | const int g = c / (${channels} / ${groups}); 311 | ${Dtype} value = 0; 312 | #pragma unroll 313 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 314 | #pragma unroll 315 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 316 | const int h_out_s = h + ${pad_h} - kh * ${dilation_h}; 317 | const int w_out_s = w + ${pad_w} - kw * ${dilation_w}; 318 | if (((h_out_s % ${stride_h}) == 0) && ((w_out_s % ${stride_w}) == 0)) { 319 | const int h_out = h_out_s / ${stride_h}; 320 | const int w_out = w_out_s / ${stride_w}; 321 | if ((h_out >= 0) && (h_out < ${top_height}) 322 | && (w_out >= 0) && (w_out < ${top_width})) { 323 | const int offset = ((n * ${channels} + c) * ${top_height} + h_out) 324 | * ${top_width} + w_out; 325 | const int offset_weight = ((((n * ${groups} + g) * ${kernel_h} + kh) * ${kernel_w} + kw) * ${top_height} + h_out) 326 | * ${top_width} + w_out; 327 | value += weight_data[offset_weight] * top_diff[offset]; 328 | } 329 | } 330 | } 331 | } 332 | bottom_diff[index] = value; 333 | } 334 | } 335 | ''' 336 | 337 | _idynamic_kernel_backward_grad_weight = kernel_loop + ''' 338 | extern "C" 339 | __global__ void idynamic_backward_grad_weight_kernel( 340 | const ${Dtype}* const top_diff, const ${Dtype}* const bottom_data, ${Dtype}* const buffer_data) { 341 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 342 | const int h = (index / ${top_width}) % ${top_height}; 343 | const int w = index % ${top_width}; 344 | const int kh = (index / ${kernel_w} / ${top_height} / ${top_width}) 345 | % ${kernel_h}; 346 | const int kw = (index / ${top_height} / ${top_width}) % ${kernel_w}; 347 | const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h}; 348 | const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w}; 349 | if ((h_in >= 0) && (h_in < ${bottom_height}) 350 | && (w_in >= 0) && (w_in < ${bottom_width})) { 351 | const int g = (index / ${kernel_h} / ${kernel_w} / ${top_height} / ${top_width}) % ${groups}; 352 | const int n = (index / ${groups} / ${kernel_h} / ${kernel_w} / ${top_height} / ${top_width}) % ${num}; 353 | ${Dtype} value = 0; 354 | #pragma unroll 355 | for (int c = g * (${channels} / ${groups}); c < (g + 1) * (${channels} / ${groups}); ++c) { 356 | const int top_offset = ((n * ${channels} + c) * ${top_height} + h) 357 | * ${top_width} + w; 358 | const int bottom_offset = ((n * ${channels} + c) * ${bottom_height} + h_in) 359 | * ${bottom_width} + w_in; 360 | value += top_diff[top_offset] * bottom_data[bottom_offset]; 361 | } 362 | buffer_data[index] = value; 363 | } else { 364 | buffer_data[index] = 0; 365 | } 366 | } 367 | } 368 | ''' 369 | 370 | class _idynamic(Function): 371 | @staticmethod 372 | def forward(ctx, input, weight, stride, padding, dilation): 373 | assert input.dim() == 4 and input.is_cuda 374 | assert weight.dim() == 6 and weight.is_cuda 375 | batch_size, channels, height, width = input.size() 376 | kernel_h, kernel_w = weight.size()[2:4] 377 | output_h = int((height + 2 * padding[0] - (dilation[0] * (kernel_h - 1) + 1)) / stride[0] + 1) 378 | output_w = int((width + 2 * padding[1] - (dilation[1] * (kernel_w - 1) + 1)) / stride[1] + 1) 379 | 380 | output = input.new(batch_size, channels, output_h, output_w) 381 | n = output.numel() 382 | 383 | with torch.cuda.device_of(input): 384 | f = load_kernel('idynamic_forward_kernel', _idynamic_kernel, Dtype=Dtype(input), nthreads=n, 385 | num=batch_size, channels=channels, groups=weight.size()[1], 386 | bottom_height=height, bottom_width=width, 387 | top_height=output_h, top_width=output_w, 388 | kernel_h=kernel_h, kernel_w=kernel_w, 389 | stride_h=stride[0], stride_w=stride[1], 390 | dilation_h=dilation[0], dilation_w=dilation[1], 391 | pad_h=padding[0], pad_w=padding[1]) 392 | f(block=(CUDA_NUM_THREADS, 1, 1), 393 | grid=(GET_BLOCKS(n), 1, 1), 394 | args=[input.data_ptr(), weight.data_ptr(), output.data_ptr()], 395 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 396 | 397 | ctx.save_for_backward(input, weight) 398 | ctx.stride, ctx.padding, ctx.dilation = stride, padding, dilation 399 | return output 400 | 401 | @staticmethod 402 | def backward(ctx, grad_output): 403 | assert grad_output.is_cuda 404 | if not grad_output.is_contiguous(): 405 | grad_output.contiguous() 406 | input, weight = ctx.saved_tensors 407 | stride, padding, dilation = ctx.stride, ctx.padding, ctx.dilation 408 | 409 | batch_size, channels, height, width = input.size() 410 | kernel_h, kernel_w = weight.size()[2:4] 411 | output_h, output_w = grad_output.size()[2:] 412 | 413 | grad_input, grad_weight = None, None 414 | 415 | opt = dict(Dtype=Dtype(grad_output), 416 | num=batch_size, channels=channels, groups=weight.size()[1], 417 | bottom_height=height, bottom_width=width, 418 | top_height=output_h, top_width=output_w, 419 | kernel_h=kernel_h, kernel_w=kernel_w, 420 | stride_h=stride[0], stride_w=stride[1], 421 | dilation_h=dilation[0], dilation_w=dilation[1], 422 | pad_h=padding[0], pad_w=padding[1]) 423 | 424 | with torch.cuda.device_of(input): 425 | if ctx.needs_input_grad[0]: 426 | grad_input = input.new(input.size()) 427 | 428 | n = grad_input.numel() 429 | opt['nthreads'] = n 430 | 431 | f = load_kernel('idynamic_backward_grad_input_kernel', 432 | _idynamic_kernel_backward_grad_input, **opt) 433 | f(block=(CUDA_NUM_THREADS, 1, 1), 434 | grid=(GET_BLOCKS(n), 1, 1), 435 | args=[grad_output.data_ptr(), weight.data_ptr(), grad_input.data_ptr()], 436 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 437 | 438 | if ctx.needs_input_grad[1]: 439 | grad_weight = weight.new(weight.size()) 440 | 441 | n = grad_weight.numel() 442 | opt['nthreads'] = n 443 | 444 | f = load_kernel('idynamic_backward_grad_weight_kernel', 445 | _idynamic_kernel_backward_grad_weight, **opt) 446 | f(block=(CUDA_NUM_THREADS, 1, 1), 447 | grid=(GET_BLOCKS(n), 1, 1), 448 | args=[grad_output.data_ptr(), input.data_ptr(), grad_weight.data_ptr()], 449 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 450 | 451 | return grad_input, grad_weight, None, None, None 452 | 453 | 454 | def _idynamic_cuda(input, weight, bias=None, stride=1, padding=0, dilation=1): 455 | """ idynamic kernel 456 | """ 457 | assert input.size(0) == weight.size(0) 458 | assert input.size(-2) // stride == weight.size(-2) 459 | assert input.size(-1) // stride == weight.size(-1) 460 | if input.is_cuda: 461 | out = _idynamic.apply(input, weight, _pair(stride), _pair(padding), _pair(dilation)) 462 | if bias is not None: 463 | out += bias.view(1, -1, 1, 1) 464 | else: 465 | raise NotImplementedError 466 | return out 467 | 468 | 469 | 470 | 471 | ############################### 472 | # ResNet 473 | ############################### 474 | 475 | 476 | def get_same_padding(kernel_size, dilation): 477 | kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) 478 | padding = (kernel_size - 1) // 2 479 | return padding 480 | 481 | 482 | class ResBlock(nn.Module): 483 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, groups=1): 484 | super(ResBlock, self).__init__() 485 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, stride=stride, 486 | padding=get_same_padding(kernel_size, dilation), dilation=dilation, groups=groups) 487 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, stride=1, 488 | padding=get_same_padding(kernel_size, dilation), dilation=dilation, groups=groups) 489 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 490 | 491 | self.res_translate = None 492 | if not inplanes == planes or not stride == 1: 493 | self.res_translate = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride) 494 | 495 | 496 | def forward(self, x): 497 | residual = x 498 | 499 | out = self.relu(self.conv1(x)) 500 | out = self.conv2(out) 501 | 502 | if self.res_translate is not None: 503 | residual = self.res_translate(residual) 504 | out += residual 505 | 506 | return out 507 | 508 | 509 | class Encoder(nn.Module): 510 | def __init__(self, n_feat, scale_unetfeats, kernel_size=3, reduction=4, bias=False): 511 | super(Encoder, self).__init__() 512 | act = nn.PReLU() 513 | self.encoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 514 | self.encoder_level2 = [CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 515 | self.encoder_level3 = [CAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 516 | 517 | self.encoder_level1 = nn.Sequential(*self.encoder_level1) 518 | self.encoder_level2 = nn.Sequential(*self.encoder_level2) 519 | self.encoder_level3 = nn.Sequential(*self.encoder_level3) 520 | 521 | self.down12 = DownSample(n_feat, scale_unetfeats) 522 | self.down23 = DownSample(n_feat + scale_unetfeats, scale_unetfeats) 523 | 524 | def forward(self, x): 525 | ### level 1 526 | enc1 = self.encoder_level1(x) 527 | x = self.down12(enc1) 528 | ### level 2 529 | enc2 = self.encoder_level2(x) 530 | x = self.down23(enc2) 531 | ### level 3 532 | enc3 = self.encoder_level3(x) 533 | return [enc1, enc2, enc3] 534 | 535 | 536 | class IDynamicDWConv(nn.Module): 537 | def __init__(self, channels, kernel_size, group_channels, down, conv_group): 538 | super(IDynamicDWConv, self).__init__() 539 | self.kernel_size = kernel_size 540 | self.channels = channels 541 | self.group_channels = group_channels 542 | self.down = down 543 | self.groups = self.channels // self.group_channels 544 | self.avgpool = nn.AvgPool2d(kernel_size=down, stride=down) 545 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 546 | Block1 = [ResBlock(channels, channels, kernel_size=kernel_size, stride=1, groups=conv_group) for _ in range(3)] 547 | Block2 = [ResBlock(channels, channels, kernel_size=kernel_size, stride=1, groups=conv_group) for _ in range(3)] 548 | self.tokernel = nn.Conv2d(channels, kernel_size**2*self.groups, 1, 1, 0) 549 | self.Block1 = nn.Sequential(*Block1) 550 | self.Block2 = nn.Sequential(*Block2) 551 | 552 | def forward(self, x, y): 553 | weight = self.tokernel(self.Block2(self.maxpool(self.Block1(self.avgpool(y))))) 554 | weight = F.interpolate(weight, scale_factor=2*self.down, mode='bilinear') 555 | b, c, h, w = weight.shape 556 | weight = weight.view(b, self.groups, self.kernel_size, self.kernel_size, h, w) 557 | out = _idynamic_cuda(x, weight, stride=1, padding=(self.kernel_size - 1) // 2) 558 | return out 559 | 560 | 561 | 562 | 563 | class ffn_align(nn.Module): 564 | def __init__(self, dim, ffn_expansion_factor, bias): 565 | super(ffn_align, self).__init__() 566 | hidden_features = int(dim*ffn_expansion_factor) 567 | self.project_in_f = nn.Conv3d(dim, hidden_features, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=bias) 568 | self.project_in_ev = nn.Conv3d(dim, hidden_features, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0), bias=bias) 569 | ## 570 | self.conv1_prev = nn.Conv2d(hidden_features*3, hidden_features, 3, 1, 1, bias=bias) 571 | self.conv1_future = nn.Conv2d(hidden_features*3, hidden_features, 3, 1, 1, bias=bias) 572 | self.resblock_forward = ResidualBlocks2D(hidden_features, 3) 573 | self.resblock_backward = ResidualBlocks2D(hidden_features, 3) 574 | self.conv_prop = nn.Conv2d(hidden_features*3, hidden_features, 3, 1, 1,bias=bias) 575 | self.resblock_prop = ResidualBlocks2D(hidden_features, 3) 576 | ## align-ment 577 | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=3, stride=1, padding=1, bias=bias) 578 | 579 | def forward(self, x_f, x_ev): 580 | bs, t, c, h, w = x_f.shape 581 | x_f = self.project_in_f(x_f) 582 | x_ev = self.project_in_ev(x_ev) 583 | x_f = rearrange(x_f, 'b c t h w -> b t c h w', b=bs) 584 | x_ev = rearrange(x_ev, 'b c t h w -> b t c h w', b=bs) 585 | ### propagation 586 | ## prev 587 | x_f_prev = x_f[:, 0] 588 | x_ev_prev = torch.cat((x_ev[:, 0], x_ev[:, 1]), dim=1) 589 | input_prev = torch.cat((x_f_prev, x_ev_prev), dim=1) 590 | prop_prev = self.conv1_prev(input_prev) 591 | prop_prev = self.resblock_forward(prop_prev) 592 | ## future 593 | x_f_future = x_f[:, 2] 594 | x_ev_future = torch.cat((x_ev[:, 1], x_ev[:, 2]), dim=1) 595 | input_future = torch.cat((x_f_future, x_ev_future), dim=1) 596 | prop_future = self.conv1_future(input_future) 597 | prop_future = self.resblock_backward(prop_future) 598 | ## cur 599 | x_f_cur = x_f[:, 1] 600 | prop_input = torch.cat((x_f_cur, prop_future, prop_prev), dim=1) 601 | prop_out = self.conv_prop(prop_input) 602 | prop_out = self.resblock_prop(prop_out) 603 | return prop_out 604 | 605 | class alignment_layer(nn.Module): 606 | def __init__(self, dim, ffn_expansion_factor, bias, LayerNorm_type): 607 | super(alignment_layer, self).__init__() 608 | self.norm_frame = LayerNorm(dim, LayerNorm_type) 609 | self.norm_event = LayerNorm(dim, LayerNorm_type) 610 | self.alignment_ffn = ffn_align(dim, ffn_expansion_factor, bias) 611 | 612 | def forward(self, x_f, x_ev): 613 | b = x_f.shape[0] 614 | x_f = self.norm_frame(rearrange(x_f, 'b t c h w -> (b t) c h w', b=b)) 615 | x_ev = self.norm_event(rearrange(x_ev, 'b t c h w -> (b t) c h w', b=b)) 616 | x_f_re = rearrange(x_f, '(b t) c h w -> b c t h w', b=b) 617 | x_ev_re = rearrange(x_ev, '(b t) c h w -> b c t h w', b=b) 618 | aligned_feat = self.alignment_ffn(x_f_re, x_ev_re) 619 | return aligned_feat 620 | 621 | 622 | ########################################################################## 623 | ## Gated-Dconv Feed-Forward Network (GDFN) 624 | class FeedForward(nn.Module): 625 | def __init__(self, dim, ffn_expansion_factor, bias): 626 | super(FeedForward, self).__init__() 627 | hidden_features = int(dim*ffn_expansion_factor) 628 | self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) 629 | self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, 630 | padding=1, groups=hidden_features*2, bias=bias) 631 | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) 632 | 633 | def forward(self, x): 634 | x = self.project_in(x) 635 | x1, x2 = self.dwconv(x).chunk(2, dim=1) 636 | x = F.gelu(x1) * x2 637 | x = self.project_out(x) 638 | return x 639 | 640 | 641 | 642 | ########################################################################## 643 | ## Multi-DConv Head Transposed Self-Attention (MDTA) 644 | class Attention(nn.Module): 645 | def __init__(self, dim, num_heads, stride, bias): 646 | super(Attention, self).__init__() 647 | self.num_heads = num_heads 648 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 649 | 650 | self.stride = stride 651 | self.qk = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias) 652 | self.qk_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=self.stride, padding=1, groups=dim*2, bias=bias) 653 | 654 | self.v = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) 655 | self.v_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) 656 | 657 | self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) 658 | 659 | def forward(self, x): 660 | b,c,h,w = x.shape 661 | 662 | qk = self.qk_dwconv(self.qk(x)) 663 | q,k = qk.chunk(2, dim=1) 664 | 665 | v = self.v_dwconv(self.v(x)) 666 | 667 | b, f, h1, w1 = q.size() 668 | 669 | q = rearrange(q, 'b (head c) h1 w1 -> b head c (h1 w1)', head=self.num_heads) 670 | k = rearrange(k, 'b (head c) h1 w1 -> b head c (h1 w1)', head=self.num_heads) 671 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 672 | 673 | q = torch.nn.functional.normalize(q, dim=-1) 674 | k = torch.nn.functional.normalize(k, dim=-1) 675 | 676 | attn = (q @ k.transpose(-2, -1)) * self.temperature 677 | attn = attn.softmax(dim=-1) 678 | 679 | out = (attn @ v) 680 | 681 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 682 | 683 | out = self.project_out(out) 684 | return out 685 | 686 | 687 | class alignment(nn.Module): 688 | def __init__(self, dim, dim_prev=None, memory=False, stride=1, type='group_conv'): 689 | super(alignment, self).__init__() 690 | act = nn.GELU() 691 | bias = False 692 | kernel_size = 3 693 | padding = kernel_size//2 694 | deform_groups = 8 695 | out_channels = deform_groups * 3 * (kernel_size**2) 696 | ## fw 697 | self.fw_offset_conv = nn.Conv2d(dim, out_channels, kernel_size, stride=1, padding=padding, bias=bias) 698 | self.fw_conv1 = nn.Conv2d(dim*4, dim, 3, 1, 1, bias=bias) 699 | self.fw_bottleneck = nn.Sequential(nn.Conv2d(dim, dim, kernel_size = 3, padding = 1, bias = bias), act) 700 | self.fw_deform = DeformConv2d(dim, dim, kernel_size, padding = 2, groups = deform_groups, dilation=2) 701 | ## bw 702 | self.bw_offset_conv = nn.Conv2d(dim, out_channels, kernel_size, stride=1, padding=padding, bias=bias) 703 | self.bw_conv1 = nn.Conv2d(dim*4, dim, 3, 1, 1, bias=bias) 704 | self.bw_bottleneck = nn.Sequential(nn.Conv2d(dim, dim, kernel_size = 3, padding = 1, bias = bias), act) 705 | self.bw_deform = DeformConv2d(dim, dim, kernel_size, padding = 2, groups = deform_groups, dilation=2) 706 | ## alignment final 707 | self.align_final = nn.Conv2d(dim*2, dim, 3, 1, 1, bias=bias) 708 | ## memory offset 709 | if memory==True: 710 | ## fw offset 711 | self.fw_offset_feat_up = nn.ConvTranspose2d(dim_prev, dim, 3, stride=2, padding=1, output_padding=1) 712 | self.bw_offset_feat_up = nn.ConvTranspose2d(dim_prev, dim, 3, stride=2, padding=1, output_padding=1) 713 | self.fw_bottleneck_prev = nn.Sequential(nn.Conv2d(dim*2, dim, kernel_size = 3 , padding=1, bias=bias), act) 714 | self.bw_bottleneck_prev = nn.Sequential(nn.Conv2d(dim*2, dim, kernel_size = 3 , padding=1, bias=bias), act) 715 | 716 | def offset_gen(self, x): 717 | o1, o2, mask = torch.chunk(x, 3, dim=1) 718 | offset = torch.cat((o1, o2), dim=1) 719 | mask = torch.sigmoid(mask) 720 | return offset, mask 721 | 722 | def forward(self, x, x_ev, bs, prev_fw_offset_feat=None, prev_bw_offset_feat=None): 723 | ## 724 | x = rearrange(x, '(b t) c h w -> b t c h w', b=bs) 725 | x_ev = rearrange(x_ev, '(b t) c h w -> b t c h w', b=bs) 726 | ## 727 | x_prev = x[:, 0] 728 | x_cur = x[:, 1] 729 | x_future = x[:, 2] 730 | ### forward prop 731 | x_ev_prev = torch.cat((x_ev[:, 0], x_ev[:, 1]), dim=1) 732 | input_prev = torch.cat((x_prev, x_cur, x_ev_prev), dim=1) 733 | prop_prev = self.fw_conv1(input_prev) 734 | fw_offset_feat = self.fw_bottleneck(prop_prev) 735 | ## forward memory 736 | if prev_fw_offset_feat is not None: 737 | prev_fw_offset_feat = self.fw_offset_feat_up(prev_fw_offset_feat) 738 | fw_offset_feat = self.fw_bottleneck_prev(torch.cat((fw_offset_feat, prev_fw_offset_feat), dim=1)) 739 | fw_offset, fw_mask = self.offset_gen(self.fw_offset_conv(fw_offset_feat)) 740 | prop_prev = self.fw_deform(prop_prev, fw_offset, fw_mask) 741 | ### backward prop 742 | x_ev_future = torch.cat((x_ev[:, 1], x_ev[:, 2]), dim=1) 743 | input_future = torch.cat((x_cur, x_future, x_ev_future), dim=1) 744 | prop_future = self.bw_conv1(input_future) 745 | bw_offset_feat = self.bw_bottleneck(prop_future) 746 | ## bw memory 747 | if prev_bw_offset_feat is not None: 748 | prev_bw_offset_feat = self.fw_offset_feat_up(prev_bw_offset_feat) 749 | bw_offset_feat = self.fw_bottleneck_prev(torch.cat((bw_offset_feat, prev_bw_offset_feat), dim=1)) 750 | bw_offset, bw_mask = self.offset_gen(self.bw_offset_conv(bw_offset_feat)) 751 | prop_future = self.bw_deform(prop_future, bw_offset, bw_mask) 752 | ## 753 | aligned_feat = self.align_final(torch.cat((prop_prev, prop_future), dim=1)) 754 | return aligned_feat, fw_offset_feat, bw_offset_feat 755 | 756 | 757 | class Transformer(nn.Module): 758 | def __init__(self, dim, num_heads, stride, ffn_expansion_factor, bias, LayerNorm_type): 759 | super(Transformer, self).__init__() 760 | self.norm1 = LayerNorm(dim, LayerNorm_type) 761 | self.attn = Attention(dim, num_heads, stride, bias) 762 | self.norm2 = LayerNorm(dim, LayerNorm_type) 763 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias) 764 | 765 | def forward(self, x): 766 | x = x + self.attn(self.norm1(x)) 767 | x = x + self.ffn(self.norm2(x)) 768 | return x 769 | 770 | 771 | 772 | class EDTFA(nn.Module): 773 | def __init__(self, in_channels=48): 774 | super(EDTFA, self).__init__() 775 | heads = [1,2,4,8] 776 | bias = False 777 | LayerNorm_type = 'WithBias' 778 | base_channel = 32 779 | scale_unet_feat = base_channel 780 | ## down and encoder 781 | self.down1 = nn.Conv2d(base_channel, base_channel + scale_unet_feat, 3, stride=2, padding=1) 782 | self.down2 = nn.Conv2d(base_channel+ scale_unet_feat, base_channel+2*scale_unet_feat, 3, stride=2, padding=1) 783 | self.encoder_level1 = nn.Sequential(*[Transformer(dim=base_channel, num_heads=heads[0], stride=1, 784 | ffn_expansion_factor=2.66, bias=bias, 785 | LayerNorm_type=LayerNorm_type) for i in range(2)]) 786 | self.encoder_level2 = nn.Sequential(*[Transformer(dim=base_channel + scale_unet_feat, num_heads=heads[1], stride=1, 787 | ffn_expansion_factor=2.66, bias=bias, 788 | LayerNorm_type=LayerNorm_type) for i in range(2)]) 789 | ## event - encoder 790 | self.down1_e = nn.Conv2d(base_channel, base_channel + scale_unet_feat, 3, stride=2, padding=1) 791 | self.down2_e = nn.Conv2d(base_channel + scale_unet_feat, base_channel + 2*scale_unet_feat , 3, stride=2, padding=1) 792 | self.encoder_level1_e = nn.Sequential(*[Transformer(dim=base_channel, num_heads=heads[0], stride=1 ,ffn_expansion_factor=2.66, 793 | bias=bias, LayerNorm_type=LayerNorm_type) for i in range(2)]) 794 | self.encoder_level2_e = nn.Sequential(*[Transformer(dim=base_channel + scale_unet_feat, num_heads=heads[1], stride=1, 795 | ffn_expansion_factor=2.66, bias=bias, 796 | LayerNorm_type=LayerNorm_type) for i in range(2)]) 797 | self.alignment0 = alignment(base_channel, base_channel+scale_unet_feat, memory=True) 798 | self.alignment1 = alignment(base_channel + scale_unet_feat, base_channel + 2*scale_unet_feat, memory=True) 799 | self.alignment2 = alignment(base_channel + 2*scale_unet_feat) 800 | 801 | self.up1 = nn.ConvTranspose2d(base_channel + scale_unet_feat, 802 | base_channel, 3, stride=2, padding=1, output_padding=1) 803 | self.up2 = nn.ConvTranspose2d(base_channel + 2*scale_unet_feat, 804 | base_channel + scale_unet_feat, 3, stride=2, padding=1, output_padding=1) 805 | 806 | def forward(self, x, x_event, bs): 807 | ## frame 808 | x = self.encoder_level1(x) 809 | enc1_f = self.down1(x) 810 | enc1_f = self.encoder_level2(enc1_f) 811 | enc2_f = self.down2(enc1_f) 812 | ### event 813 | x_event = self.encoder_level1_e(x_event) 814 | enc1_e = self.down1_e(x_event) 815 | enc1_e = self.encoder_level2_e(enc1_e) 816 | enc2_e = self.down2_e(enc1_e) 817 | ## alignment 818 | enc2, fw_offset_fw2, bw_offset_bw2 = self.alignment2(enc2_f, enc2_e, bs) 819 | dec1 = self.up2(enc2) 820 | enc1_f, fw_offset_fw1, bw_offset_bw1 = self.alignment1(enc1_f, enc1_e, bs, fw_offset_fw2, bw_offset_bw2) 821 | dec1 = dec1 + enc1_f 822 | dec0 = self.up1(dec1) 823 | x, _, _ = self.alignment0(x, x_event, bs, fw_offset_fw1, bw_offset_bw1) 824 | x = x + dec0 825 | return x 826 | 827 | 828 | class Decoder(nn.Module): 829 | def __init__(self, n_feat, scale_unetfeats, kernel_size=3, reduction=4, bias=False): 830 | super(Decoder, self).__init__() 831 | act = nn.PReLU() 832 | self.decoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 833 | self.decoder_level2 = [CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 834 | self.decoder_level3 = [CAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 835 | 836 | self.decoder_level1 = nn.Sequential(*self.decoder_level1) 837 | self.decoder_level2 = nn.Sequential(*self.decoder_level2) 838 | self.decoder_level3 = nn.Sequential(*self.decoder_level3) 839 | 840 | self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act) 841 | self.skip_attn2 = CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) 842 | 843 | self.up21 = SkipUpSample(n_feat, scale_unetfeats) 844 | self.up32 = SkipUpSample(n_feat + scale_unetfeats, scale_unetfeats) 845 | ## reconstruction 846 | self.recons_1 = conv5x5(n_feat, 3) 847 | self.recons_2 = conv5x5(n_feat+scale_unetfeats, 3) 848 | self.recons_3 = conv5x5(n_feat+2*scale_unetfeats, 3) 849 | 850 | def forward(self, outs): 851 | enc1, enc2, enc3 = outs 852 | dec3 = self.decoder_level3(enc3) 853 | deblurred_img_0 = self.recons_3(dec3) 854 | deblurred_img_0 = torch.clamp(deblurred_img_0, 0, 1) 855 | 856 | x = self.up32(dec3, self.skip_attn2(enc2)) 857 | dec2 = self.decoder_level2(x) 858 | deblurred_img_1 = self.recons_2(dec2) 859 | deblurred_img_1 = torch.clamp(deblurred_img_1, 0, 1) 860 | 861 | x = self.up21(dec2, self.skip_attn1(enc1)) 862 | dec1 = self.decoder_level1(x) 863 | deblurred_img_2 = self.recons_1(dec1) 864 | deblurred_img_2 = torch.clamp(deblurred_img_2, 0, 1) 865 | return [deblurred_img_2, deblurred_img_1, deblurred_img_0] 866 | 867 | #### fusion module ### 868 | class Frequency_module(nn.Module): 869 | def __init__(self, num_channels): 870 | super(Frequency_module, self).__init__() 871 | self.conv1_y = nn.Conv2d(2 * num_channels, 2 * num_channels, kernel_size=1, stride=1, padding=0) 872 | self.conv2_y = nn.Conv2d(2 * num_channels, 2 * num_channels, kernel_size=1, stride=1, padding=0) 873 | 874 | self.conv1_z = nn.Conv2d(2 * num_channels, 2 * num_channels, kernel_size=1, stride=1, padding=1) 875 | self.conv2_z = nn.Conv2d(2 * num_channels, 2 * num_channels, kernel_size=1, stride=1, padding=1) 876 | ## sigma 877 | self.sigma = 7 878 | 879 | def make_gaussian(self, y_idx, x_idx, height, width): 880 | yv, xv = torch.meshgrid([torch.arange(0, height), torch.arange(0, width)]) 881 | yv = yv.unsqueeze(0).float().cuda() 882 | xv = xv.unsqueeze(0).float().cuda() 883 | g = torch.exp(- ((yv - y_idx) ** 2 + (xv - x_idx) ** 2) / (2 * self.sigma** 2)) 884 | return g.unsqueeze(0) #1, 1, H, W 885 | 886 | def forward(self, x): 887 | b, c, h, w = x.shape 888 | x = x.float() 889 | y = torch.fft.fft2(x) 890 | 891 | h_idx, w_idx = h // 2, w // 2 892 | high_filter = self.make_gaussian(h_idx, w_idx, h, w) 893 | ## high frequency regions 894 | f = y * high_filter 895 | 896 | y_imag = f.imag 897 | y_real = f.real 898 | y_f = torch.cat([y_real, y_imag], dim=1) 899 | f = F.relu(self.conv1_y(y_f)) 900 | 901 | f = self.conv2_y(f).float() 902 | f_real, f_imag = torch.chunk(f, 2, dim=1) 903 | f = torch.complex(f_real, f_imag) 904 | 905 | f = torch.fft.ifft2(f, s=(h, w)).float() 906 | high_out = x+f 907 | return high_out 908 | 909 | 910 | class FeedForward_bottom_level(nn.Module): 911 | def __init__(self, dim, dim_prev, ffn_expansion_factor, bias): 912 | super(FeedForward_bottom_level, self).__init__() 913 | hidden_features = int(dim*ffn_expansion_factor) 914 | self.num_feat = hidden_features 915 | self.deconv_hidden = deconv4x4(dim_prev, hidden_features) 916 | self.project_f = conv3x3(dim, hidden_features) 917 | self.project_e = conv3x3(dim, hidden_features) 918 | self.project_cat_feat = conv3x3_leaky_relu(hidden_features*3, hidden_features*2) 919 | self.conv1_1 = conv3x3_leaky_relu(2*hidden_features, hidden_features) 920 | self.conv1_2 = conv3x3_leaky_relu(2*hidden_features, hidden_features) 921 | ## dynamic filter 922 | self.freq = Frequency_module(hidden_features) 923 | self.dp = IDynamicDWConv(hidden_features, 3, 1, 2, 2) 924 | self.cat_conv = conv3x3_leaky_relu(hidden_features, hidden_features*2) 925 | self.conv2 = nn.Conv2d(hidden_features, hidden_features, 3, 1, 1, bias=True) 926 | self.conv3 = nn.Conv2d(hidden_features, hidden_features, 3, 1, 1, bias=True) 927 | ## output projection 928 | self.project_CAB = CABs(hidden_features, 3, reduction=4, act = nn.PReLU(), bias=False, num_cab=3) 929 | self.project_out = conv3x3_leaky_relu(hidden_features, dim) 930 | 931 | def forward(self, x_f, x_ev, fusion_feat_prev): 932 | deconv_fusion_feat = self.deconv_hidden(fusion_feat_prev) 933 | x_f = self.project_f(x_f) 934 | x_ev = self.project_e(x_ev) 935 | ## cat feat 936 | cat_feat = torch.cat((x_f, x_ev, deconv_fusion_feat), dim=1) 937 | cat_feat = self.project_cat_feat(cat_feat) 938 | ## split 939 | cat_feat1 = self.conv1_1(cat_feat) 940 | cat_feat2 = self.conv1_2(cat_feat) 941 | # cat_feat1, cat_feat2 = torch.split(cat_feat, self.num_feat, dim=1) 942 | ## high frequency 943 | high_feat = self.freq(cat_feat1) 944 | high_feat = self.dp(high_feat, high_feat) 945 | high_feat = high_feat * torch.sigmoid(self.conv2(high_feat)) 946 | ## org feat 947 | cat_feat2 = cat_feat2 * torch.sigmoid(self.conv3(cat_feat2)) 948 | cat_feat = high_feat + cat_feat2 949 | ## dp 950 | cat_feat = self.project_CAB(cat_feat) 951 | x = self.project_out(cat_feat) 952 | return x 953 | 954 | class ImageEventFusion_bottom_level(nn.Module): 955 | def __init__(self, dim, dim_prev, ffn_expansion_factor, bias, LayerNorm_type): 956 | super(ImageEventFusion_bottom_level, self).__init__() 957 | self.ffn = FeedForward_bottom_level(dim, dim_prev, ffn_expansion_factor, bias) 958 | 959 | def forward(self, x_f, x_ev, fusion_feat_prev): 960 | x = self.ffn(x_f, x_ev, fusion_feat_prev) 961 | return x 962 | 963 | class FeedForward_top_level(nn.Module): 964 | def __init__(self, dim, ffn_expansion_factor, bias): 965 | super(FeedForward_top_level, self).__init__() 966 | hidden_features = int(dim*ffn_expansion_factor) 967 | self.num_feat = hidden_features 968 | ### modules 969 | self.project_f = conv3x3(dim, hidden_features) 970 | self.project_e = conv3x3(dim, hidden_features) 971 | self.project_cat_feat = conv3x3_leaky_relu(hidden_features*2, hidden_features*2) 972 | self.conv1_1 = conv3x3_leaky_relu(2*hidden_features, hidden_features) 973 | self.conv1_2 = conv3x3_leaky_relu(2*hidden_features, hidden_features) 974 | self.freq = Frequency_module(hidden_features) 975 | self.dp = IDynamicDWConv(hidden_features, 3, 1, 2, 2) 976 | self.cat_conv = conv3x3_leaky_relu(hidden_features, hidden_features*2) 977 | self.conv2 = nn.Conv2d(hidden_features, hidden_features, 3, 1, 1, bias=True) 978 | self.conv3 = nn.Conv2d(hidden_features, hidden_features, 3, 1, 1, bias=True) 979 | self.project_CAB = CABs(hidden_features, 3, reduction=4, act = nn.PReLU(), bias=False, num_cab=3) 980 | self.project_out = conv3x3_leaky_relu(hidden_features, dim) 981 | 982 | def forward(self, x_f, x_ev): 983 | x_f = self.project_f(x_f) 984 | x_ev = self.project_e(x_ev) 985 | ## cat feat 986 | cat_feat = torch.cat((x_f, x_ev), dim=1) 987 | cat_feat = self.project_cat_feat(cat_feat) 988 | ## split 989 | cat_feat1 = self.conv1_1(cat_feat) 990 | cat_feat2 = self.conv1_2(cat_feat) 991 | ## high frequency 992 | high_feat = self.freq(cat_feat1) 993 | high_feat = self.dp(high_feat, high_feat) 994 | high_feat = high_feat * torch.sigmoid(self.conv2(high_feat)) 995 | ## org feat 996 | cat_feat2 = cat_feat2 * torch.sigmoid(self.conv3(cat_feat2)) 997 | cat_feat = high_feat + cat_feat2 998 | ## output feature processing 999 | cat_feat = self.project_CAB(cat_feat) 1000 | x = self.project_out(cat_feat) 1001 | return x 1002 | 1003 | class ImageEventFusion_top_level(nn.Module): 1004 | def __init__(self, dim, ffn_expansion_factor, bias, LayerNorm_type): 1005 | super(ImageEventFusion_top_level, self).__init__() 1006 | self.ffn = FeedForward_top_level(dim, ffn_expansion_factor, bias) 1007 | 1008 | def forward(self, x_f, x_ev): 1009 | x = self.ffn(x_f, x_ev) 1010 | return x 1011 | 1012 | 1013 | class EventDeblurNet(nn.Module): 1014 | def __init__(self): 1015 | super(EventDeblurNet, self).__init__() 1016 | ### ev down convolution 1017 | base_feat = 32 1018 | scale_unet_feat = base_feat 1019 | self.device = torch.device('cuda') 1020 | # RNN cell 1021 | self.shallow_cell_frames = shallow_cell(n_feat=base_feat) 1022 | self.shallow_cell_events = shallow_cell_events(n_feat=base_feat) 1023 | self.encoder_frame = Encoder(n_feat=base_feat, scale_unetfeats=scale_unet_feat) 1024 | self.encoder_event = Encoder(n_feat=base_feat, scale_unetfeats=scale_unet_feat) 1025 | # decoder 1026 | self.decoder = Decoder(n_feat=base_feat, scale_unetfeats=scale_unet_feat) 1027 | ## recon id 1028 | self.recon_id = 1 1029 | self.scale_range = 3 1030 | ## alignment 1031 | self.align = EDTFA(base_feat) 1032 | ## alignment module 1033 | ## fusion module 1034 | fusion_list = [] 1035 | fusion_list.append(ImageEventFusion_bottom_level(base_feat, int(base_feat + scale_unet_feat), 1036 | ffn_expansion_factor=1, bias=False, LayerNorm_type='WithBias')) 1037 | fusion_list.append(ImageEventFusion_bottom_level(int(base_feat + scale_unet_feat), int(base_feat + 2*scale_unet_feat), 1038 | ffn_expansion_factor=1.25, bias=False, LayerNorm_type='WithBias')) 1039 | fusion_list.append(ImageEventFusion_top_level(int(base_feat + 2*scale_unet_feat), 1040 | ffn_expansion_factor=1.5, bias=False, LayerNorm_type='WithBias')) 1041 | self.fusion_list = nn.ModuleList(fusion_list) 1042 | 1043 | def forward(self, batch): 1044 | b, t, c, h, w = batch['blur_input_clip'].shape 1045 | x_frame = batch['blur_input_clip'] 1046 | x_event = batch['event_vox_clip'] 1047 | ## frame feature processing 1048 | x_frame = rearrange(x_frame, 'b t c h w -> (b t) c h w') 1049 | f_feature = self.shallow_cell_frames(x_frame) 1050 | ## event feature processing 1051 | x_event = rearrange(x_event, 'b t c h w -> (b t) c h w') 1052 | ## event feature 1053 | e_feature = self.shallow_cell_events(x_event) 1054 | ### alignment 1055 | aligned_feat = self.align(f_feature, e_feature, b) 1056 | # ## frame encoding 1057 | f_encoder_outs = self.encoder_frame(aligned_feat) 1058 | ## event encoding 1059 | e_encoder_outs = self.encoder_event(e_feature) 1060 | # ## align module 1061 | # # prop_hidden = torch.new_zeros(b, c, h, w) 1062 | for scale_idx in range(self.scale_range-1, -1, -1): 1063 | e_encoder_outs[scale_idx] = rearrange(e_encoder_outs[scale_idx], '(b t) c h w -> b t c h w', b=b) 1064 | # ## fusion 1065 | fusion_out_0 = self.fusion_list[-1](f_encoder_outs[-1], e_encoder_outs[-1][:, 1]) 1066 | fusion_out_1 = self.fusion_list[-2](f_encoder_outs[-2], e_encoder_outs[-2][:, 1], fusion_out_0) 1067 | fusion_out_2 = self.fusion_list[-3](f_encoder_outs[-3], e_encoder_outs[-3][:, 1], fusion_out_1) 1068 | # ## 1069 | encoder_outs = [] 1070 | encoder_outs.append(fusion_out_2) 1071 | encoder_outs.append(fusion_out_1) 1072 | encoder_outs.append(fusion_out_0) 1073 | # ################## 1074 | # # TFR # 1075 | # ################## 1076 | output_dict = self.decoder(encoder_outs) 1077 | return output_dict -------------------------------------------------------------------------------- /models/model_manager.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import os 4 | from math import ceil 5 | import importlib 6 | from utils.utils import AverageMeter 7 | import torch.nn.functional as F 8 | 9 | 10 | class ModelManager(object): 11 | def __init__(self, args): 12 | self.batch = {} 13 | self.args = args 14 | self.voxel_num_bins = args.voxel_num_bins 15 | self.smoothness_weight = 1.0 16 | self.scale = 3 if getattr(args, 'loss_type', None) == 'multi_scale' else 1 17 | self.ms_lambda_dict = [1.0, 0.5, 0.25] if self.scale == 3 else [] 18 | self.downsample = nn.AvgPool2d(2, stride=2) if self.scale == 3 else None 19 | 20 | def initilalize_deblur_model(self, args, model_folder, model_name, tb_path): 21 | mod = importlib.import_module('models.' + model_folder + '.' + model_name) 22 | self.deblur_net = mod.EventDeblurNet() 23 | self.save_path = os.path.join(tb_path, 'saved_model') if tb_path else None 24 | os.makedirs(self.save_path, exist_ok=True) if self.save_path else None 25 | self.loss_total_meter = AverageMeter() 26 | self.loss_deblur_meter = AverageMeter() 27 | 28 | def cuda_deblur(self): 29 | self.deblur_net.cuda() 30 | 31 | def use_multi_gpu_deblur(self): # data parallel 32 | self.deblur_net = nn.DataParallel(self.deblur_net) 33 | 34 | def count_total_parameters(self): 35 | return sum(p.numel() for p in self.deblur_net.parameters()) 36 | 37 | def get_deblurnet_optimizer_params(self): 38 | return self.deblur_net.parameters() 39 | 40 | def get_l1_loss(self, x, y, reduction_): 41 | loss = F.l1_loss(x, y, reduction=reduction_) 42 | return loss 43 | 44 | def get_chainbor_loss(self, x, y): 45 | loss = ((((x - y) ** 2 + 1e-6) ** 0.5).mean()) 46 | return loss 47 | 48 | def backward_warp(self, x, flo): 49 | ''' 50 | x shape : [B,C,T,H,W] 51 | t_value shape : [B,1] ############### 52 | ''' 53 | B, C, H, W = x.size() 54 | # mesh grid 55 | xx = torch.arange(0, W).view(1, 1, 1, W).expand(B, 1, H, W) 56 | yy = torch.arange(0, H).view(1, 1, H, 1).expand(B, 1, H, W) 57 | grid = torch.cat((xx, yy), 1).float() 58 | 59 | if x.is_cuda: 60 | grid = grid.cuda() 61 | vgrid = torch.autograd.Variable(grid) + flo 62 | 63 | # scale grid to [-1,1] 64 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0 65 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0 66 | 67 | vgrid = vgrid.permute(0, 2, 3, 1) # [B,H,W,2] 68 | output = nn.functional.grid_sample(x, vgrid, align_corners=True) 69 | mask = torch.autograd.Variable(torch.ones(x.size())).cuda() 70 | mask = nn.functional.grid_sample(mask, vgrid, align_corners=True) 71 | 72 | mask = mask.masked_fill_(mask < 0.999, 0) 73 | mask = mask.masked_fill_(mask > 0, 1) 74 | return output * mask 75 | 76 | def set_test_inputs(self, sample): 77 | self.batch['event_vox_clip'] = sample['event_vox_clip'] 78 | self.batch['blur_input_clip'] = sample['blur_input_clip'] 79 | 80 | def set_video_inputs(self, sample): 81 | self.batch['event_vox_clip'] = sample['event_vox_clip'] 82 | self.batch['blur_input_clip'] = sample['blur_input_clip'] 83 | self.batch['clean_middle'] = sample['clean_middle'] 84 | 85 | def set_test_video_inputs(self, sample): 86 | self.batch['event_vox_clip'] = sample['event_vox_clip'] 87 | self.batch['blur_input_clip'] = sample['blur_input_clip'] 88 | self.batch['clean_gt_clip'] = sample['clean_gt_clip'] 89 | self.batch['clean_middle'] = sample['clean_middle'] 90 | 91 | def forward_deblur_net(self): 92 | self.batch['output_deblur'] = self.deblur_net(self.batch) 93 | 94 | def get_single_loss(self): 95 | self.loss_total = self.get_chainbor_loss(self.batch['clean_middle'], self.batch['output_deblur'][0]) 96 | return self.loss_total 97 | 98 | def get_single_dict_loss(self): 99 | loss_dict = [] 100 | for i in range(len(self.batch['output_deblur'])): 101 | loss_dict.append(self.get_chainbor_loss(self.batch['clean_middle'], self.batch['output_deblur'][i])) 102 | self.loss_total = sum(loss_dict) 103 | return self.loss_total 104 | 105 | ## multi scale loss 106 | def get_multi_scale_single_loss(self): 107 | loss_dict = [] 108 | for scale_idx in range(len(self.batch['output_deblur'])): 109 | loss_dict.append(self.ms_lambda_dict[scale_idx]*self.get_chainbor_loss(self.batch['clean_middle_ms'][scale_idx], self.batch['output_deblur'][scale_idx])) 110 | self.loss_total = sum(loss_dict) 111 | return self.loss_total 112 | 113 | def update_loss_meters_deblur(self): 114 | # total loss update 115 | self.loss_total_meter.update(self.loss_total.item(), 1) 116 | 117 | def update_loss_meters_all(self): 118 | # total loss update 119 | self.loss_total_meter.update(self.loss_total.item(), 1) 120 | self.loss_deblur_meter.update(self.loss_image.item(), 1) 121 | 122 | def reset_loss_meters_deblur(self): 123 | self.loss_total_meter.reset() 124 | 125 | def reset_loss_meters_all(self): 126 | self.loss_total_meter.reset() 127 | self.loss_deblur_meter.reset() 128 | 129 | def del_batch(self): 130 | del self.batch 131 | self.batch = dict() 132 | 133 | def load_model(self, state_dict): 134 | self.deblur_net.load_state_dict(state_dict) 135 | print('load model') -------------------------------------------------------------------------------- /models/submodules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn import init 5 | from torch.nn.modules.batchnorm import _BatchNorm 6 | 7 | 8 | def actFunc(act, *args, **kwargs): 9 | act = act.lower() 10 | if act == 'relu': 11 | return nn.ReLU() 12 | elif act == 'relu6': 13 | return nn.ReLU6() 14 | elif act == 'leakyrelu': 15 | return nn.LeakyReLU(0.1) 16 | elif act == 'prelu': 17 | return nn.PReLU() 18 | elif act == 'rrelu': 19 | return nn.RReLU(0.1, 0.3) 20 | elif act == 'selu': 21 | return nn.SELU() 22 | elif act == 'celu': 23 | return nn.CELU() 24 | elif act == 'elu': 25 | return nn.ELU() 26 | elif act == 'gelu': 27 | return nn.GELU() 28 | elif act == 'tanh': 29 | return nn.Tanh() 30 | else: 31 | raise NotImplementedError 32 | 33 | class ResBlock(nn.Module): 34 | """ 35 | Residual block 36 | """ 37 | def __init__(self, in_chs, activation='relu', batch_norm=False): 38 | super(ResBlock, self).__init__() 39 | op = [] 40 | for i in range(2): 41 | op.append(conv3x3(in_chs, in_chs)) 42 | if batch_norm: 43 | op.append(nn.BatchNorm2d(in_chs)) 44 | if i == 0: 45 | op.append(actFunc(activation)) 46 | self.main_branch = nn.Sequential(*op) 47 | 48 | def forward(self, x): 49 | out = self.main_branch(x) 50 | out += x 51 | return out 52 | 53 | 54 | ### resblock 55 | def make_layer(basic_block, num_basic_block, **kwarg): 56 | """Make layers by stacking the same blocks. 57 | 58 | Args: 59 | basic_block (nn.module): nn.module class for basic block. 60 | num_basic_block (int): number of blocks. 61 | 62 | Returns: 63 | nn.Sequential: Stacked blocks in nn.Sequential. 64 | """ 65 | layers = [] 66 | for _ in range(num_basic_block): 67 | layers.append(basic_block(**kwarg)) 68 | return nn.Sequential(*layers) 69 | 70 | @torch.no_grad() 71 | def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): 72 | """Initialize network weights. 73 | 74 | Args: 75 | module_list (list[nn.Module] | nn.Module): Modules to be initialized. 76 | scale (float): Scale initialized weights, especially for residual 77 | blocks. Default: 1. 78 | bias_fill (float): The value to fill bias. Default: 0 79 | kwargs (dict): Other arguments for initialization function. 80 | """ 81 | if not isinstance(module_list, list): 82 | module_list = [module_list] 83 | for module in module_list: 84 | for m in module.modules(): 85 | if isinstance(m, nn.Conv2d): 86 | init.kaiming_normal_(m.weight, **kwargs) 87 | m.weight.data *= scale 88 | if m.bias is not None: 89 | m.bias.data.fill_(bias_fill) 90 | elif isinstance(m, nn.Linear): 91 | init.kaiming_normal_(m.weight, **kwargs) 92 | m.weight.data *= scale 93 | if m.bias is not None: 94 | m.bias.data.fill_(bias_fill) 95 | elif isinstance(m, _BatchNorm): 96 | init.constant_(m.weight, 1) 97 | if m.bias is not None: 98 | m.bias.data.fill_(bias_fill) 99 | 100 | 101 | class ResidualBlockNoBN3D(nn.Module): 102 | """Residual block without BN. 103 | 104 | It has a style of: 105 | ---Conv-ReLU-Conv-+- 106 | |________________| 107 | 108 | Args: 109 | num_feat (int): Channel number of intermediate features. 110 | Default: 64. 111 | res_scale (float): Residual scale. Default: 1. 112 | pytorch_init (bool): If set to True, use pytorch default init, 113 | otherwise, use default_init_weights. Default: False. 114 | """ 115 | 116 | def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): 117 | super(ResidualBlockNoBN3D, self).__init__() 118 | self.res_scale = res_scale 119 | self.conv1 = nn.Conv3d(num_feat, num_feat, 3, 1, 1, bias=True) 120 | self.conv2 = nn.Conv3d(num_feat, num_feat, 3, 1, 1, bias=True) 121 | self.relu = nn.ReLU(inplace=True) 122 | 123 | if not pytorch_init: 124 | default_init_weights([self.conv1, self.conv2], 0.1) 125 | 126 | def forward(self, x): 127 | identity = x 128 | out = self.conv2(self.relu(self.conv1(x))) 129 | return identity + out * self.res_scale 130 | 131 | class ResidualBlockNoBN(nn.Module): 132 | """Residual block without BN. 133 | 134 | It has a style of: 135 | ---Conv-ReLU-Conv-+- 136 | |________________| 137 | 138 | Args: 139 | num_feat (int): Channel number of intermediate features. 140 | Default: 64. 141 | res_scale (float): Residual scale. Default: 1. 142 | pytorch_init (bool): If set to True, use pytorch default init, 143 | otherwise, use default_init_weights. Default: False. 144 | """ 145 | 146 | def __init__(self, num_feat=64, pytorch_init=False): 147 | super(ResidualBlockNoBN, self).__init__() 148 | self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) 149 | self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) 150 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 151 | 152 | if not pytorch_init: 153 | default_init_weights([self.conv1, self.conv2], 0.1) 154 | 155 | def forward(self, x): 156 | identity = x 157 | out = self.conv2(self.lrelu(self.conv1(x))) 158 | return identity + out 159 | 160 | 161 | class ResidualBlockNoBN2D(nn.Module): 162 | """Residual block without BN. 163 | 164 | It has a style of: 165 | ---Conv-ReLU-Conv-+- 166 | |________________| 167 | 168 | Args: 169 | num_feat (int): Channel number of intermediate features. 170 | Default: 64. 171 | res_scale (float): Residual scale. Default: 1. 172 | pytorch_init (bool): If set to True, use pytorch default init, 173 | otherwise, use default_init_weights. Default: False. 174 | """ 175 | 176 | def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): 177 | super(ResidualBlockNoBN2D, self).__init__() 178 | self.res_scale = res_scale 179 | self.conv1 = nn.Conv3d(num_feat, num_feat, (1, 3, 3), 1, (0, 1, 1), bias=True) 180 | self.conv2 = nn.Conv3d(num_feat, num_feat, (1, 3, 3), 1, (0, 1, 1), bias=True) 181 | self.relu = nn.ReLU(inplace=True) 182 | 183 | if not pytorch_init: 184 | default_init_weights([self.conv1, self.conv2], 0.1) 185 | 186 | def forward(self, x): 187 | identity = x 188 | out = self.conv2(self.relu(self.conv1(x))) 189 | return identity + out * self.res_scale 190 | 191 | class ResidualBlocks2D(nn.Module): 192 | def __init__(self, num_feat=64, num_block=30): 193 | super().__init__() 194 | self.main = nn.Sequential( 195 | make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat)) 196 | 197 | def forward(self, fea): 198 | return self.main(fea) 199 | 200 | # conv blocks 201 | def conv1x1(in_channels, out_channels, stride=1): 202 | return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=True) 203 | 204 | def conv1x1_relu(in_channels, out_channels, stride=1): 205 | return nn.Sequential(conv1x1(in_channels, out_channels, stride), nn.ReLU()) 206 | 207 | def conv3x3(in_channels, out_channels, stride=1): 208 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True) 209 | 210 | def conv3x3_relu(in_channels, out_channels, stride=1): 211 | return nn.Sequential(conv3x3(in_channels, out_channels, stride), nn.ReLU()) 212 | 213 | def conv5x5(in_channels, out_channels, stride=1): 214 | return nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=stride, padding=2, bias=True) 215 | 216 | def deconv4x4(in_channels, out_channels, stride=2): 217 | return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=stride, padding=1) 218 | 219 | def deconv5x5(in_channels, out_channels, stride=2): 220 | return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=5, stride=stride, padding=2, output_padding=1) 221 | 222 | def conv_resblock_three(in_channels, out_channels, stride=1): 223 | return nn.Sequential(conv3x3(in_channels, out_channels, stride), nn.ReLU(), ResBlock(out_channels), ResBlock(out_channels), ResBlock(out_channels)) 224 | 225 | def conv_resblock_two(in_channels, out_channels, stride=1): 226 | return nn.Sequential(conv3x3(in_channels, out_channels, stride), nn.ReLU(), ResBlock(out_channels), ResBlock(out_channels)) 227 | 228 | def conv_resblock_one(in_channels, out_channels, stride=1): 229 | return nn.Sequential(conv3x3(in_channels, out_channels, stride), nn.ReLU(), ResBlock(out_channels)) 230 | 231 | def conv_1x1_resblock_one(in_channels, out_channels, stride=1): 232 | return nn.Sequential(conv1x1(in_channels, out_channels, stride), nn.ReLU(), ResBlock(out_channels)) 233 | 234 | def conv_resblock_two_DS(in_channels, out_channels, stride=2): 235 | return nn.Sequential(conv3x3(in_channels, out_channels, stride), nn.ReLU(), ResBlock(out_channels), ResBlock(out_channels)) 236 | 237 | def conv3x3_leaky_relu(in_channels, out_channels, stride=1): 238 | return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True), nn.LeakyReLU(0.1)) 239 | 240 | def conv1x1_leaky_relu(in_channels, out_channels, stride=1): 241 | return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=True), nn.LeakyReLU(0.1)) -------------------------------------------------------------------------------- /pretrained_model/README.md: -------------------------------------------------------------------------------- 1 | put the downloaded model here -------------------------------------------------------------------------------- /sample_data/blur_images/00014.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelpro/ELEDNet/036259bd29cd6465941f5adbc4205035d06d5408/sample_data/blur_images/00014.png -------------------------------------------------------------------------------- /sample_data/blur_images/00015.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelpro/ELEDNet/036259bd29cd6465941f5adbc4205035d06d5408/sample_data/blur_images/00015.png -------------------------------------------------------------------------------- /sample_data/blur_images/00016.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelpro/ELEDNet/036259bd29cd6465941f5adbc4205035d06d5408/sample_data/blur_images/00016.png -------------------------------------------------------------------------------- /sample_data/event_voxel/00014.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelpro/ELEDNet/036259bd29cd6465941f5adbc4205035d06d5408/sample_data/event_voxel/00014.npz -------------------------------------------------------------------------------- /sample_data/event_voxel/00015.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelpro/ELEDNet/036259bd29cd6465941f5adbc4205035d06d5408/sample_data/event_voxel/00015.npz -------------------------------------------------------------------------------- /sample_data/event_voxel/00016.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelpro/ELEDNet/036259bd29cd6465941f5adbc4205035d06d5408/sample_data/event_voxel/00016.npz -------------------------------------------------------------------------------- /test_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import datetime 4 | import argparse 5 | from torch.utils.data import DataLoader 6 | from utils.utils import * 7 | from utils.dataloader import get_test_dataset 8 | from models.model_manager import ModelManager 9 | from tqdm import tqdm 10 | 11 | 12 | def get_argument(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--val_batch_size', type = int, default=4) 15 | # training params 16 | parser.add_argument('--num_test_video_frames', type = int, default=3) 17 | parser.add_argument('--voxel_num_bins', type = int, default=16) 18 | parser.add_argument('--learning_rate', type = float, default=1e-4) 19 | parser.add_argument('--mode', type = str, default='test') 20 | # model discription 21 | parser.add_argument('--model_folder', type=str, default='model_factory') 22 | parser.add_argument('--model_name', type=str, default='models_final') 23 | # data loading params 24 | parser.add_argument('--experiment_name', type = str, default='test_networks') 25 | parser.add_argument('--num_threads', type = int, default=12) 26 | parser.add_argument('--use_multigpu', type=str2bool, default='True') 27 | ## required fileds 28 | parser.add_argument('--data_dir', type = str, default = '/media/mnt2/dataset/RELED/') 29 | parser.add_argument('--resume_ckpt', type=str2bool, default='True') 30 | parser.add_argument('--ckpt_dir', type = str, default='./pretrained_model/Ours_RELED.pth') 31 | parser.add_argument('--saved_dir', type = str, default='./saved_img') 32 | args = parser.parse_args() 33 | return args 34 | 35 | class Tester: 36 | def __init__(self, args): 37 | """ 38 | Initializes the Tester class for evaluating the model. 39 | - Sets up the test data loader. 40 | - Initializes and loads the model. 41 | - Defines evaluation metrics (PSNR, SSIM). 42 | - Configures logging. 43 | """ 44 | self.args = args 45 | # Define the logging and saving path for the experiment (includes date and experiment name). 46 | tb_path = f'./experiments/{datetime.datetime.now().strftime("%y%m%d-" + args.experiment_name + "/%H%M")} ' 47 | # Create the test dataset loader. 48 | self.test_loader = DataLoader(get_test_dataset(args, mode='test'), 49 | batch_size=args.val_batch_size, shuffle=False, 50 | num_workers=args.num_threads, pin_memory=False) 51 | # Initialize the model. 52 | self.model = ModelManager(args) 53 | self.model.initilalize_deblur_model(args, model_folder=args.model_folder, model_name=args.model_name, tb_path=tb_path) 54 | # Load the checkpoint if resuming from a saved model. 55 | if args.resume_ckpt: 56 | ckpt = torch.load(args.ckpt_dir)['model_state_dict'] 57 | # Remove "module." prefix if it exists (for models trained with DataParallel). 58 | new_ckpt = {k.replace("module.", "") if k.startswith("module.") else k: v for k, v in ckpt.items()} 59 | # Save the modified checkpoint 60 | self.model.load_model(new_ckpt) 61 | # Configure device settings 62 | self._setup_device() 63 | # save 64 | self.output_dir = os.path.join(args.saved_dir, 'output_img') 65 | os.makedirs(self.output_dir, exist_ok=True) 66 | self.gt_dir = os.path.join(args.saved_dir, 'gt_img') 67 | os.makedirs(self.gt_dir, exist_ok=True) 68 | 69 | def _setup_device(self): 70 | """ 71 | Configures the computing device. 72 | - Moves the model to GPU if available. 73 | - Enables multi-GPU support if specified. 74 | """ 75 | if torch.cuda.is_available(): 76 | self.model.cuda_deblur() 77 | if self.args.use_multigpu: 78 | self.model.use_multi_gpu_deblur() 79 | 80 | def test(self): 81 | """ 82 | Performs testing on the dataset. 83 | - Iterates through the test loader and evaluates the model. 84 | - Computes PSNR and SSIM for each sample. 85 | - Logs the final evaluation results. 86 | """ 87 | self.model.del_batch() 88 | # 89 | global_cnt = 0 90 | with torch.no_grad(): # Disable gradient calculations for testing. 91 | for sample in tqdm(self.test_loader, desc='Testing Progress'): 92 | sample = batch2device(sample) 93 | self.model.set_video_inputs(sample) 94 | self.model.forward_deblur_net() 95 | # Compute PSNR and SSIM metrics. 96 | for batch_idx in range(args.val_batch_size): 97 | output_img = 255*self.model.batch['output_deblur'][0][batch_idx, ...].squeeze().detach().cpu().numpy().transpose(1,2,0) 98 | clean_middle = 255*self.model.batch['clean_middle'][batch_idx, ...].squeeze().detach().cpu().numpy().transpose(1,2,0) 99 | cv2.imwrite(os.path.join(self.output_dir, str(global_cnt).zfill(5) + '.png'), cv2.cvtColor(output_img, cv2.COLOR_RGB2BGR)) 100 | cv2.imwrite(os.path.join(self.gt_dir, str(global_cnt).zfill(5) + '.png'), cv2.cvtColor(clean_middle, cv2.COLOR_RGB2BGR)) 101 | global_cnt += 1 102 | self.model.del_batch() 103 | # Free up GPU memory. 104 | torch.cuda.empty_cache() 105 | 106 | 107 | if __name__ == '__main__': 108 | args = get_argument() # Parse arguments. 109 | tester = Tester(args) # Initialize the Tester. 110 | tester.test() # Run the test. -------------------------------------------------------------------------------- /test_sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from utils.utils import * 4 | from models.model_manager import ModelManager 5 | from torchvision import transforms 6 | import torchvision.transforms.functional as TF # Functional API 사용 7 | from PIL import Image 8 | 9 | 10 | 11 | def get_argument(): 12 | parser = argparse.ArgumentParser() 13 | # params 14 | parser.add_argument('--num_test_video_frames', type = int, default=3) 15 | parser.add_argument('--mode', type = str, default='test') 16 | parser.add_argument('--voxel_num_bins', type = int, default=16) 17 | # model discription 18 | parser.add_argument('--model_folder', type=str, default='model_factory') 19 | parser.add_argument('--model_name', type=str, default='models_final') 20 | # data loading params 21 | parser.add_argument('--experiment_name', type = str, default='test_networks') 22 | parser.add_argument('--num_threads', type = int, default=12) 23 | parser.add_argument('--sample_folder_path', type = str, default='./sample_data') 24 | parser.add_argument('--resume_ckpt', type=str2bool, required=True) 25 | parser.add_argument('--ckpt_dir', type = str, required=True) 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | if __name__ == '__main__': 31 | # Parse command-line arguments 32 | args = get_argument() 33 | # Initialize model manager 34 | model = ModelManager(args) 35 | # Initialize the deblurring model 36 | model.initilalize_deblur_model(args, model_folder=args.model_folder, model_name=args.model_name, tb_path=None) 37 | # Load the checkpoint if resuming from a saved model 38 | if args.resume_ckpt: 39 | ckpt = torch.load(args.ckpt_dir)['model_state_dict'] 40 | # Remove "module." prefix if it exists (for models trained with DataParallel) 41 | new_ckpt = {k.replace("module.", "") if k.startswith("module.") else k: v for k, v in ckpt.items()} 42 | # Load the modified checkpoint into the model 43 | model.load_model(new_ckpt) 44 | # Configure device settings 45 | if torch.cuda.is_available(): 46 | model.cuda_deblur() 47 | # Configure output directory for saving results 48 | output_dir = os.path.join(args.sample_folder_path, 'output_folder') 49 | os.makedirs(output_dir, exist_ok=True) # Create directory if it doesn't exist 50 | # Clear any previous batch data in the model 51 | model.del_batch() 52 | # Define transformation for image conversion 53 | transform = transforms.ToTensor() 54 | with torch.no_grad(): 55 | sample = dict() 56 | ## Load file paths for input data 57 | blur_image_path = os.path.join(args.sample_folder_path, 'blur_images') 58 | event_voxel_path = os.path.join(args.sample_folder_path, 'event_voxel') 59 | # Get sorted filenames for blur images and event voxels 60 | blur_image_names = sorted(os.listdir(blur_image_path)) 61 | event_voxel_names = sorted(os.listdir(event_voxel_path)) 62 | # Lists to store tensor representations of images and event voxels 63 | event_vox_list, blur_list = [], [] 64 | # Iterate through the given number of test video frames 65 | for i in range(args.num_test_video_frames): 66 | # Load blur image and convert it to tensor 67 | blur_image = Image.open(os.path.join(blur_image_path, blur_image_names[i])) 68 | blur_image_tensor = transform(blur_image) 69 | # Load event voxel data (assuming .npz format with "data" key) 70 | event_voxel = np.load(os.path.join(event_voxel_path, event_voxel_names[i]))["data"] 71 | event_vox_tensor = torch.from_numpy(event_voxel) 72 | # Append tensors to the corresponding lists (adding batch dimension) 73 | event_vox_list.append(event_vox_tensor[None, ...]) 74 | blur_list.append(blur_image_tensor[None, ...]) 75 | # Concatenate the tensors along the batch dimension and add extra dimension 76 | event_vox_tensor = torch.cat(event_vox_list)[None, ...] # Shape: (1, num_frames, ...) 77 | blur_input_clip = torch.cat(blur_list)[None, ...] # Shape: (1, num_frames, ...) 78 | # Assign processed inputs to the sample dictionary 79 | sample['event_vox_clip'] = event_vox_tensor 80 | sample['blur_input_clip'] = blur_input_clip 81 | # Move data to the appropriate device 82 | sample = batch2device(sample) 83 | # Set inputs for the model 84 | model.set_test_inputs(sample) 85 | # Run the deblurring model 86 | model.forward_deblur_net() 87 | # Extract the deblurred output from the model's batch dictionary 88 | output_deblur = model.batch['output_deblur'][0] # Extract first frame of output 89 | # Convert tensor output to a PIL image 90 | output_deblur_cpu = TF.to_pil_image(output_deblur.cpu().squeeze()) 91 | # Save the output image to the designated directory 92 | output_deblur_cpu.save(os.path.join(output_dir, blur_image_names[args.num_test_video_frames // 2])) 93 | # Clear batch data after processing 94 | model.del_batch() 95 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import datetime 4 | import argparse 5 | from collections import OrderedDict 6 | from torch.optim import Adam 7 | from torch.utils.data import DataLoader 8 | from tensorboardX import SummaryWriter 9 | from tqdm import tqdm, trange 10 | from math import ceil 11 | from utils.utils import * 12 | from utils.dataloader import get_train_dataset, get_test_dataset 13 | from models.model_manager import ModelManager 14 | 15 | 16 | def get_argument(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--val_batch_size', type = int, default=1) 19 | parser.add_argument('--batch_size', type = int, default=1) 20 | parser.add_argument('--total_epochs', type = int, default=201) 21 | # training params 22 | parser.add_argument('--num_train_video_frames', type = int, default=3) 23 | parser.add_argument('--num_test_video_frames', type = int, default=3) 24 | parser.add_argument('--voxel_num_bins', type = int, default=16) 25 | parser.add_argument('--num_video_frames', type = int, default=1) 26 | parser.add_argument('--crop_size', type = int, default=256) 27 | parser.add_argument('--learning_rate', type = float, default=1e-4) 28 | parser.add_argument('--mode', type = str, default='train') 29 | # model discription 30 | parser.add_argument('--model_folder', type=str, default='model_factory') 31 | parser.add_argument('--model_name', type=str, default='models_final') 32 | # data loading params 33 | parser.add_argument('--num_threads', type = int, default=12) 34 | parser.add_argument('--test_epoch_every', type=int, default=40) 35 | parser.add_argument('--experiment_name', type = str, default='train_networks') 36 | parser.add_argument('--loss_type', type = str, default='multi_scale') 37 | parser.add_argument('--loss_name', type = str, default='multi_scale_single') 38 | parser.add_argument('--tb_update_thresh', type = int, default=200) 39 | parser.add_argument('--data_dir', type = str, default = '/media/mnt2/dataset/RELED/') 40 | parser.add_argument('--use_multigpu', type=str2bool, default='True') 41 | parser.add_argument('--resume_ckpt', type=str2bool, default=False) 42 | parser.add_argument('--ckpt_dir', type = str, default='./experiments/240302-train_networks_v27_ms_hpf_dp_SA_v4_reduced_ffn_mimo_dffn/0056/saved_model/') 43 | args = parser.parse_args() 44 | return args 45 | 46 | 47 | class Trainer: 48 | def __init__(self, args): 49 | self.args = args 50 | self.tb_iter_cnt = 0 51 | self.tb_iter_cnt_val = 0 52 | self.tb_iter_thresh = args.tb_update_thresh 53 | tb_path = f'./experiments/{datetime.datetime.now().strftime("%y%m%d-" + args.experiment_name + "/%H%M")} ' 54 | self.tb = SummaryWriter(tb_path, flush_secs=1) 55 | 56 | self.train_loader = DataLoader(get_train_dataset(args, mode='train'), 57 | batch_size=args.batch_size, shuffle=True, 58 | num_workers=args.num_threads, pin_memory=True, drop_last=True) 59 | 60 | self.test_loader = DataLoader(get_test_dataset(args, mode='test'), 61 | batch_size=args.val_batch_size, shuffle=False, 62 | num_workers=args.num_threads, pin_memory=False) 63 | 64 | self.model = ModelManager(args) 65 | self.model.initilalize_deblur_model(args, model_folder=args.model_folder, model_name=args.model_name, tb_path=tb_path) 66 | self._setup_device() 67 | self.optimizer = Adam(self.model.get_deblurnet_optimizer_params(), lr=args.learning_rate) 68 | self.start_epoch = 0 69 | self.end_epoch = args.total_epochs 70 | self.PSNR_calculator = PSNR() 71 | self.SSIM_calculator = SSIM() 72 | self.logger = get_logger(tb_path, 'log.txt', 'append') 73 | self._log_arguments() 74 | 75 | def _setup_device(self): 76 | if torch.cuda.is_available(): 77 | self.model.cuda_deblur() 78 | if self.args.use_multigpu: 79 | self.model.use_multi_gpu_deblur() 80 | 81 | def _log_arguments(self): 82 | self.logger.info(f'Overall parameter count: {self.model.count_total_parameters() * 1e-6:.4f} MB') 83 | for arg, val in vars(self.args).items(): 84 | self.logger.info(f'{arg}: {val}') 85 | self.tb.add_text('logs', f'model name: {self.args.model_name}') 86 | 87 | def train(self): 88 | for self.epoch in trange(self.start_epoch, self.end_epoch, desc='Epoch Progress'): 89 | for sample in tqdm(self.train_loader, desc='Training Progress'): 90 | self.optimizer.zero_grad() 91 | sample = batch2device(sample) 92 | self.model.set_video_inputs(sample) 93 | self.model.forward_deblur_net() 94 | loss = self.model.get_single_loss() 95 | loss.backward() 96 | self.optimizer.step() 97 | self.model.update_loss_meters_deblur() 98 | self.tb_iter_cnt += 1 99 | if self.args.batch_size * self.tb_iter_cnt > self.tb_iter_thresh: 100 | self.log_train_tb() 101 | if self.epoch % 20 == 0: 102 | self.test(self.epoch) 103 | self.save_model(self.epoch) 104 | 105 | def test(self, epoch): 106 | psnr_meter, ssim_meter = AverageMeter(), AverageMeter() 107 | self.model.del_batch() 108 | with torch.no_grad(): 109 | for sample in tqdm(self.test_loader, desc='Testing Progress'): 110 | sample = batch2device(sample) 111 | self.model.set_video_inputs(sample) 112 | self.model.forward_deblur_net() 113 | psnr_meter.update(self.PSNR_calculator(self.model.batch['clean_middle'], self.model.batch['output_deblur'][0]).mean().item()) 114 | ssim_meter.update(self.SSIM_calculator(self.model.batch['clean_middle'], self.model.batch['output_deblur'][0]).mean().item()) 115 | self.tb.add_scalar('val_progress/avg_psnr/', psnr_meter.avg, epoch) 116 | self.tb.add_scalar('val_progress/avg_ssim/', ssim_meter.avg, epoch) 117 | self.model.del_batch() 118 | torch.cuda.empty_cache() 119 | return psnr_meter.avg 120 | 121 | def log_train_tb(self): 122 | self.tb.add_scalar('train_progress/loss_total', self.model.loss_total_meter.avg, self.tb_iter_cnt) 123 | self.tb.add_image('train_blur/input', self.model.batch['blur_input'][0], self.tb_iter_cnt) 124 | self.tb.add_image('train_output/clean_est', self.model.batch['output_deblur'][0], self.tb_iter_cnt) 125 | self.tb.add_image('train_output/clean_gt', self.model.batch['clean_middle'][0], self.tb_iter_cnt) 126 | self.tb_iter_cnt = 0 127 | self.model.reset_loss_meters_deblur() 128 | 129 | def save_model(self, epoch): 130 | state = { 131 | 'epoch': epoch, 132 | 'model_state_dict': self.model.deblur_net.state_dict(), 133 | 'optimizer_state_dict': self.optimizer.state_dict() 134 | } 135 | save_path = os.path.join(self.model.save_path, f'model_{epoch}_ep.pth') 136 | torch.save(state, save_path) 137 | 138 | 139 | if __name__ == '__main__': 140 | args = get_argument() 141 | trainer = Trainer(args) 142 | if args.mode == 'train': 143 | trainer.train() 144 | else: 145 | trainer.test(0) 146 | -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | from torch.utils import data as data 5 | from torch.utils.data import ConcatDataset 6 | from torchvision import transforms 7 | import numpy as np 8 | import random 9 | from utils.utils import randomCrop 10 | from PIL import Image 11 | 12 | 13 | class Train_Video_Dataset(data.Dataset): 14 | def __init__(self, args, data_path, crop_size=256): 15 | super(Train_Video_Dataset, self).__init__() 16 | ## 17 | self.num_frames_seq = args.num_train_video_frames 18 | self.middle_frame_id = self.num_frames_seq//2 19 | ## image and event prefix 20 | self.event_vox_prefix = 'event_voxel_parsed' 21 | self.blur_image_prefix = 'blur_processed_parsed' 22 | self.sharp_image_prefix = 'gt_processed_parsed' 23 | # transform 24 | self.transform = transforms.ToTensor() 25 | # data aug params 26 | self.get_filetaxnomy(data_path) 27 | ## crop 28 | self.crop_height = crop_size 29 | self.crop_width = crop_size 30 | 31 | def get_filetaxnomy(self, data_dir): 32 | self.input_dict = {} 33 | self.input_dict['blur_images'] = {} 34 | self.input_dict['blur_images']['0'] = [] 35 | self.input_dict['blur_images']['1'] = [] 36 | self.input_dict['blur_images']['2'] = [] 37 | self.input_dict['blur_images']['3'] = [] 38 | self.input_dict['sharp_images'] = {} 39 | self.input_dict['sharp_images']['0'] = [] 40 | self.input_dict['sharp_images']['1'] = [] 41 | self.input_dict['sharp_images']['2'] = [] 42 | self.input_dict['sharp_images']['3'] = [] 43 | self.input_dict['event_voxel'] = {} 44 | self.input_dict['event_voxel']['0'] = [] 45 | self.input_dict['event_voxel']['1'] = [] 46 | self.input_dict['event_voxel']['2'] = [] 47 | self.input_dict['event_voxel']['3'] = [] 48 | event_voxel_dir = os.path.join(data_dir, self.event_vox_prefix) 49 | blur_image_dir = os.path.join(data_dir, self.blur_image_prefix) 50 | sharp_image_dir = os.path.join(data_dir, self.sharp_image_prefix) 51 | for patch_idx in range(4): 52 | event_vox_patch_dir = os.path.join(event_voxel_dir, str(patch_idx).zfill(5)) 53 | blur_patch_dir = os.path.join(blur_image_dir, str(patch_idx).zfill(5)) 54 | sharp_patch_dir = os.path.join(sharp_image_dir, str(patch_idx).zfill(5)) 55 | num_blur_images = len(os.listdir(blur_patch_dir)) 56 | for image_idx in range(num_blur_images): 57 | blur_name = os.path.join(blur_patch_dir, str(image_idx).zfill(5) + '.png') 58 | sharp_name = os.path.join(sharp_patch_dir, str(image_idx).zfill(5) + '.png') 59 | left_voxel_name = os.path.join(event_vox_patch_dir, str(image_idx).zfill(5) + '.npz') 60 | self.input_dict['blur_images'][str(patch_idx)].append(blur_name) 61 | self.input_dict['sharp_images'][str(patch_idx)].append(sharp_name) 62 | self.input_dict['event_voxel'][str(patch_idx)].append(left_voxel_name) 63 | 64 | def __getitem__(self, index): 65 | ## patch number 66 | rand_patch_idx = np.random.randint(0, 4) 67 | ## event vox read 68 | event_vox_list = [] 69 | blur_list, gt_list = [], [] 70 | # new_index = self.num_frames_seq 71 | for video_num_idx in range(index, index + self.num_frames_seq): 72 | ## event voxel 73 | left_event_vox = np.load(self.input_dict['event_voxel'][str(rand_patch_idx)][video_num_idx])["data"] 74 | left_event_vox_tensor = torch.from_numpy(left_event_vox) 75 | event_vox_list.append(left_event_vox_tensor[None, ...]) 76 | blur_image = Image.open(self.input_dict['blur_images'][str(rand_patch_idx)][video_num_idx]) 77 | gt_image = Image.open(self.input_dict['sharp_images'][str(rand_patch_idx)][video_num_idx]) 78 | blur_image_tensor = self.transform(blur_image) 79 | gt_image_tensor = self.transform(gt_image) 80 | blur_list.append(blur_image_tensor[None, ...]) 81 | gt_list.append(gt_image_tensor[None, ...]) 82 | blur_input_clip = torch.cat(blur_list) 83 | gt_clip = torch.cat(gt_list) 84 | gt_clip_middle = gt_clip[self.middle_frame_id] 85 | event_vox_tensor = torch.cat(event_vox_list) 86 | _, _, height, width = gt_clip.shape 87 | # random crop 88 | x = random.randint(0, width - self.crop_width) 89 | y = random.randint(0, height - self.crop_height) 90 | gt_image_tensor = randomCrop(gt_clip, x, y, self.crop_height, self.crop_width) 91 | gt_clip_middle = randomCrop(gt_clip_middle, x, y, self.crop_height, self.crop_width ) 92 | blur_input_clip = randomCrop(blur_input_clip, x, y, self.crop_height, self.crop_width) 93 | event_vox_cropped = randomCrop(event_vox_tensor, x, y, self.crop_height, self.crop_width) 94 | ### sample 95 | sample = {} 96 | sample['clean_gt_clip'] = gt_image_tensor 97 | sample['clean_middle'] = gt_clip_middle 98 | sample['blur_input_clip'] = blur_input_clip 99 | sample['event_vox_clip'] = event_vox_cropped 100 | return sample 101 | 102 | def __len__(self): 103 | return len(self.input_dict['blur_images']['0'])-self.num_frames_seq//2-1 104 | 105 | 106 | class Test_Video_Dataset(data.Dataset): 107 | def __init__(self, args, data_path): 108 | super(Test_Video_Dataset, self).__init__() 109 | self.num_frames_seq = args.num_test_video_frames 110 | self.middle_frame_id = self.num_frames_seq//2 111 | ## image and event prefix 112 | self.event_vox_prefix = 'event_voxel' 113 | self.blur_image_prefix = 'blur_processed' 114 | self.sharp_image_prefix = 'gt_processed' 115 | # transform 116 | self.transform = transforms.ToTensor() 117 | # data aug params 118 | self.get_filetaxnomy(data_path) 119 | 120 | def get_filetaxnomy(self, data_dir): 121 | self.input_dict = {} 122 | self.input_dict['blur_images'] = [] 123 | self.input_dict['sharp_images'] = [] 124 | self.input_dict['event_voxel'] = [] 125 | event_voxel_dir = os.path.join(data_dir, self.event_vox_prefix) 126 | blur_image_dir = os.path.join(data_dir, self.blur_image_prefix) 127 | sharp_image_dir = os.path.join(data_dir, self.sharp_image_prefix) 128 | num_blur_images = len(os.listdir(blur_image_dir)) 129 | for image_idx in range(num_blur_images): 130 | blur_name = os.path.join(blur_image_dir, str(image_idx).zfill(5) + '.png') 131 | sharp_name = os.path.join(sharp_image_dir, str(image_idx).zfill(5) + '.png') 132 | left_voxel_name = os.path.join(event_voxel_dir, str(image_idx).zfill(5) + '.npz') 133 | self.input_dict['blur_images'].append(blur_name) 134 | self.input_dict['sharp_images'].append(sharp_name) 135 | self.input_dict['event_voxel'].append(left_voxel_name) 136 | 137 | def __getitem__(self, index): 138 | ## event vox read 139 | event_vox_list = [] 140 | blur_list, gt_list = [], [] 141 | for video_num_idx in range(index, index + self.num_frames_seq): 142 | ## event voxel 143 | left_event_vox = np.load(self.input_dict['event_voxel'][video_num_idx])["data"] 144 | left_event_vox_tensor = torch.from_numpy(left_event_vox) 145 | ## images 146 | blur_image = Image.open(self.input_dict['blur_images'][video_num_idx]) 147 | gt_image = Image.open(self.input_dict['sharp_images'][video_num_idx]) 148 | blur_image_tensor = self.transform(blur_image) 149 | gt_image_tensor = self.transform(gt_image) 150 | ## append to the list 151 | event_vox_list.append(left_event_vox_tensor[None, ...]) 152 | blur_list.append(blur_image_tensor[None, ...]) 153 | gt_list.append(gt_image_tensor[None, ...]) 154 | blur_input_clip = torch.cat(blur_list) 155 | gt_clip = torch.cat(gt_list) 156 | gt_clip_middle = gt_clip[self.middle_frame_id] 157 | event_vox_tensor = torch.cat(event_vox_list) 158 | ## prepare sample 159 | sample = {} 160 | sample['clean_gt_clip'] = gt_clip 161 | sample['clean_middle'] = gt_clip_middle 162 | sample['blur_input_clip'] = blur_input_clip 163 | sample['event_vox_clip'] = event_vox_tensor 164 | return sample 165 | 166 | def __len__(self): 167 | return len(self.input_dict['blur_images'])-self.num_frames_seq//2-1 168 | 169 | def get_train_dataset(args, mode): 170 | data_with_mode = os.path.join(args.data_dir, mode) 171 | scene_list = os.listdir(data_with_mode) 172 | dataset_list = [] 173 | for scene in scene_list: 174 | data_path = os.path.join(args.data_dir, mode, scene) 175 | dset = Train_Video_Dataset(args, data_path) 176 | dataset_list.append(dset) 177 | dataset_train_concat = ConcatDataset(dataset_list) 178 | return dataset_train_concat 179 | 180 | def get_test_dataset(args, mode): 181 | data_with_mode = os.path.join(args.data_dir, mode) 182 | scene_list = os.listdir(data_with_mode) 183 | dataset_list = [] 184 | for scene in scene_list: 185 | data_path = os.path.join(data_with_mode, scene) 186 | dsets = Test_Video_Dataset(args, data_path) 187 | dataset_list.append(dsets) 188 | dataset_test_concat = ConcatDataset(dataset_list) 189 | return dataset_test_concat -------------------------------------------------------------------------------- /utils/eval_metrics.py: -------------------------------------------------------------------------------- 1 | from scipy.ndimage import gaussian_filter 2 | import numpy as np 3 | 4 | 5 | def ssim_calculate(img1, img2, sd=1.5, C1=0.01**2, C2=0.03**2): 6 | # Processing input image 7 | img1 = np.array(img1, dtype=np.float32) / 255 8 | img1 = img1.transpose((2, 0, 1)) 9 | 10 | # Processing gt image 11 | img2 = np.array(img2, dtype=np.float32) / 255 12 | img2 = img2.transpose((2, 0, 1)) 13 | 14 | 15 | mu1 = gaussian_filter(img1, sd) 16 | mu2 = gaussian_filter(img2, sd) 17 | mu1_sq = mu1 * mu1 18 | mu2_sq = mu2 * mu2 19 | mu1_mu2 = mu1 * mu2 20 | sigma1_sq = gaussian_filter(img1 * img1, sd) - mu1_sq 21 | sigma2_sq = gaussian_filter(img2 * img2, sd) - mu2_sq 22 | sigma12 = gaussian_filter(img1 * img2, sd) - mu1_mu2 23 | 24 | ssim_num = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) 25 | 26 | ssim_den = ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 27 | 28 | ssim_map = ssim_num / ssim_den 29 | return np.mean(ssim_map) 30 | 31 | def psnr_calculate(x, y, val_range=255.0): 32 | x = x.astype(np.float) 33 | y = y.astype(np.float) 34 | diff = (x - y) / val_range 35 | mse = np.mean(diff ** 2) 36 | psnr = -10 * np.log10(mse) 37 | return psnr -------------------------------------------------------------------------------- /utils/make_train_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import cv2 4 | import numpy as np 5 | import argparse 6 | 7 | 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser(description="Low-light deblurring dataset parser") 11 | parser.add_argument("--train_data_dir", type=str, default = '/media/mnt2/dataset/RELED/train', 12 | help="Path to the training dataset directory") 13 | args = parser.parse_args() 14 | # Prefix setting.. 15 | event_vox_prefix = 'event_voxel' 16 | blur_prefix = 'blur_processed' 17 | gt_prefix = 'gt_processed' 18 | event_vox_parsed_prefix = 'event_voxel_parsed' 19 | blur_parsed_prefix = 'blur_processed_parsed' 20 | gt_parsed_prefix = 'gt_processed_parsed' 21 | print(f"Train data directory: {args.train_data_dir}") 22 | print(f"Event voxel directory: {event_vox_prefix}") 23 | print(f"Blur directory: {blur_prefix}") 24 | print(f"GT directory: {gt_prefix}") 25 | ## scene list 26 | scene_list = os.listdir(args.train_data_dir) 27 | scene_list.sort() 28 | for scene in scene_list: 29 | ## event vox 30 | event_vox_dir = os.path.join(args.train_data_dir, scene, event_vox_prefix) 31 | event_vox_list = glob.glob(os.path.join(event_vox_dir, '*.npz')) 32 | ## blur 33 | blur_dir = os.path.join(args.train_data_dir, scene, blur_prefix) 34 | blur_list = glob.glob(os.path.join(blur_dir, '*.png')) 35 | ## gt 36 | gt_dir = os.path.join(args.train_data_dir, scene, gt_prefix) 37 | gt_list = glob.glob(os.path.join(gt_dir, '*.png')) 38 | event_vox_list.sort() 39 | blur_list.sort() 40 | gt_list.sort() 41 | num_data = len(event_vox_list) 42 | ## target dir 43 | event_vox_parsed_dir = os.path.join(args.train_data_dir, scene, event_vox_parsed_prefix) 44 | blur_parsed_dir = os.path.join(args.train_data_dir, scene, blur_parsed_prefix) 45 | gt_parsed_dir = os.path.join(args.train_data_dir, scene, gt_parsed_prefix) 46 | if not os.path.exists(event_vox_parsed_dir): 47 | os.makedirs(event_vox_parsed_dir) 48 | if not os.path.exists(blur_parsed_dir): 49 | os.makedirs(blur_parsed_dir) 50 | if not os.path.exists(gt_parsed_dir): 51 | os.makedirs(gt_parsed_dir) 52 | ## clean dir 53 | clean_image_dir_0 = os.path.join(gt_parsed_dir, str(0).zfill(5)) 54 | clean_image_dir_1 = os.path.join(gt_parsed_dir, str(1).zfill(5)) 55 | clean_image_dir_2 = os.path.join(gt_parsed_dir, str(2).zfill(5)) 56 | clean_image_dir_3 = os.path.join(gt_parsed_dir, str(3).zfill(5)) 57 | if not os.path.exists(clean_image_dir_0): 58 | os.makedirs(clean_image_dir_0) 59 | if not os.path.exists(clean_image_dir_1): 60 | os.makedirs(clean_image_dir_1) 61 | if not os.path.exists(clean_image_dir_2): 62 | os.makedirs(clean_image_dir_2) 63 | if not os.path.exists(clean_image_dir_3): 64 | os.makedirs(clean_image_dir_3) 65 | ## blur dir 66 | blur_dir_0 = os.path.join(blur_parsed_dir, str(0).zfill(5)) 67 | blur_dir_1 = os.path.join(blur_parsed_dir, str(1).zfill(5)) 68 | blur_dir_2 = os.path.join(blur_parsed_dir, str(2).zfill(5)) 69 | blur_dir_3 = os.path.join(blur_parsed_dir, str(3).zfill(5)) 70 | if not os.path.exists(blur_dir_0): 71 | os.makedirs(blur_dir_0) 72 | if not os.path.exists(blur_dir_1): 73 | os.makedirs(blur_dir_1) 74 | if not os.path.exists(blur_dir_2): 75 | os.makedirs(blur_dir_2) 76 | if not os.path.exists(blur_dir_3): 77 | os.makedirs(blur_dir_3) 78 | ## vox dir 79 | vox_dir_0 = os.path.join(event_vox_parsed_dir, str(0).zfill(5)) 80 | vox_dir_1 = os.path.join(event_vox_parsed_dir, str(1).zfill(5)) 81 | vox_dir_2 = os.path.join(event_vox_parsed_dir, str(2).zfill(5)) 82 | vox_dir_3 = os.path.join(event_vox_parsed_dir, str(3).zfill(5)) 83 | if not os.path.exists(vox_dir_0): 84 | os.makedirs(vox_dir_0) 85 | if not os.path.exists(vox_dir_1): 86 | os.makedirs(vox_dir_1) 87 | if not os.path.exists(vox_dir_2): 88 | os.makedirs(vox_dir_2) 89 | if not os.path.exists(vox_dir_3): 90 | os.makedirs(vox_dir_3) 91 | for data_idx in range(num_data): 92 | ### sharp 93 | image_name = gt_list[data_idx] 94 | cur_image = cv2.imread(image_name) 95 | 96 | h,w,c = cur_image.shape 97 | h_u = int(h/2) 98 | w_u = int(w/2) 99 | 100 | saved_sharp_image_name1 = os.path.join(clean_image_dir_0, str(data_idx).zfill(5) + '.png') 101 | saved_sharp_image_name2 = os.path.join(clean_image_dir_1, str(data_idx).zfill(5) + '.png') 102 | saved_sharp_image_name3 = os.path.join(clean_image_dir_2, str(data_idx).zfill(5) + '.png') 103 | saved_sharp_image_name4 = os.path.join(clean_image_dir_3, str(data_idx).zfill(5) + '.png') 104 | 105 | c1=cur_image[:h_u, :w_u,:] 106 | c2=cur_image[h_u:, w_u:,:] 107 | c3=cur_image[:h_u, w_u:,:] 108 | c4=cur_image[h_u:, :w_u,:] 109 | 110 | cv2.imwrite(saved_sharp_image_name1, c1) 111 | cv2.imwrite(saved_sharp_image_name2, c2) 112 | cv2.imwrite(saved_sharp_image_name3, c3) 113 | cv2.imwrite(saved_sharp_image_name4, c4) 114 | 115 | ### blur 116 | blur_image_name = blur_list[data_idx] 117 | blur_image = cv2.imread(blur_image_name) 118 | 119 | saved_blur_image_name1 = os.path.join(blur_dir_0, str(data_idx).zfill(5) + '.png') 120 | saved_blur_image_name2 = os.path.join(blur_dir_1, str(data_idx).zfill(5) + '.png') 121 | saved_blur_image_name3 = os.path.join(blur_dir_2, str(data_idx).zfill(5) + '.png') 122 | saved_blur_image_name4 = os.path.join(blur_dir_3, str(data_idx).zfill(5) + '.png') 123 | 124 | b1=blur_image[:h_u, :w_u,:] 125 | b2=blur_image[h_u:, w_u:,:] 126 | b3=blur_image[:h_u, w_u:,:] 127 | b4=blur_image[h_u:, :w_u,:] 128 | 129 | cv2.imwrite(saved_blur_image_name1, b1) 130 | cv2.imwrite(saved_blur_image_name2, b2) 131 | cv2.imwrite(saved_blur_image_name3, b3) 132 | cv2.imwrite(saved_blur_image_name4, b4) 133 | 134 | ## event voxel 135 | voxel_grid = np.load(event_vox_list[data_idx])["data"] 136 | v1 = voxel_grid[:,:h_u, :w_u] 137 | v2 = voxel_grid[:,h_u:, w_u:] 138 | v3 = voxel_grid[:,:h_u, w_u:] 139 | v4 = voxel_grid[:,h_u:, :w_u] 140 | 141 | np.savez_compressed(os.path.join(vox_dir_0, str(data_idx).zfill(5) + '.npz'), data=v1) 142 | np.savez_compressed(os.path.join(vox_dir_1, str(data_idx).zfill(5) + '.npz'), data=v2) 143 | np.savez_compressed(os.path.join(vox_dir_2, str(data_idx).zfill(5) + '.npz'), data=v3) 144 | np.savez_compressed(os.path.join(vox_dir_3, str(data_idx).zfill(5) + '.npz'), data=v4) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import math 4 | import torch 5 | from math import exp 6 | from torch.autograd import Variable 7 | import torch.nn.functional as F 8 | import torch.nn as nn 9 | import logging 10 | import cv2 11 | 12 | 13 | def randomCrop(tensor, x, y, height, width): 14 | tensor = tensor[..., y:y+height, x:x+width] 15 | return tensor 16 | 17 | def str2bool(v): 18 | return v.lower() in ('true') 19 | 20 | 21 | def batch2device(dictionary_of_tensors): 22 | if isinstance(dictionary_of_tensors, dict): 23 | return {key: batch2device(value) for key, value in dictionary_of_tensors.items()} 24 | return dictionary_of_tensors.cuda() 25 | 26 | def batch2device_test(dictionary_of_tensors): 27 | sample = {} 28 | for key, value in dictionary_of_tensors.items(): 29 | if torch.is_tensor(value): 30 | sample[key] = value.cuda() 31 | else: 32 | sample[key] = value 33 | return sample 34 | 35 | 36 | # logger 37 | def get_logger(save_directory, filename, mode, name=__file__, level=logging.INFO): 38 | logger = logging.getLogger(name) 39 | if getattr(logger, '_init_done__', None): 40 | logger.setLevel(level) 41 | return logger 42 | 43 | logger._init_done__ = True 44 | logger.propagate = False 45 | logger.setLevel(level) 46 | 47 | formatter = logging.Formatter("%(message)s") 48 | handler = logging.StreamHandler() 49 | handler.setFormatter(formatter) 50 | handler.setLevel(0) 51 | 52 | del logger.handlers[:] 53 | logger.addHandler(handler) 54 | 55 | # file handler 56 | if mode=='append': 57 | file_handler = logging.FileHandler(os.path.join(save_directory, filename), mode='a') 58 | if mode=='write': 59 | file_handler = logging.FileHandler(os.path.join(save_directory, filename), mode='w') 60 | logger.addHandler(file_handler) 61 | return logger 62 | 63 | 64 | 65 | class AverageMeter(object): 66 | """Computes and stores the average and current value""" 67 | def __init__(self): 68 | self.reset() 69 | 70 | def reset(self): 71 | self.val = 0 72 | self.avg = 0 73 | self.sum = 0 74 | self.count = 0 75 | 76 | def update(self, val, n=1): 77 | self.val = val 78 | self.sum += val * n 79 | self.count += n 80 | self.avg = self.sum / self.count 81 | 82 | 83 | class PSNR: 84 | def __init__(self): 85 | self.name = "PSNR" 86 | 87 | @staticmethod 88 | def __call__(img1, img2): 89 | img1 = img1.reshape(img1.shape[0], -1) 90 | img2 = img2.reshape(img2.shape[0], -1) 91 | mse = torch.mean((img1 - img2) ** 2, dim=1) 92 | return 10* torch.log10(1 / mse) 93 | 94 | def gaussian(window_size, sigma): 95 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 96 | return gauss/gauss.sum() 97 | 98 | def create_window(window_size, channel): 99 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 100 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 101 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 102 | return window 103 | 104 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 105 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 106 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 107 | 108 | mu1_sq = mu1.pow(2) 109 | mu2_sq = mu2.pow(2) 110 | mu1_mu2 = mu1*mu2 111 | 112 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 113 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 114 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 115 | 116 | C1 = 0.01**2 117 | C2 = 0.03**2 118 | 119 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 120 | 121 | if size_average: 122 | return ssim_map.mean() 123 | else: 124 | return ssim_map.mean(1).mean(1).mean(1) 125 | 126 | class SSIM(torch.nn.Module): 127 | def __init__(self, window_size = 11, size_average = False): 128 | super(SSIM, self).__init__() 129 | self.window_size = window_size 130 | self.size_average = size_average 131 | self.channel = 1 132 | self.window = create_window(window_size, self.channel) 133 | 134 | def forward(self, img1, img2): 135 | (_, channel, _, _) = img1.size() 136 | 137 | if channel == self.channel and self.window.data.type() == img1.data.type(): 138 | window = self.window 139 | else: 140 | window = create_window(self.window_size, channel) 141 | 142 | if img1.is_cuda: 143 | window = window.cuda(img1.get_device()) 144 | window = window.type_as(img1) 145 | 146 | self.window = window 147 | self.channel = channel 148 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) --------------------------------------------------------------------------------