├── .gitignore ├── README.md ├── codes ├── dataset.py ├── metrics.py ├── model │ ├── __pycache__ │ │ └── networks.cpython-38.pyc │ ├── bsn_model.py │ ├── edsr_model.py │ ├── ldn_model.py │ └── networks.py └── utils.py ├── imgs ├── gopro.png ├── high_calib_compress.gif ├── middle_calib_compress.gif ├── overview.png └── rsb.png ├── log ├── BSN │ ├── opt.txt │ ├── test_log.txt │ └── train_log.txt ├── Deblur │ ├── opt.txt │ ├── test_log.txt │ └── train_log.txt └── SR │ ├── opt.txt │ ├── test_log.txt │ └── train_log.txt ├── model ├── BSN_1000.pth ├── DeblurNet_100.pth └── SR_70.pth ├── scripts ├── GOPRO_dataset.md ├── XVFI-main │ ├── (0000.png │ ├── README.md │ ├── XVFInet.py │ ├── __pycache__ │ │ ├── XVFInet.cpython-38.pyc │ │ └── utils.cpython-38.pyc │ ├── checkpoint_dir │ │ ├── XVFInet_Vimeo_exp1 │ │ │ └── info.txt │ │ └── XVFInet_X4K1000FPS_exp1 │ │ │ ├── XVFInet_X4K1000FPS_exp1_latest.pt │ │ │ └── info.txt │ ├── compare_psnr_ssim.m │ ├── figures │ │ ├── .txt │ │ ├── 003.gif │ │ ├── 004.gif │ │ ├── 045.gif │ │ ├── 078.gif │ │ ├── 081.gif │ │ ├── 146.gif │ │ ├── Figure5.PNG │ │ ├── Table2.PNG │ │ ├── results_045_resized_768.gif │ │ ├── results_079_resized_768.gif │ │ └── results_158_resized_768.gif │ ├── main.py │ ├── run.sh │ ├── test.gif │ ├── text_dir │ │ └── XVFInet_X4K1000FPS_exp1.txt │ └── utils.py ├── __pycache__ │ ├── generate_spike.cpython-38.pyc │ └── utils.cpython-38.pyc ├── blur_syn.py ├── run.sh ├── sharp_extract.py ├── spike_simulate.py ├── utils.py └── utils_spike.py ├── train_bsn.py ├── train_deblur.py └── train_sr.py /.gitignore: -------------------------------------------------------------------------------- 1 | scripts/GOPRO/* 2 | # log/* 3 | visual_ablation_distillation.py 4 | visual_gopro.py 5 | run.sh 6 | model/DeblurNet_150.pth 7 | model/NEW_GOPRO_9_reblur0_24_100.pth 8 | model/NEW_GOPRO_9_reblur10_24_full_100.pth 9 | model/NEW_GOPRO_9_reblur10_24_full_nolpips.pth 10 | model/NEW_GOPRO_9_reblur50_24_full_100.pth 11 | model/NEW_GOPRO_9_reblur70_24_full.pth 12 | model/NEW_GOPRO_9_reblur100_24_full_100.pth -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | SpikeReveal: Unlocking Temporal Sequences from Real Blurry Inputs with Spike Streams (NeurIPS 2024 Spotlight) 3 |

4 |
5 | If you like our project, please give us a star ⭐ on GitHub.
6 |
7 | 8 | [![arXiv](https://img.shields.io/badge/Arxiv-2403.09486-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2403.09486) 9 | [![License](https://img.shields.io/badge/License-MIT-yellow)](https://github.com/chenkang455/S-SDM) 10 | [![GitHub repo stars](https://img.shields.io/github/stars/chenkang455/S-SDM?style=flat&logo=github&logoColor=whitesmoke&label=Stars)](https://github.com/chenkang455/S-SDM/stargazers)  11 | 12 |
13 | 14 |

15 | 16 |

17 | 18 | 19 | 20 | ## 📕 Abstract 21 | > We begin with a theoretical analysis of the relationship between spike streams, blurry images, and sharp sequences, leading to the development of our Spike-guided Deblurring Model (SDM). We further construct a self-supervised processing pipeline by cascading the denoising network and the super-resolution network to reduce the sensitivity of the SDM to spike noise and its reliance on spatial-resolution matching between the two modalities. To reduce the time consumption and enhance the utilization of spatial-temporal spike information within this pipeline, we further design a Lightweight Deblurring Network (LDN) and train it based on pseudo-labels from the teacher model, i.e., the established self-supervised processing pipeline. By further introducing re-blurring loss during LDN training, we achieve better restoration performance and faster processing speed compared to the processing-lengthy and structure-complicated teacher model. 22 | 23 | ## 👀 Visual Comparisons 24 |
25 | Sequence reconstruction on RSB dataset under different light conditions. (flicker is caused by the gif compression) 26 |

27 | 28 | 29 |

30 |
31 | 32 | 33 |
Image debluring on RSB Dataset. 34 |

35 | RSB Dataset 36 |

37 |
38 | 39 | 40 | ## 🗓️ TODO 41 | - [x] Release the scripts for simulating GOPRO dataset. 42 | - [x] Release the training and testing code. 43 | - [x] Release the pretrained model. 44 | - [x] Release the synthetic/real-world dataset. 45 | 46 | ## 🕶 Dataset 47 | 48 | Guidance on synthesizing the spike-based GOPRO dataset can be found in [GOPRO_dataset](scripts/GOPRO_dataset.md). 49 | 50 | Converted GOPRO dataset can be found in [GOPRO](https://pan.baidu.com/s/1ZvRNF4kqVB8qe1K78hmnzg?pwd=1623) and the real-world blur RSB dataset will be public once our manuscript is accepted. 51 | 52 | ## 🍭 Prepare 53 | Our S-SDM requires the sequential training of BSN, EDSR and LDN respectively. We provide the trained weights through the [link](https://pan.baidu.com/s/1FGqlMFtnL5jwI39I5mNkTw?pwd=1623), which should be placed in the folder `model/`. Meanwhile, downloaded/converted GOPRO dataset should be located under the `project root` folder. The structure of our project is formulated as: 54 | ``` 55 | 56 | ├── codes 57 | ├── imgs 58 | ├── log (train and evaluation results) 59 | ├── model 60 | │ ├── BSN_1000.pth 61 | │ └── ... 62 | ├── scripts 63 | ├── GOPRO 64 | │ ├── test 65 | │ └── train 66 | ├── train_bsn.py 67 | ├── train_deblur.py 68 | └── train_sr.py 69 | ``` 70 | 71 | ## 🌅 Train 72 | 73 | Train BSN on the GOPRO dataset: 74 | ``` 75 | python train_bsn.py --base_folder GOPRO/ --bsn_len 9 --data_type GOPRO 76 | ``` 77 | 78 | Train EDSR on the GOPRO dataset: 79 | ``` 80 | python train_sr.py --base_folder GOPRO/ --data_type GOPRO 81 | ``` 82 | 83 | Train LDN on the GOPRO dataset: 84 | ``` 85 | python train_deblur.py --base_folder GOPRO/ --bsn_path model/BSN_1000.pth --sr_path model/SR_70.pth --lambda_reblur 100 --data_type GOPRO 86 | ``` 87 | 88 | ## 📊 Evaluate 89 | 90 | Evaluate BSN on the GOPRO dataset: 91 | ``` 92 | python train_bsn.py --test_mode --bsn_path model/BSN_1000.pth --data_type GOPRO 93 | ``` 94 | 95 | Evaluate EDSR on the GOPRO dataset: 96 | ``` 97 | python train_sr.py --test_mode --bsn_path model/BSN_1000.pth --sr_path model/SR_70.pth --data_type GOPRO 98 | ``` 99 | 100 | Evaluate LDN on the GOPRO dataset: 101 | ``` 102 | python train_deblur.py --test_mode --bsn_path model/BSN_1000.pth --sr_path model/SR_70.pth --deblur_path model/DeblurNet_100.pth --data_type GOPRO 103 | ``` 104 | 105 | ## 📞 Contact 106 | Should you have any questions, please feel free to contact [mrchenkang@whu.edu.cn](mailto:mrchenkang@whu.edu.cn). 107 | 108 | ## 🤝 Citation 109 | If you find our work useful in your research, please cite: 110 | 111 | ``` 112 | @article{chen2025spikereveal, 113 | title={Spikereveal: Unlocking temporal sequences from real blurry inputs with spike streams}, 114 | author={Chen, Kang and Chen, Shiyan and Zhang, Jiyuan and Zhang, Baoyue and Zheng, Yajing and Huang, Tiejun and Yu, Zhaofei}, 115 | journal={Advances in Neural Information Processing Systems}, 116 | volume={37}, 117 | pages={62673--62696}, 118 | year={2025} 119 | } 120 | ``` 121 | -------------------------------------------------------------------------------- /codes/dataset.py: -------------------------------------------------------------------------------- 1 | from codes.utils import * 2 | import glob 3 | from torchvision import transforms 4 | # Spike Dataset 5 | class SpikeData(torch.utils.data.Dataset): 6 | def __init__(self, root_dir, data_type = 'GOPRO', stage = 'train', 7 | use_resize = False,use_roi = False,roi_size = [256,256], 8 | use_small = False,spike_full = False): 9 | """ Spike Dataset 10 | 11 | Args: 12 | root_dir (str): base folder of the dataset 13 | data_type (str, optional): data type. Defaults to 'GOPRO'. 14 | stage (str, optional): train / test. Defaults to 'train'. 15 | use_resize (bool, optional): . Defaults to False. 16 | use_roi (bool, optional): ROI operation. Defaults to False. 17 | roi_size (list, optional): ROI size for the image, [ROI/4,ROI/4] for the spike. Defaults to [256,256]. 18 | use_small (bool, optional): small dataset for debugging. Defaults to False. 19 | spike_full (bool, optional): full spike size (1280 * 720) instead of (320 * 180). Defaults to False. 20 | """ 21 | self.root_dir = root_dir 22 | self.data_type = data_type 23 | self.use_resize = use_resize 24 | self.use_roi = use_roi 25 | self.roi_size = roi_size 26 | self.spike_full = spike_full 27 | # Real DataSet 28 | if data_type == 'real': 29 | pattern = os.path.join(self.root_dir,'spike_data', '*','*') 30 | self.spike_list = sorted(glob.glob(pattern)) 31 | self.width = 400 * 4 32 | self.height = 250 * 4 33 | # GOPRO Synthetic DataSet 34 | elif data_type == 'GOPRO': 35 | if self.spike_full: 36 | pattern = os.path.join(self.root_dir, stage,'spike_full', '*','*') 37 | else: 38 | pattern = os.path.join(self.root_dir, stage,'spike_data', '*','*') 39 | self.spike_list = sorted(glob.glob(pattern)) 40 | if use_small == True: 41 | self.spike_list = self.spike_list[::10] 42 | self.width = 1280 43 | self.height = 720 44 | self.resize = transforms.Resize((self.height // 2,self.width // 2),interpolation=transforms.InterpolationMode.NEAREST) 45 | self.img_size = (self.height,self.width) if use_resize == False else (self.height // 2,self.width // 2) 46 | self.length = len(self.spike_list) 47 | 48 | def __getitem__(self, index: int): 49 | # blur and spike load 50 | if self.data_type in ['real']: 51 | spike_name = self.spike_list[index] 52 | spike = load_vidar_dat(spike_name,width=self.width //4 ,height=self.height // 4) 53 | blur_name = spike_name.replace('.dat','.jpg').replace('spike_data','blur_data') 54 | blur = cv2.imread(blur_name) 55 | elif self.data_type in ['GOPRO']: 56 | spike_name = self.spike_list[index] 57 | if self.spike_full: 58 | spike = load_vidar_dat(spike_name,width=self.width ,height=self.height ) 59 | blur_name = spike_name.replace('.dat','.png').replace('spike_full','blur_data') 60 | else: 61 | spike = load_vidar_dat(spike_name,width=self.width // 4,height=self.height // 4) 62 | blur_name = spike_name.replace('.dat','.png').replace('spike_data','blur_data') 63 | blur = cv2.imread(blur_name) 64 | 65 | # sharp load 66 | if self.data_type in ['real']: 67 | sharp = np.zeros_like(blur) 68 | elif self.data_type in ['GOPRO']: 69 | if self.spike_full: 70 | sharp_name = spike_name.replace('.dat','.png').replace('spike_full','sharp_data') 71 | else: 72 | sharp_name = spike_name.replace('.dat','.png').replace('spike_data','sharp_data') 73 | sharp = cv2.imread(sharp_name) 74 | 75 | # channel & property exchange 76 | blur = torch.from_numpy(blur).permute((2,0,1)).float() / 255 77 | sharp = torch.from_numpy(sharp).permute((2,0,1)).float() / 255 78 | spike = torch.from_numpy(spike) 79 | # resize method (set true for synthetic NeRF dataset and false for real dataset) 80 | if self.use_resize == True: 81 | blur,spike,sharp = self.resize(blur),self.resize(spike),self.resize(sharp) 82 | # randomly crop 83 | if self.use_roi == True: 84 | if self.data_type not in ['GOPRO','real']: 85 | roiTL = (np.random.randint(0, self.img_size[0] -self.roi_size[0]+1), np.random.randint(0, self.img_size[1] -self.roi_size[1]+1)) 86 | roiBR = (roiTL[0]+self.roi_size[0],roiTL[1]+self.roi_size[1]) 87 | blur = blur[:,roiTL[0]:roiBR[0], roiTL[1]:roiBR[1]] 88 | spike = spike[:,roiTL[0]:roiBR[0], roiTL[1]:roiBR[1]] 89 | sharp = sharp[:,roiTL[0]:roiBR[0], roiTL[1]:roiBR[1]] 90 | else: 91 | roiTL = (np.random.randint(0, self.height // 4 - self.roi_size[0] // 4 +1), np.random.randint(0, self.width // 4 - self.roi_size[1] // 4+1)) 92 | roiBR = (roiTL[0]+self.roi_size[0]//4,roiTL[1]+self.roi_size[1]//4) 93 | blur = blur[:,4 * roiTL[0]:4 * roiBR[0], 4 * roiTL[1]:4 * roiBR[1]] 94 | if self.spike_full: 95 | spike = spike[:,4 * roiTL[0]:4 * roiBR[0], 4 * roiTL[1]:4 * roiBR[1]] 96 | else: 97 | spike = spike[:,roiTL[0]:roiBR[0], roiTL[1]:roiBR[1]] 98 | sharp = sharp[:,4 * roiTL[0]:4 *roiBR[0], 4 *roiTL[1]:4 * roiBR[1]] 99 | return blur,spike,sharp 100 | 101 | def __len__(self): 102 | return self.length -------------------------------------------------------------------------------- /codes/metrics.py: -------------------------------------------------------------------------------- 1 | from skimage import metrics 2 | import torch 3 | import torch.hub 4 | from lpips.lpips import LPIPS 5 | import os 6 | import numpy as np 7 | 8 | photometric = { 9 | "mse": None, 10 | "ssim": None, 11 | "psnr": None, 12 | "lpips": None 13 | } 14 | 15 | 16 | def compute_img_metric(im1t: torch.Tensor, im2t: torch.Tensor, 17 | metric="mse", margin=0, mask=None): 18 | """ 19 | im1t, im2t: torch.tensors with batched imaged shape, range from (0, 1) 20 | """ 21 | if metric not in photometric.keys(): 22 | raise RuntimeError(f"img_utils:: metric {metric} not recognized") 23 | if photometric[metric] is None: 24 | if metric == "mse": 25 | photometric[metric] = metrics.mean_squared_error 26 | elif metric == "ssim": 27 | photometric[metric] = metrics.structural_similarity 28 | elif metric == "psnr": 29 | photometric[metric] = metrics.peak_signal_noise_ratio 30 | elif metric == "lpips": 31 | photometric[metric] = LPIPS().cpu() 32 | 33 | if mask is not None: 34 | if mask.dim() == 3: 35 | mask = mask.unsqueeze(1) 36 | if mask.shape[1] == 1: 37 | mask = mask.expand(-1, 3, -1, -1) 38 | mask = mask.permute(0, 2, 3, 1).numpy() 39 | batchsz, hei, wid, _ = mask.shape 40 | if margin > 0: 41 | marginh = int(hei * margin) + 1 42 | marginw = int(wid * margin) + 1 43 | mask = mask[:, marginh:hei - marginh, marginw:wid - marginw] 44 | 45 | # convert from [0, 1] to [-1, 1] 46 | im1t = (im1t * 2 - 1).clamp(-1, 1) 47 | im2t = (im2t * 2 - 1).clamp(-1, 1) 48 | 49 | if im1t.dim() == 3: 50 | im1t = im1t.unsqueeze(0) 51 | im2t = im2t.unsqueeze(0) 52 | im1t = im1t.detach().cpu() 53 | im2t = im2t.detach().cpu() 54 | 55 | if im1t.shape[-1] == 3: 56 | im1t = im1t.permute(0, 3, 1, 2) 57 | im2t = im2t.permute(0, 3, 1, 2) 58 | 59 | im1 = im1t.permute(0, 2, 3, 1).numpy() 60 | im2 = im2t.permute(0, 2, 3, 1).numpy() 61 | batchsz, hei, wid, _ = im1.shape 62 | if margin > 0: 63 | marginh = int(hei * margin) + 1 64 | marginw = int(wid * margin) + 1 65 | im1 = im1[:, marginh:hei - marginh, marginw:wid - marginw] 66 | im2 = im2[:, marginh:hei - marginh, marginw:wid - marginw] 67 | values = [] 68 | 69 | for i in range(batchsz): 70 | if metric in ["mse", "psnr"]: 71 | if mask is not None: 72 | im1 = im1 * mask[i] 73 | im2 = im2 * mask[i] 74 | value = photometric[metric]( 75 | im1[i], im2[i] 76 | ) 77 | if mask is not None: 78 | hei, wid, _ = im1[i].shape 79 | pixelnum = mask[i, ..., 0].sum() 80 | value = value - 10 * np.log10(hei * wid / pixelnum) 81 | elif metric in ["ssim"]: 82 | value, ssimmap = photometric["ssim"]( 83 | im1[i], im2[i], multichannel=True, full=True 84 | ) 85 | if mask is not None: 86 | value = (ssimmap * mask[i]).sum() / mask[i].sum() 87 | elif metric in ["lpips"]: 88 | value = photometric[metric]( 89 | im1t[i:i + 1], im2t[i:i + 1] 90 | )[0,0,0,0] 91 | else: 92 | raise NotImplementedError 93 | values.append(value) 94 | 95 | return sum(values) / len(values) 96 | -------------------------------------------------------------------------------- /codes/model/__pycache__/networks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/codes/model/__pycache__/networks.cpython-38.pyc -------------------------------------------------------------------------------- /codes/model/bsn_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | # from utils import * 5 | import numpy as np 6 | from codes.model.networks import * 7 | 8 | class crop(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, x): 13 | N, C, H, W = x.shape 14 | x = x[0:N, 0:C, 0:H-1, 0:W] 15 | return x 16 | 17 | class shift(nn.Module): 18 | def __init__(self): 19 | super().__init__() 20 | self.shift_down = nn.ZeroPad2d((0,0,1,0)) 21 | self.crop = crop() 22 | 23 | def forward(self, x): 24 | x = self.shift_down(x) 25 | x = self.crop(x) 26 | return x 27 | 28 | class Conv(nn.Module): 29 | def __init__(self, in_channels, out_channels, bias=False, blind=True,stride=1,padding=0,kernel_size=3): 30 | super().__init__() 31 | self.blind = blind 32 | if blind: 33 | self.shift_down = nn.ZeroPad2d((0,0,1,0)) 34 | self.crop = crop() 35 | self.replicate = nn.ReplicationPad2d(1) 36 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,padding=padding,bias=bias) 37 | self.relu = nn.LeakyReLU(0.1, inplace=True) 38 | 39 | def forward(self, x): 40 | if self.blind: 41 | x = self.shift_down(x) 42 | x = self.replicate(x) 43 | x = self.conv(x) 44 | x = self.relu(x) 45 | if self.blind: 46 | x = self.crop(x) 47 | return x 48 | 49 | class Pool(nn.Module): 50 | def __init__(self, blind=True): 51 | super().__init__() 52 | self.blind = blind 53 | if blind: 54 | self.shift = shift() 55 | self.pool = nn.MaxPool2d(2) 56 | 57 | def forward(self, x): 58 | if self.blind: 59 | x = self.shift(x) 60 | x = self.pool(x) 61 | return x 62 | 63 | class rotate(nn.Module): 64 | def __init__(self): 65 | super().__init__() 66 | 67 | def forward(self, x): 68 | x90 = x.transpose(2,3).flip(3) 69 | x180 = x.flip(2).flip(3) 70 | x270 = x.transpose(2,3).flip(2) 71 | x = torch.cat((x,x90,x180,x270), dim=0) 72 | return x 73 | 74 | class unrotate(nn.Module): 75 | def __init__(self): 76 | super().__init__() 77 | 78 | def forward(self, x): 79 | x0, x90, x180, x270 = torch.chunk(x, 4, dim=0) 80 | x90 = x90.transpose(2,3).flip(2) 81 | x180 = x180.flip(2).flip(3) 82 | x270 = x270.transpose(2,3).flip(3) 83 | x = torch.cat((x0,x90,x180,x270), dim=1) 84 | return x 85 | 86 | class ENC_Conv(nn.Module): 87 | def __init__(self, in_channels, mid_channels, out_channels, bias=False, reduce=True, blind=True): 88 | super().__init__() 89 | self.reduce = reduce 90 | self.conv1 = Conv(in_channels, mid_channels, bias=bias, blind=blind) 91 | self.conv2 = Conv(mid_channels, mid_channels, bias=bias, blind=blind) 92 | self.conv3 = Conv(mid_channels, out_channels, bias=bias, blind=blind) 93 | if reduce: 94 | self.pool = Pool(blind=blind) 95 | 96 | def forward(self, x): 97 | x = self.conv1(x) 98 | x = self.conv2(x) 99 | x = self.conv3(x) 100 | if self.reduce: 101 | x = self.pool(x) 102 | return x 103 | 104 | class DEC_Conv(nn.Module): 105 | def __init__(self, in_channels, mid_channels, out_channels, bias=False, blind=True): 106 | super().__init__() 107 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest') 108 | self.conv1 = Conv(in_channels, mid_channels, bias=bias, blind=blind) 109 | self.conv2 = Conv(mid_channels, mid_channels, bias=bias, blind=blind) 110 | self.conv3 = Conv(mid_channels, mid_channels, bias=bias, blind=blind) 111 | self.conv4 = Conv(mid_channels, out_channels, bias=bias, blind=blind) 112 | 113 | def forward(self, x, x_in): 114 | x = self.upsample(x) 115 | 116 | # Smart Padding 117 | diffY = x_in.size()[2] - x.size()[2] 118 | diffX = x_in.size()[3] - x.size()[3] 119 | x = F.pad(x, [diffX // 2, diffX - diffX // 2, 120 | diffY // 2, diffY - diffY // 2]) 121 | 122 | x = torch.cat((x, x_in), dim=1) 123 | x = self.conv1(x) 124 | x = self.conv2(x) 125 | x = self.conv3(x) 126 | x = self.conv4(x) 127 | return x 128 | 129 | class Blind_UNet(nn.Module): 130 | def __init__(self, n_channels=3, n_output=96, bias=False, blind=True): 131 | super().__init__() 132 | self.n_channels = n_channels 133 | self.bias = bias 134 | self.enc1 = ENC_Conv(n_channels, 48, 48, bias=bias, blind=blind) 135 | self.enc2 = ENC_Conv(48, 48, 48, bias=bias, blind=blind) 136 | self.enc3 = ENC_Conv(48, 96, 48, bias=bias, reduce=False, blind=blind) 137 | self.dec2 = DEC_Conv(96, 96, 96, bias=bias, blind=blind) 138 | self.dec1 = DEC_Conv(96+n_channels, 96, n_output, bias=bias, blind=blind) 139 | 140 | def forward(self, input): 141 | x1 = self.enc1(input) 142 | x2 = self.enc2(x1) 143 | x = self.enc3(x2) 144 | x = self.dec2(x, x1) 145 | x = self.dec1(x, input) 146 | return x 147 | 148 | class BSN(nn.Module): 149 | def __init__(self, n_channels=1, n_output=1, bias=False, blind=True, sigma_known=True): 150 | super().__init__() 151 | self.n_channels = n_channels 152 | self.c = n_channels 153 | self.n_output = n_output 154 | self.bias = bias 155 | self.blind = blind 156 | self.sigma_known = sigma_known 157 | self.rotate = rotate() 158 | self.unet = Blind_UNet(n_channels=1, bias=bias, blind=blind) 159 | self.shift = shift() 160 | self.unrotate = unrotate() 161 | self.nin_A = nn.Conv2d(384, 384, 1, bias=bias) 162 | self.nin_B = nn.Conv2d(384, 96, 1, bias=bias) 163 | self.nin_C = nn.Conv2d(96, n_output, 1, bias=bias) 164 | 165 | def forward(self, x): 166 | N, C, H, W = x.shape 167 | # square 168 | if(H > W): 169 | diff = H - W 170 | x = F.pad(x, [diff // 2, diff - diff // 2, 0, 0], mode = 'reflect') 171 | elif(W > H): 172 | diff = W - H 173 | x = F.pad(x, [0, 0, diff // 2, diff - diff // 2], mode = 'reflect') 174 | x = self.rotate(x) 175 | x = self.unet(x) # 4 3 100 100 -> 4 96 100 100 176 | if self.blind: 177 | x = self.shift(x) 178 | x = self.unrotate(x) #4 96 100 100 -> 1 384 100 100 179 | x = F.leaky_relu_(self.nin_A(x), negative_slope=0.1) 180 | x = F.leaky_relu_(self.nin_B(x), negative_slope=0.1) 181 | x = self.nin_C(x) 182 | # Unsquare 183 | if(H > W): 184 | diff = H - W 185 | x = x[:, :, 0:H, (diff // 2):(diff // 2 + W)] 186 | elif(W > H): 187 | diff = W - H 188 | x = x[:, :, (diff // 2):(diff // 2 + H), 0:W] 189 | return x -------------------------------------------------------------------------------- /codes/model/edsr_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 7 | return nn.Conv2d( 8 | in_channels, out_channels, kernel_size, 9 | padding=(kernel_size//2), bias=bias) 10 | 11 | class MeanShift(nn.Conv2d): 12 | def __init__( 13 | self, rgb_range, 14 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 15 | 16 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 17 | std = torch.Tensor(rgb_std) 18 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 19 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 20 | for p in self.parameters(): 21 | p.requires_grad = False 22 | import torch 23 | import torch.nn as nn 24 | 25 | class MeanShift_GRAY(nn.Conv2d): 26 | def __init__( 27 | self, rgb_range, 28 | gray_mean=0.446, gray_std=1.0, sign=-1): 29 | 30 | super(MeanShift_GRAY, self).__init__(1, 1, kernel_size=1) 31 | std = torch.tensor([gray_std]) 32 | self.weight.data = torch.eye(1).view(1, 1, 1, 1) / std.view(1, 1, 1, 1) 33 | self.bias.data = sign * rgb_range * torch.tensor([gray_mean]) / std 34 | for p in self.parameters(): 35 | p.requires_grad = False 36 | 37 | class BasicBlock(nn.Sequential): 38 | def __init__( 39 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, 40 | bn=True, act=nn.ReLU(True)): 41 | 42 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 43 | if bn: 44 | m.append(nn.BatchNorm2d(out_channels)) 45 | if act is not None: 46 | m.append(act) 47 | 48 | super(BasicBlock, self).__init__(*m) 49 | 50 | class ResBlock(nn.Module): 51 | def __init__( 52 | self, conv, n_feats, kernel_size, 53 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 54 | 55 | super(ResBlock, self).__init__() 56 | m = [] 57 | for i in range(2): 58 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 59 | if bn: 60 | m.append(nn.BatchNorm2d(n_feats)) 61 | if i == 0: 62 | m.append(act) 63 | 64 | self.body = nn.Sequential(*m) 65 | self.res_scale = res_scale 66 | 67 | def forward(self, x): 68 | res = self.body(x).mul(self.res_scale) 69 | res += x 70 | 71 | return res 72 | 73 | class Upsampler(nn.Sequential): 74 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 75 | 76 | m = [] 77 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 78 | for _ in range(int(math.log(scale, 2))): 79 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 80 | m.append(nn.PixelShuffle(2)) 81 | if bn: 82 | m.append(nn.BatchNorm2d(n_feats)) 83 | if act == 'relu': 84 | m.append(nn.ReLU(True)) 85 | elif act == 'prelu': 86 | m.append(nn.PReLU(n_feats)) 87 | 88 | elif scale == 3: 89 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 90 | m.append(nn.PixelShuffle(3)) 91 | if bn: 92 | m.append(nn.BatchNorm2d(n_feats)) 93 | if act == 'relu': 94 | m.append(nn.ReLU(True)) 95 | elif act == 'prelu': 96 | m.append(nn.PReLU(n_feats)) 97 | else: 98 | raise NotImplementedError 99 | 100 | super(Upsampler, self).__init__(*m) 101 | 102 | 103 | 104 | def make_model(args, parent=False): 105 | return EDSR(args) 106 | 107 | class EDSR(nn.Module): 108 | def __init__(self,color_num = 1, conv=default_conv): 109 | super(EDSR, self).__init__() 110 | 111 | n_resblocks = 16 112 | n_feats = 64 113 | kernel_size = 3 114 | scale = 4 115 | act = nn.ReLU(True) 116 | url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale) 117 | if color_num == 1: 118 | self.sub_mean = MeanShift_GRAY(1) 119 | self.add_mean = MeanShift_GRAY(1, sign=1) 120 | else: 121 | self.sub_mean = MeanShift(255) 122 | self.add_mean = MeanShift(255, sign=1) 123 | 124 | # define head module 125 | m_head = [conv(color_num, n_feats, kernel_size)] 126 | 127 | # define body module 128 | m_body = [ 129 | ResBlock( 130 | conv, n_feats, kernel_size, act=act, res_scale=1 131 | ) for _ in range(n_resblocks) 132 | ] 133 | m_body.append(conv(n_feats, n_feats, kernel_size)) 134 | 135 | # define tail module 136 | m_tail = [ 137 | Upsampler(conv, scale, n_feats, act=False), 138 | conv(n_feats, color_num, kernel_size) 139 | ] 140 | 141 | self.head = nn.Sequential(*m_head) 142 | self.body = nn.Sequential(*m_body) 143 | self.tail = nn.Sequential(*m_tail) 144 | 145 | def forward(self, x): 146 | x = self.sub_mean(x) 147 | x = self.head(x) 148 | 149 | res = self.body(x) 150 | res += x 151 | 152 | x = self.tail(res) 153 | x = self.add_mean(x) 154 | 155 | return x 156 | -------------------------------------------------------------------------------- /codes/model/ldn_model.py: -------------------------------------------------------------------------------- 1 | from thop import profile 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from codes.model.networks import * 6 | 7 | def conv(inDim,outDim,ks,s,p): 8 | # inDim,outDim,kernel_size,stride,padding 9 | conv = nn.Conv2d(inDim, outDim, kernel_size=ks, stride=s, padding=p) 10 | relu = nn.ReLU(True) 11 | seq = nn.Sequential(conv, relu) 12 | return seq 13 | 14 | def de_conv(inDim,outDim,ks,s,p,op): 15 | # inDim,outDim,kernel_size,stride,padding 16 | conv_t = nn.ConvTranspose2d(inDim,outDim, kernel_size=ks, stride=s, 17 | padding=p,output_padding= op) 18 | relu = nn.ReLU(inplace=True) 19 | seq = nn.Sequential(conv_t, relu) 20 | return seq 21 | 22 | class ChannelAttention(nn.Module): 23 | ## channel attention block 24 | def __init__(self, in_planes, ratio=16): 25 | super(ChannelAttention, self).__init__() 26 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 27 | self.max_pool = nn.AdaptiveMaxPool2d(1) 28 | 29 | self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False), 30 | nn.ReLU(), 31 | nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)) 32 | self.sigmoid = nn.Sigmoid() 33 | 34 | def forward(self, x): 35 | avg_out = self.fc(self.avg_pool(x)) 36 | max_out = self.fc(self.max_pool(x)) 37 | out = avg_out + max_out 38 | return self.sigmoid(out) 39 | 40 | class SpatialAttention(nn.Module): 41 | ## spatial attention block 42 | def __init__(self, kernel_size=7): 43 | super(SpatialAttention, self).__init__() 44 | 45 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False) 46 | self.sigmoid = nn.Sigmoid() 47 | 48 | def forward(self, x): 49 | avg_out = torch.mean(x, dim=1, keepdim=True) 50 | max_out, _ = torch.max(x, dim=1, keepdim=True) 51 | x = torch.cat([avg_out, max_out], dim=1) 52 | x = self.conv1(x) 53 | return self.sigmoid(x) 54 | 55 | 56 | class Deblur_Net(nn.Module): 57 | def __init__(self,spike_dim = 21) -> None: 58 | super().__init__() 59 | # down_sample 60 | self.blur_enc = conv(3,16,5,2,2) 61 | self.blur_enc2 = conv(16,32,5,2,2) 62 | # sample_same 63 | self.spike_enc = conv(spike_dim,16,3,1,1) 64 | self.spike_enc2 = conv(16,32,3,1,1) 65 | # res 66 | self.resBlock1 = ResnetBlock(64,'zero',get_norm_layer('none'), False, True) 67 | self.resBlock2 = ResnetBlock(64,'zero',get_norm_layer('none'), False, True) 68 | # CBAM 69 | self.ca = ChannelAttention(64) 70 | self.sa = SpatialAttention() 71 | # up_sample 72 | self.decoder1 = de_conv(64,32,5,2,2,(1,1)) 73 | self.decoder2 = de_conv(32,16,5,2,2,(1,1)) 74 | self.pred = nn.Conv2d(3 + 16, 3, kernel_size=1, stride=1) 75 | 76 | def forward(self,blur,spike): 77 | # blur branch 78 | 79 | blur_re = self.blur_enc(blur) 80 | blur_re = self.blur_enc2(blur_re) 81 | # spike branch 82 | spike_re = self.spike_enc(spike) 83 | spike_re = self.spike_enc2(spike_re) 84 | # fusion 85 | fusion = torch.cat([blur_re,spike_re],dim = 1) 86 | fusion = self.ca(fusion) * fusion 87 | fusion = self.sa(fusion) * fusion 88 | fusion = self.resBlock1(fusion) 89 | fusion = self.resBlock2(fusion) 90 | fusion = self.decoder1(fusion) 91 | fusion = self.decoder2(fusion) 92 | result = self.pred(torch.cat([blur,fusion],dim = 1)) 93 | return result 94 | 95 | 96 | if __name__ == '__main__': 97 | net = Deblur_Net() 98 | blur = torch.zeros((1,3,720,1280)) 99 | spike = torch.zeros((1,21,720//4,1280//4)) 100 | flops, params = profile((net).cpu(), inputs=(blur,spike)) 101 | print("FLOPs=", str(flops/1e9) +'{}'.format("G")) 102 | total = sum(p.numel() for p in net.parameters()) 103 | print("Total params: %.5fM" % (total/1e6)) 104 | print(net(blur,spike)) 105 | -------------------------------------------------------------------------------- /codes/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | import imageio 5 | import os 6 | import torch.nn as nn 7 | import random 8 | # Save Network 9 | def save_network(network, save_path): 10 | if isinstance(network, nn.DataParallel): 11 | network = network.module 12 | state_dict = network.state_dict() 13 | for key, param in state_dict.items(): 14 | state_dict[key] = param.cpu() 15 | torch.save(state_dict, save_path) 16 | 17 | def set_random_seed(seed): 18 | """Set random seeds.""" 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | np.random.seed(seed) 22 | random.seed(seed) 23 | torch.backends.cudnn.deterministic = True 24 | torch.cuda.manual_seed(seed) 25 | 26 | def save_opt(opt,opt_path): 27 | with open(opt_path, 'w') as f: 28 | for key, value in vars(opt).items(): 29 | f.write(f"{key}: {value}\n") 30 | 31 | def save_gif(image_list, gif_path = 'test', duration = 2,RGB = True,nor = False): 32 | imgs = [] 33 | os.makedirs('Video',exist_ok = True) 34 | with imageio.get_writer(os.path.join('Video',gif_path + '.gif'), mode='I',duration = 1000 * duration / len(image_list),loop=0) as writer: 35 | for i in range(len(image_list)): 36 | img = normal_img(image_list[i],RGB,nor) 37 | writer.append_data(img) 38 | 39 | def save_video(image_list,path = 'test',duration = 2,RGB = True,nor = False): 40 | os.makedirs('Video',exist_ok = True) 41 | imgs = [] 42 | for i in range(len(image_list)): 43 | img = normal_img(image_list[i],RGB,nor) 44 | imgs.append(img) 45 | imageio.mimwrite(os.path.join('Video',path + '.mp4'), imgs, fps=30, quality=8) 46 | 47 | 48 | def normal_img(img,RGB = True,nor = True): 49 | if nor: 50 | img = 255 * ((img - img.min()) / (img.max() - img.min())) 51 | if (img.shape[0] == 3 or img.shape[0] == 1) and isinstance(img,torch.Tensor): 52 | img = img.permute(1,2,0) 53 | if isinstance(img,torch.Tensor): 54 | img = np.array(img.detach().cpu()) 55 | if len(img.shape) == 2: 56 | img = img[...,None] 57 | if img.shape[-1] == 1: 58 | img = np.repeat(img,3,axis = -1) 59 | img = img.astype(np.uint8) 60 | if RGB == False: 61 | img = img[...,::-1] 62 | return img 63 | 64 | def save_img(path = 'test.png',img = None,nor = True): 65 | if nor: 66 | img = 255 * ((img - img.min()) / (img.max() - img.min())) 67 | if isinstance(img,torch.Tensor): 68 | img = np.array(img.detach().cpu()) 69 | img = img.astype(np.uint8) 70 | cv2.imwrite(path,img) 71 | 72 | def make_folder(path): 73 | os.makedirs(path,exist_ok = True) 74 | 75 | class AverageMeter(object): 76 | """Computes and stores the average and current value""" 77 | 78 | def __init__(self): 79 | self.reset() 80 | 81 | def reset(self): 82 | self.val = 0 83 | self.avg = 0 84 | self.sum = 0 85 | self.count = 0 86 | 87 | def update(self, val, n=1): 88 | self.val = val 89 | self.count += n 90 | self.sum += val * n 91 | self.avg = self.sum / self.count 92 | 93 | def video_to_spike( 94 | sourefolder=None, 95 | imgs = None, 96 | savefolder_debug=None, 97 | threshold=5.0, 98 | init_noise=True, 99 | format="png", 100 | ): 101 | """ 102 | 函数说明 103 | :param 参数名: 参数说明 104 | :return: 返回值名: 返回值说明 105 | """ 106 | if sourefolder != None: 107 | filelist = sorted(os.listdir(sourefolder)) 108 | datas = [fn for fn in filelist if fn.endswith(format)] 109 | 110 | T = len(datas) 111 | 112 | frame0 = cv2.imread(os.path.join(sourefolder, datas[0])) 113 | H, W, C = frame0.shape 114 | 115 | frame0 = cv2.cvtColor(frame0, cv2.COLOR_BGR2GRAY) 116 | 117 | spikematrix = np.zeros([T, H, W], np.uint8) 118 | 119 | if init_noise: 120 | integral = np.random.random(size=([H,W])) * threshold 121 | else: 122 | integral = np.random.zeros(size=([H,W])) 123 | 124 | Thr = np.ones_like(integral).astype(np.float32) * threshold 125 | 126 | for t in range(0, T): 127 | frame = cv2.imread(os.path.join(sourefolder, datas[t])) 128 | if C > 1: 129 | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 130 | gray = gray / 255.0 131 | integral += gray 132 | fire = (integral - Thr) >= 0 133 | fire_pos = fire.nonzero() 134 | 135 | integral[fire_pos] -= threshold 136 | spikematrix[t][fire_pos] = 1 137 | 138 | if savefolder_debug: 139 | np.save(os.path.join(savefolder_debug, "spike_debug.npy"), spikematrix) 140 | elif imgs != None: 141 | frame0 = imgs[0] 142 | H, W, C = frame0.shape 143 | T = len(imgs) 144 | spikematrix = np.zeros([T, H, W], np.uint8) 145 | 146 | if init_noise: 147 | integral = np.random.random(size=([H,W])) * threshold 148 | else: 149 | integral = np.random.zeros(size=([H,W])) 150 | 151 | Thr = np.ones_like(integral).astype(np.float32) * threshold 152 | 153 | for t in range(0, T): 154 | frame = imgs[t] 155 | if C > 1: 156 | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 157 | gray = gray / 255.0 158 | integral += gray 159 | fire = (integral - Thr) >= 0 160 | fire_pos = fire.nonzero() 161 | 162 | integral[fire_pos] -= threshold 163 | spikematrix[t][fire_pos] = 1 164 | return spikematrix 165 | 166 | 167 | def load_vidar_dat(filename, left_up=(0, 0), window=None, frame_cnt = None,height = 800, width = 800, **kwargs): 168 | if isinstance(filename, str): 169 | array = np.fromfile(filename, dtype=np.uint8) 170 | elif isinstance(filename, (list, tuple)): 171 | l = [] 172 | for name in filename: 173 | a = np.fromfile(name, dtype=np.uint8) 174 | l.append(a) 175 | array = np.concatenate(l) 176 | else: 177 | raise NotImplementedError 178 | 179 | if window == None: 180 | window = (height - left_up[0], width - left_up[0]) 181 | 182 | len_per_frame = height * width // 8 183 | framecnt = frame_cnt if frame_cnt != None else len(array) // len_per_frame 184 | 185 | spikes = [] 186 | 187 | for i in range(framecnt): 188 | compr_frame = array[i * len_per_frame: (i + 1) * len_per_frame] 189 | blist = [] 190 | for b in range(8): 191 | blist.append(np.right_shift(np.bitwise_and(compr_frame, np.left_shift(1, b)), b)) 192 | 193 | frame_ = np.stack(blist).transpose() 194 | frame_ = np.flipud(frame_.reshape((height, width), order='C')) 195 | 196 | if window is not None: 197 | spk = frame_[left_up[0]:left_up[0] + window[0], left_up[1]:left_up[1] + window[1]] 198 | else: 199 | spk = frame_ 200 | 201 | spk = spk.copy().astype(np.float32)[None] 202 | 203 | spikes.append(spk) 204 | 205 | return np.concatenate(spikes) 206 | 207 | 208 | import logging 209 | # log info 210 | def setup_logging(log_file): 211 | logger = logging.getLogger('training_logger') 212 | logger.setLevel(logging.INFO) 213 | if logger.hasHandlers(): 214 | logger.handlers.clear() 215 | file_handler = logging.FileHandler(log_file, mode='w') # 使用'w'模式打开文件 216 | file_handler.setLevel(logging.INFO) 217 | console_handler = logging.StreamHandler() 218 | console_handler.setLevel(logging.INFO) 219 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 220 | file_handler.setFormatter(formatter) 221 | console_handler.setFormatter(formatter) 222 | logger.addHandler(file_handler) 223 | logger.addHandler(console_handler) 224 | return logger 225 | 226 | def normalize(arr): 227 | return (arr - arr.min()) / (arr.max() - arr.min()) 228 | 229 | 230 | def generate_labels(file_name): 231 | num_part = file_name.split('/')[-1] 232 | non_num_part = file_name.replace(num_part, '') 233 | num = int(num_part) 234 | labels = [non_num_part + str(num + 2 * i).zfill(len(num_part)) + '.png' for i in range(-3, 4)] 235 | return labels -------------------------------------------------------------------------------- /imgs/gopro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/imgs/gopro.png -------------------------------------------------------------------------------- /imgs/high_calib_compress.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/imgs/high_calib_compress.gif -------------------------------------------------------------------------------- /imgs/middle_calib_compress.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/imgs/middle_calib_compress.gif -------------------------------------------------------------------------------- /imgs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/imgs/overview.png -------------------------------------------------------------------------------- /imgs/rsb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/imgs/rsb.png -------------------------------------------------------------------------------- /log/BSN/opt.txt: -------------------------------------------------------------------------------- 1 | base_folder: GOPRO/ 2 | save_folder: exp/BSN 3 | data_type: GOPRO 4 | exp_name: NEW_GOPRO_9_test_full 5 | epochs: 1001 6 | lr: 0.0003 7 | seed: 42 8 | bsn_len: 9 9 | width: 1280 10 | height: 720 11 | scheduler: test 12 | use_small: False 13 | spike_full: False 14 | test_mode: False 15 | bsn_path: ... 16 | -------------------------------------------------------------------------------- /log/BSN/test_log.txt: -------------------------------------------------------------------------------- 1 | 2024-03-14 17:32:06,108 - INFO - Namespace(base_folder='GOPRO/', bsn_len=9, bsn_path='model/BSN_1000.pth', data_type='GOPRO', epochs=1001, exp_name='test', height=720, lr=0.0003, save_folder='exp/BSN', seed=42, spike_full=False, test_mode=True, use_small=False, width=1280) 2 | 2024-03-14 17:32:06,391 - INFO - Start Training! 3 | 2024-03-14 17:40:11,730 - INFO - TFP: mse: 0.010 ssim: 0.733 psnr: 25.997 lpips: 0.211 4 | 2024-03-14 17:40:11,731 - INFO - BSN: mse: 0.008 ssim: 0.842 psnr: 27.365 lpips: 0.109 5 | -------------------------------------------------------------------------------- /log/Deblur/opt.txt: -------------------------------------------------------------------------------- 1 | base_folder: GOPRO/ 2 | save_folder: exp/Deblur 3 | data_type: GOPRO 4 | exp_name: NEW_GOPRO_9_reblur100_24_full 5 | bsn_path: exp/BSN/NEW_GOPRO_9_test_full/ckpts/BSN_1000.pth 6 | sr_path: exp/SR/NEW_GOPRO_9_bsn1000/ckpts/SR_70.pth 7 | deblur_path: exp/Deblur/NEW_GOPRO_9_reblur10_24_full/ckpts/DeblurNet_185.pth 8 | crop_size: 256 9 | epochs: 151 10 | lr: 0.001 11 | seed: 42 12 | spike_deblur_len: 21 13 | spike_bsn_len: 9 14 | lambda_tea: 1 15 | lambda_reblur: 100.0 16 | blur_step: 24 17 | use_small: False 18 | test_mode: False 19 | use_ssim: False 20 | bs: 4 21 | -------------------------------------------------------------------------------- /log/Deblur/test_log.txt: -------------------------------------------------------------------------------- 1 | 2024-03-15 12:47:41,145 - INFO - Namespace(base_folder='GOPRO/', blur_step=24, bs=4, bsn_path='model/BSN_1000.pth', data_type='GOPRO', deblur_path='model/DeblurNet_100.pth', epochs=101, exp_name='test', lambda_reblur=100, lambda_tea=1, lr=0.001, roi_size=512, save_folder='exp/Deblur', seed=42, spike_bsn_len=9, spike_deblur_len=21, sr_path='model/SR_70.pth', test_mode=True, use_small=False, use_ssim=False) 2 | 2024-03-15 12:47:42,745 - INFO - Start Training! 3 | 2024-03-15 13:10:10,764 - INFO - SDM: ssim: 0.732 psnr: 26.362 4 | 2024-03-15 13:10:10,765 - INFO - Deblur_Net: ssim: 0.786 psnr: 27.928 5 | -------------------------------------------------------------------------------- /log/Deblur/train_log.txt: -------------------------------------------------------------------------------- 1 | 2024-02-21 11:11:15,473 - INFO - Namespace(base_folder='GOPRO/', save_folder='exp/Deblur', data_type='GOPRO', exp_name='NEW_GOPRO_9_reblur100_24_full', bsn_path='exp/BSN/NEW_GOPRO_9_test_full/ckpts/BSN_1000.pth', sr_path='exp/SR/NEW_GOPRO_9_bsn1000/ckpts/SR_70.pth', deblur_path='exp/Deblur/NEW_GOPRO_9_reblur10_24_full/ckpts/DeblurNet_185.pth', crop_size=256, epochs=151, lr=0.001, seed=42, spike_deblur_len=21, spike_bsn_len=9, lambda_tea=1, lambda_reblur=100.0, blur_step=24, use_small=False, test_mode=False, use_ssim=False, bs=4) 2 | 2024-02-21 11:11:18,625 - INFO - Start Training! 3 | 2024-02-21 11:15:07,104 - INFO - EPOCH 0/151: Total Train Loss: 1.981791118780772, Tea Loss: 0.5576680742558979, Reblur Loss: 1.4241230393642987 4 | 2024-02-21 11:20:35,766 - INFO - EPOCH 1/151: Total Train Loss: 0.7465140743431075, Tea Loss: 0.45105000272457735, Reblur Loss: 0.295464072908674 5 | 2024-02-21 11:24:23,129 - INFO - EPOCH 2/151: Total Train Loss: 0.56218161495217, Tea Loss: 0.3737060081907165, Reblur Loss: 0.18847560703560903 6 | 2024-02-21 11:28:10,566 - INFO - EPOCH 3/151: Total Train Loss: 0.4921396754004739, Tea Loss: 0.3222369218801523, Reblur Loss: 0.16990275352032153 7 | 2024-02-21 11:31:57,957 - INFO - EPOCH 4/151: Total Train Loss: 0.4323549967307549, Tea Loss: 0.29852057851496194, Reblur Loss: 0.13383441787713019 8 | 2024-02-21 11:35:45,372 - INFO - EPOCH 5/151: Total Train Loss: 0.3681258845638919, Tea Loss: 0.2753991732349643, Reblur Loss: 0.09272671207076028 9 | 2024-02-21 11:41:04,270 - INFO - EPOCH 6/151: Total Train Loss: 0.2876348569021597, Tea Loss: 0.23525384732913146, Reblur Loss: 0.05238100977058147 10 | 2024-02-21 11:44:51,678 - INFO - EPOCH 7/151: Total Train Loss: 0.2613711170690916, Tea Loss: 0.21692385534187417, Reblur Loss: 0.04444726203362663 11 | 2024-02-21 11:48:39,157 - INFO - EPOCH 8/151: Total Train Loss: 0.24577252315236375, Tea Loss: 0.20893679752752378, Reblur Loss: 0.03683572499589487 12 | 2024-02-21 11:52:26,640 - INFO - EPOCH 9/151: Total Train Loss: 0.2426675031066457, Tea Loss: 0.20484271858419692, Reblur Loss: 0.037824784236198124 13 | 2024-02-21 11:56:14,252 - INFO - EPOCH 10/151: Total Train Loss: 0.2309971262753268, Tea Loss: 0.19534648903262564, Reblur Loss: 0.03565063678308741 14 | 2024-02-21 12:01:35,223 - INFO - EPOCH 11/151: Total Train Loss: 0.22711923157239888, Tea Loss: 0.19388342555209157, Reblur Loss: 0.03323580622189231 15 | 2024-02-21 12:05:22,875 - INFO - EPOCH 12/151: Total Train Loss: 0.2145585086825606, Tea Loss: 0.1866386016080906, Reblur Loss: 0.027919907808239327 16 | 2024-02-21 12:09:10,451 - INFO - EPOCH 13/151: Total Train Loss: 0.21047288682553675, Tea Loss: 0.18185828832578865, Reblur Loss: 0.028614598761808562 17 | 2024-02-21 12:12:58,052 - INFO - EPOCH 14/151: Total Train Loss: 0.21015932614153082, Tea Loss: 0.1806152764207873, Reblur Loss: 0.029544049349827167 18 | 2024-02-21 12:16:45,702 - INFO - EPOCH 15/151: Total Train Loss: 0.20325037359675288, Tea Loss: 0.17644172584339654, Reblur Loss: 0.026808647801736734 19 | 2024-02-21 12:22:07,721 - INFO - EPOCH 16/151: Total Train Loss: 0.1940440669720307, Tea Loss: 0.17125921996382923, Reblur Loss: 0.022784847088835457 20 | 2024-02-21 12:25:55,181 - INFO - EPOCH 17/151: Total Train Loss: 0.1925107462581618, Tea Loss: 0.1693128842167008, Reblur Loss: 0.02319786160200576 21 | 2024-02-21 12:29:42,707 - INFO - EPOCH 18/151: Total Train Loss: 0.19356717943371116, Tea Loss: 0.16971592559958948, Reblur Loss: 0.023851253963136053 22 | 2024-02-21 12:33:30,282 - INFO - EPOCH 19/151: Total Train Loss: 0.18988661661550596, Tea Loss: 0.16595214095724609, Reblur Loss: 0.023934475448611495 23 | 2024-02-21 12:37:17,834 - INFO - EPOCH 20/151: Total Train Loss: 0.1869647025545954, Tea Loss: 0.16441954630406905, Reblur Loss: 0.022545156766583908 24 | 2024-02-21 12:42:39,842 - INFO - EPOCH 21/151: Total Train Loss: 0.18639814550484413, Tea Loss: 0.16518269585711615, Reblur Loss: 0.021215449617490237 25 | 2024-02-21 12:46:28,692 - INFO - EPOCH 22/151: Total Train Loss: 0.18183056739243594, Tea Loss: 0.1610855016034919, Reblur Loss: 0.02074506588973653 26 | 2024-02-21 12:50:16,389 - INFO - EPOCH 23/151: Total Train Loss: 0.1875665687766426, Tea Loss: 0.16315624221062763, Reblur Loss: 0.024410326541824776 27 | 2024-02-21 12:54:04,005 - INFO - EPOCH 24/151: Total Train Loss: 0.1784304248177128, Tea Loss: 0.15827988965428752, Reblur Loss: 0.02015053503441088 28 | 2024-02-21 12:57:51,563 - INFO - EPOCH 25/151: Total Train Loss: 0.17520826700187864, Tea Loss: 0.15595671118466886, Reblur Loss: 0.019251555817209796 29 | 2024-02-21 13:03:13,996 - INFO - EPOCH 26/151: Total Train Loss: 0.176363579767607, Tea Loss: 0.1570910235375037, Reblur Loss: 0.019272556786477824 30 | 2024-02-21 13:07:01,638 - INFO - EPOCH 27/151: Total Train Loss: 0.1756473018493487, Tea Loss: 0.15534867197920232, Reblur Loss: 0.020298629934653575 31 | 2024-02-21 13:10:49,271 - INFO - EPOCH 28/151: Total Train Loss: 0.17550866305828094, Tea Loss: 0.15469595496401642, Reblur Loss: 0.02081270805797923 32 | 2024-02-21 13:14:36,833 - INFO - EPOCH 29/151: Total Train Loss: 0.1699268434728895, Tea Loss: 0.15181923270612568, Reblur Loss: 0.01810761091996839 33 | 2024-02-21 13:18:24,393 - INFO - EPOCH 30/151: Total Train Loss: 0.16909387423878625, Tea Loss: 0.15078474823143576, Reblur Loss: 0.018309126261347557 34 | 2024-02-21 13:23:46,279 - INFO - EPOCH 31/151: Total Train Loss: 0.17189460425149827, Tea Loss: 0.15200508085938244, Reblur Loss: 0.019889523835602777 35 | 2024-02-21 13:27:34,205 - INFO - EPOCH 32/151: Total Train Loss: 0.16703752579885128, Tea Loss: 0.14821302036057302, Reblur Loss: 0.01882450535764426 36 | 2024-02-21 13:31:23,375 - INFO - EPOCH 33/151: Total Train Loss: 0.16504471545869653, Tea Loss: 0.14643452690664308, Reblur Loss: 0.018610188056154428 37 | 2024-02-21 13:35:11,038 - INFO - EPOCH 34/151: Total Train Loss: 0.1643026904626326, Tea Loss: 0.1461101443994613, Reblur Loss: 0.01819254650665826 38 | 2024-02-21 13:38:58,722 - INFO - EPOCH 35/151: Total Train Loss: 0.1635446385581256, Tea Loss: 0.14505956334856165, Reblur Loss: 0.01848507593123815 39 | 2024-02-21 13:44:22,901 - INFO - EPOCH 36/151: Total Train Loss: 0.1661887692682671, Tea Loss: 0.14559309490960398, Reblur Loss: 0.020595673975651655 40 | 2024-02-21 13:48:10,525 - INFO - EPOCH 37/151: Total Train Loss: 0.16147938199760595, Tea Loss: 0.1428786571020688, Reblur Loss: 0.018600724242401844 41 | 2024-02-21 13:51:58,155 - INFO - EPOCH 38/151: Total Train Loss: 0.1581530040734774, Tea Loss: 0.14090757004239343, Reblur Loss: 0.01724543399076699 42 | 2024-02-21 13:55:45,762 - INFO - EPOCH 39/151: Total Train Loss: 0.1634903107628678, Tea Loss: 0.1447599632786466, Reblur Loss: 0.0187303473995555 43 | 2024-02-21 13:59:33,394 - INFO - EPOCH 40/151: Total Train Loss: 0.16260723298762267, Tea Loss: 0.14292728078442735, Reblur Loss: 0.01967995221125873 44 | 2024-02-21 14:04:56,610 - INFO - EPOCH 41/151: Total Train Loss: 0.15650884378137012, Tea Loss: 0.13880101101093995, Reblur Loss: 0.017707833064744223 45 | 2024-02-21 14:08:44,278 - INFO - EPOCH 42/151: Total Train Loss: 0.1539752763651666, Tea Loss: 0.13657345406678847, Reblur Loss: 0.017401822056476172 46 | 2024-02-21 14:12:31,886 - INFO - EPOCH 43/151: Total Train Loss: 0.15841850093303822, Tea Loss: 0.14041390366755524, Reblur Loss: 0.018004597213070888 47 | 2024-02-21 14:16:19,511 - INFO - EPOCH 44/151: Total Train Loss: 0.15566800070273412, Tea Loss: 0.1376432954813495, Reblur Loss: 0.01802470464888331 48 | 2024-02-21 14:20:07,092 - INFO - EPOCH 45/151: Total Train Loss: 0.15683167250383467, Tea Loss: 0.1378219990761249, Reblur Loss: 0.019009673439804867 49 | 2024-02-21 14:25:30,825 - INFO - EPOCH 46/151: Total Train Loss: 0.15287813766823186, Tea Loss: 0.13571227503158315, Reblur Loss: 0.017165863527654313 50 | 2024-02-21 14:29:18,514 - INFO - EPOCH 47/151: Total Train Loss: 0.15194799763944758, Tea Loss: 0.13455750331992195, Reblur Loss: 0.017390494650124988 51 | 2024-02-21 14:33:06,217 - INFO - EPOCH 48/151: Total Train Loss: 0.1535301176565034, Tea Loss: 0.13565615332358844, Reblur Loss: 0.01787396442967576 52 | 2024-02-21 14:36:53,853 - INFO - EPOCH 49/151: Total Train Loss: 0.1541611450962174, Tea Loss: 0.13536655722242413, Reblur Loss: 0.0187945876278596 53 | 2024-02-21 14:40:41,541 - INFO - EPOCH 50/151: Total Train Loss: 0.15028684750779883, Tea Loss: 0.13304561966812456, Reblur Loss: 0.017241228154146825 54 | 2024-02-21 14:53:15,101 - INFO - SDM: mse: 0.010 ssim: 0.732 psnr: 26.362 lpips: 0.259 55 | 2024-02-21 14:53:15,106 - INFO - Deblur_Net: mse: 0.008 ssim: 0.776 psnr: 27.586 lpips: 0.264 56 | 2024-02-21 14:57:02,284 - INFO - EPOCH 51/151: Total Train Loss: 0.15076929337276526, Tea Loss: 0.1338261765835089, Reblur Loss: 0.016943117418201453 57 | 2024-02-21 15:00:49,880 - INFO - EPOCH 52/151: Total Train Loss: 0.15185928489867742, Tea Loss: 0.13381835553811225, Reblur Loss: 0.018040929525864845 58 | 2024-02-21 15:04:37,557 - INFO - EPOCH 53/151: Total Train Loss: 0.14984404589190628, Tea Loss: 0.1322726078363724, Reblur Loss: 0.017571437765251507 59 | 2024-02-21 15:08:25,127 - INFO - EPOCH 54/151: Total Train Loss: 0.14938556348090565, Tea Loss: 0.13165730376651277, Reblur Loss: 0.017728259468459205 60 | 2024-02-21 15:12:12,710 - INFO - EPOCH 55/151: Total Train Loss: 0.1466710292996266, Tea Loss: 0.1298487519288992, Reblur Loss: 0.016822277068349965 61 | 2024-02-21 15:17:35,712 - INFO - EPOCH 56/151: Total Train Loss: 0.1474687285818063, Tea Loss: 0.13019359347237136, Reblur Loss: 0.017275134859469545 62 | 2024-02-21 15:21:23,447 - INFO - EPOCH 57/151: Total Train Loss: 0.1507295116588667, Tea Loss: 0.13240870517311673, Reblur Loss: 0.018320806296260067 63 | 2024-02-21 15:25:11,137 - INFO - EPOCH 58/151: Total Train Loss: 0.14794867637353543, Tea Loss: 0.13034166179565124, Reblur Loss: 0.01760701478350085 64 | 2024-02-21 15:28:58,859 - INFO - EPOCH 59/151: Total Train Loss: 0.1444331771690092, Tea Loss: 0.12768478304534764, Reblur Loss: 0.0167483939019181 65 | 2024-02-21 15:32:46,617 - INFO - EPOCH 60/151: Total Train Loss: 0.14571062539950078, Tea Loss: 0.12870393090304874, Reblur Loss: 0.017006695077016756 66 | 2024-02-21 15:38:09,601 - INFO - EPOCH 61/151: Total Train Loss: 0.15109377325355233, Tea Loss: 0.1326590437458191, Reblur Loss: 0.018434729588367206 67 | 2024-02-21 15:41:57,392 - INFO - EPOCH 62/151: Total Train Loss: 0.1454399266800323, Tea Loss: 0.1282080735737111, Reblur Loss: 0.017231853106321193 68 | 2024-02-21 15:45:45,130 - INFO - EPOCH 63/151: Total Train Loss: 0.1435814949147629, Tea Loss: 0.12696699456219032, Reblur Loss: 0.01661450036466767 69 | 2024-02-21 15:49:32,973 - INFO - EPOCH 64/151: Total Train Loss: 0.14347542728825566, Tea Loss: 0.12685335059831668, Reblur Loss: 0.016622076875397136 70 | 2024-02-21 15:53:20,998 - INFO - EPOCH 65/151: Total Train Loss: 0.14517085844433153, Tea Loss: 0.12781795230391738, Reblur Loss: 0.01735290601946317 71 | 2024-02-21 15:58:43,330 - INFO - EPOCH 66/151: Total Train Loss: 0.1446716483214717, Tea Loss: 0.1272475728482911, Reblur Loss: 0.017424075803779936 72 | 2024-02-21 16:02:31,304 - INFO - EPOCH 67/151: Total Train Loss: 0.1435448806329723, Tea Loss: 0.1264306746455498, Reblur Loss: 0.017114206168848973 73 | 2024-02-21 16:06:19,216 - INFO - EPOCH 68/151: Total Train Loss: 0.14082041811762433, Tea Loss: 0.12448064908707812, Reblur Loss: 0.01633976892169033 74 | 2024-02-21 16:10:06,925 - INFO - EPOCH 69/151: Total Train Loss: 0.14659886720118584, Tea Loss: 0.12843387777155096, Reblur Loss: 0.01816498963121986 75 | 2024-02-21 16:13:54,798 - INFO - EPOCH 70/151: Total Train Loss: 0.14360976877150597, Tea Loss: 0.12647196563420357, Reblur Loss: 0.017137803302602075 76 | 2024-02-21 16:19:16,886 - INFO - EPOCH 71/151: Total Train Loss: 0.16021025964579025, Tea Loss: 0.13979781754476167, Reblur Loss: 0.02041244211715537 77 | 2024-02-21 16:23:04,451 - INFO - EPOCH 72/151: Total Train Loss: 0.1543136409350804, Tea Loss: 0.13649254250320006, Reblur Loss: 0.017821098323024455 78 | 2024-02-21 16:26:52,086 - INFO - EPOCH 73/151: Total Train Loss: 0.14720702142297448, Tea Loss: 0.1300513201203697, Reblur Loss: 0.01715570112520998 79 | 2024-02-21 16:30:39,705 - INFO - EPOCH 74/151: Total Train Loss: 0.14704671389225757, Tea Loss: 0.12937674810101976, Reblur Loss: 0.017669965980727693 80 | 2024-02-21 16:34:27,306 - INFO - EPOCH 75/151: Total Train Loss: 0.1434673515362141, Tea Loss: 0.12661278025283443, Reblur Loss: 0.016854570815702536 81 | 2024-02-21 16:39:49,710 - INFO - EPOCH 76/151: Total Train Loss: 0.1419998449293566, Tea Loss: 0.1251836247193865, Reblur Loss: 0.01681622104856359 82 | 2024-02-21 16:43:37,228 - INFO - EPOCH 77/151: Total Train Loss: 0.14403526752671122, Tea Loss: 0.1264730587085604, Reblur Loss: 0.017562208681073024 83 | 2024-02-21 16:47:24,841 - INFO - EPOCH 78/151: Total Train Loss: 0.14474546919008355, Tea Loss: 0.127241933016808, Reblur Loss: 0.017503536636920978 84 | 2024-02-21 16:51:12,997 - INFO - EPOCH 79/151: Total Train Loss: 0.14180082314974302, Tea Loss: 0.12507406135142107, Reblur Loss: 0.01672676197168502 85 | 2024-02-21 16:55:01,443 - INFO - EPOCH 80/151: Total Train Loss: 0.14336753639823946, Tea Loss: 0.12627104866556274, Reblur Loss: 0.017096487591567237 86 | 2024-02-21 17:00:25,386 - INFO - EPOCH 81/151: Total Train Loss: 0.1428917437404781, Tea Loss: 0.12610096681169616, Reblur Loss: 0.016790777166652216 87 | 2024-02-21 17:04:13,867 - INFO - EPOCH 82/151: Total Train Loss: 0.14355934830584052, Tea Loss: 0.12629533078505364, Reblur Loss: 0.017264017097458437 88 | 2024-02-21 17:08:02,698 - INFO - EPOCH 83/151: Total Train Loss: 0.14320857848697927, Tea Loss: 0.12600807868170016, Reblur Loss: 0.017200499970578785 89 | 2024-02-21 17:11:51,564 - INFO - EPOCH 84/151: Total Train Loss: 0.14073882974458463, Tea Loss: 0.12396387397855907, Reblur Loss: 0.016774955568472288 90 | 2024-02-21 17:15:40,399 - INFO - EPOCH 85/151: Total Train Loss: 0.14087068154053253, Tea Loss: 0.12417404866450793, Reblur Loss: 0.01669663344046254 91 | 2024-02-21 17:21:13,439 - INFO - EPOCH 86/151: Total Train Loss: 0.1422664277352296, Tea Loss: 0.12493182070456542, Reblur Loss: 0.017334606982283778 92 | 2024-02-21 17:25:01,949 - INFO - EPOCH 87/151: Total Train Loss: 0.14089638091527024, Tea Loss: 0.12385588400549703, Reblur Loss: 0.01704049695815359 93 | 2024-02-21 17:28:50,445 - INFO - EPOCH 88/151: Total Train Loss: 0.1399132420683836, Tea Loss: 0.12308049456768738, Reblur Loss: 0.016832747403935436 94 | 2024-02-21 17:32:38,902 - INFO - EPOCH 89/151: Total Train Loss: 0.13818140953650207, Tea Loss: 0.12187661856283873, Reblur Loss: 0.016304791469562364 95 | 2024-02-21 17:36:27,002 - INFO - EPOCH 90/151: Total Train Loss: 0.14573748516184942, Tea Loss: 0.12821079980888409, Reblur Loss: 0.017526685252172863 96 | 2024-02-21 17:41:49,848 - INFO - EPOCH 91/151: Total Train Loss: 0.14120482869478532, Tea Loss: 0.12441673494262613, Reblur Loss: 0.016788093941649058 97 | 2024-02-21 17:45:37,999 - INFO - EPOCH 92/151: Total Train Loss: 0.1404902297051954, Tea Loss: 0.12345530028209026, Reblur Loss: 0.017034929011871803 98 | 2024-02-21 17:49:26,217 - INFO - EPOCH 93/151: Total Train Loss: 0.13812280343079464, Tea Loss: 0.121862444397691, Reblur Loss: 0.016260358750884666 99 | 2024-02-21 17:53:14,424 - INFO - EPOCH 94/151: Total Train Loss: 0.13867980148240086, Tea Loss: 0.12204451436475242, Reblur Loss: 0.016635287343423604 100 | 2024-02-21 17:57:02,561 - INFO - EPOCH 95/151: Total Train Loss: 0.14069516056930864, Tea Loss: 0.12385144284664294, Reblur Loss: 0.016843716847786915 101 | 2024-02-21 18:02:26,331 - INFO - EPOCH 96/151: Total Train Loss: 0.13945014778024706, Tea Loss: 0.1225252967182692, Reblur Loss: 0.016924851303879832 102 | 2024-02-21 18:06:14,652 - INFO - EPOCH 97/151: Total Train Loss: 0.1372118221345918, Tea Loss: 0.12085662469332352, Reblur Loss: 0.016355197259841803 103 | 2024-02-21 18:10:02,894 - INFO - EPOCH 98/151: Total Train Loss: 0.14199044339326553, Tea Loss: 0.12498571965601538, Reblur Loss: 0.017004724116229906 104 | 2024-02-21 18:13:51,253 - INFO - EPOCH 99/151: Total Train Loss: 0.13994478705254468, Tea Loss: 0.12313165651126341, Reblur Loss: 0.016813130605788457 105 | 2024-02-21 18:17:39,473 - INFO - EPOCH 100/151: Total Train Loss: 0.1380174332237863, Tea Loss: 0.1214408864513104, Reblur Loss: 0.016576546941807258 106 | 2024-02-21 18:30:09,010 - INFO - SDM: mse: 0.010 ssim: 0.732 psnr: 26.362 lpips: 0.259 107 | 2024-02-21 18:30:09,011 - INFO - Deblur_Net: mse: 0.007 ssim: 0.786 psnr: 27.928 lpips: 0.252 108 | 2024-02-21 18:33:57,030 - INFO - EPOCH 101/151: Total Train Loss: 0.1378288431923627, Tea Loss: 0.1212237568902763, Reblur Loss: 0.01660508629805469 109 | 2024-02-21 18:37:45,769 - INFO - EPOCH 102/151: Total Train Loss: 0.13755620919264755, Tea Loss: 0.12160171765140641, Reblur Loss: 0.015954491690414036 110 | 2024-02-21 18:41:34,434 - INFO - EPOCH 103/151: Total Train Loss: 0.138257441808393, Tea Loss: 0.12145415109731418, Reblur Loss: 0.016803290505462136 111 | 2024-02-21 18:45:22,749 - INFO - EPOCH 104/151: Total Train Loss: 0.13809712418230066, Tea Loss: 0.12142424404750139, Reblur Loss: 0.016672880183179657 112 | 2024-02-21 18:49:11,083 - INFO - EPOCH 105/151: Total Train Loss: 0.13738702982664108, Tea Loss: 0.12067290224554218, Reblur Loss: 0.016714127782683868 113 | 2024-02-21 18:54:33,484 - INFO - EPOCH 106/151: Total Train Loss: 0.13537142006478783, Tea Loss: 0.11937310443425075, Reblur Loss: 0.015998315215272046 114 | 2024-02-21 18:58:21,890 - INFO - EPOCH 107/151: Total Train Loss: 0.1379812436895969, Tea Loss: 0.12124085948838816, Reblur Loss: 0.01674038420524045 115 | 2024-02-21 19:02:10,348 - INFO - EPOCH 108/151: Total Train Loss: 0.14069764654744754, Tea Loss: 0.12356308567059504, Reblur Loss: 0.01713456069542603 116 | 2024-02-21 19:05:58,802 - INFO - EPOCH 109/151: Total Train Loss: 0.138514020903544, Tea Loss: 0.12166547230182788, Reblur Loss: 0.016848548738793892 117 | 2024-02-21 19:09:47,265 - INFO - EPOCH 110/151: Total Train Loss: 0.13842449227581807, Tea Loss: 0.12143795095480882, Reblur Loss: 0.0169865418290034 118 | 2024-02-21 19:15:09,966 - INFO - EPOCH 111/151: Total Train Loss: 0.13819570491066227, Tea Loss: 0.1219581923572532, Reblur Loss: 0.016237512605821156 119 | 2024-02-21 19:18:58,428 - INFO - EPOCH 112/151: Total Train Loss: 0.13839785670821284, Tea Loss: 0.12155075126138085, Reblur Loss: 0.016847105438768606 120 | 2024-02-21 19:22:46,795 - INFO - EPOCH 113/151: Total Train Loss: 0.13677882712769818, Tea Loss: 0.12005357699089753, Reblur Loss: 0.016725250156959155 121 | 2024-02-21 19:26:35,302 - INFO - EPOCH 114/151: Total Train Loss: 0.13452625832645407, Tea Loss: 0.1182897954340621, Reblur Loss: 0.016236462710965505 122 | 2024-02-21 19:30:23,838 - INFO - EPOCH 115/151: Total Train Loss: 0.13718709688295017, Tea Loss: 0.12045202155907948, Reblur Loss: 0.016735075799611223 123 | 2024-02-21 19:35:46,377 - INFO - EPOCH 116/151: Total Train Loss: 0.13697202916527207, Tea Loss: 0.12038201193659853, Reblur Loss: 0.01659001698677158 124 | 2024-02-21 19:39:35,025 - INFO - EPOCH 117/151: Total Train Loss: 0.13589300969978432, Tea Loss: 0.11910017815245179, Reblur Loss: 0.016792831700537112 125 | 2024-02-21 19:43:23,577 - INFO - EPOCH 118/151: Total Train Loss: 0.13500712547467383, Tea Loss: 0.11859367100836395, Reblur Loss: 0.016413455091223314 126 | 2024-02-21 19:47:12,306 - INFO - EPOCH 119/151: Total Train Loss: 0.13317141862529697, Tea Loss: 0.1173663471877833, Reblur Loss: 0.01580507102224863 127 | 2024-02-21 19:51:04,440 - INFO - EPOCH 120/151: Total Train Loss: 0.13608860605077827, Tea Loss: 0.11919505568422796, Reblur Loss: 0.016893550374613694 128 | 2024-02-21 19:56:31,905 - INFO - EPOCH 121/151: Total Train Loss: 0.1362584302951763, Tea Loss: 0.11948918902770782, Reblur Loss: 0.016769240852203462 129 | 2024-02-21 20:00:24,714 - INFO - EPOCH 122/151: Total Train Loss: 0.13465122036732635, Tea Loss: 0.11808150193907997, Reblur Loss: 0.016569718218598015 130 | 2024-02-21 20:04:17,498 - INFO - EPOCH 123/151: Total Train Loss: 0.13412981302965254, Tea Loss: 0.11801906629945293, Reblur Loss: 0.016110746468139158 131 | 2024-02-21 20:08:10,286 - INFO - EPOCH 124/151: Total Train Loss: 0.1353475305360633, Tea Loss: 0.11868680578160595, Reblur Loss: 0.01666072496410572 132 | 2024-02-21 20:12:03,182 - INFO - EPOCH 125/151: Total Train Loss: 0.13473823502079232, Tea Loss: 0.11823303658853877, Reblur Loss: 0.01650519877091631 133 | 2024-02-21 20:17:33,181 - INFO - EPOCH 126/151: Total Train Loss: 0.13619285976731932, Tea Loss: 0.11934208008763078, Reblur Loss: 0.01685077938134278 134 | 2024-02-21 20:21:26,034 - INFO - EPOCH 127/151: Total Train Loss: 0.1337554477380984, Tea Loss: 0.11781134275775967, Reblur Loss: 0.015944105105321387 135 | 2024-02-21 20:25:18,917 - INFO - EPOCH 128/151: Total Train Loss: 0.13385922952002777, Tea Loss: 0.11756700577286931, Reblur Loss: 0.016292223755221862 136 | 2024-02-21 20:29:11,864 - INFO - EPOCH 129/151: Total Train Loss: 0.13444969645052246, Tea Loss: 0.117829749207476, Reblur Loss: 0.01661994808567164 137 | 2024-02-21 20:33:04,695 - INFO - EPOCH 130/151: Total Train Loss: 0.1344283643739048, Tea Loss: 0.11783283923095439, Reblur Loss: 0.016595525122791915 138 | 2024-02-21 20:38:33,836 - INFO - EPOCH 131/151: Total Train Loss: 0.13254633594255943, Tea Loss: 0.11647765380350543, Reblur Loss: 0.016068682638984737 139 | 2024-02-21 20:42:26,698 - INFO - EPOCH 132/151: Total Train Loss: 0.13348609583202378, Tea Loss: 0.11717449247966082, Reblur Loss: 0.0163116031548097 140 | 2024-02-21 20:46:19,622 - INFO - EPOCH 133/151: Total Train Loss: 0.13426269087698553, Tea Loss: 0.11791426914207863, Reblur Loss: 0.016348422267091222 141 | 2024-02-21 20:50:12,632 - INFO - EPOCH 134/151: Total Train Loss: 0.13518822267329023, Tea Loss: 0.11855961766206857, Reblur Loss: 0.016628604914460863 142 | 2024-02-21 20:54:05,636 - INFO - EPOCH 135/151: Total Train Loss: 0.13607895332368422, Tea Loss: 0.1191918769956151, Reblur Loss: 0.01688707603375504 143 | 2024-02-21 20:59:40,029 - INFO - EPOCH 136/151: Total Train Loss: 0.13298061241706213, Tea Loss: 0.11717816913153702, Reblur Loss: 0.01580244323311301 144 | 2024-02-21 21:03:32,881 - INFO - EPOCH 137/151: Total Train Loss: 0.13413896153628568, Tea Loss: 0.1176624967357813, Reblur Loss: 0.01647646492548706 145 | 2024-02-21 21:07:25,749 - INFO - EPOCH 138/151: Total Train Loss: 0.1341735723214748, Tea Loss: 0.11770590485045404, Reblur Loss: 0.01646766753955966 146 | 2024-02-21 21:11:18,778 - INFO - EPOCH 139/151: Total Train Loss: 0.1323508085174994, Tea Loss: 0.11597479141248769, Reblur Loss: 0.01637601697196563 147 | 2024-02-21 21:15:11,637 - INFO - EPOCH 140/151: Total Train Loss: 0.1313024704784026, Tea Loss: 0.11536454802854752, Reblur Loss: 0.015937922744169122 148 | 2024-02-21 21:20:45,685 - INFO - EPOCH 141/151: Total Train Loss: 0.14401600925953356, Tea Loss: 0.12672924763196475, Reblur Loss: 0.017286761490491045 149 | 2024-02-21 21:24:38,675 - INFO - EPOCH 142/151: Total Train Loss: 0.138105887871284, Tea Loss: 0.1207235726120668, Reblur Loss: 0.017382314892332534 150 | 2024-02-21 21:28:31,715 - INFO - EPOCH 143/151: Total Train Loss: 0.13395060392427238, Tea Loss: 0.1174174669436562, Reblur Loss: 0.016533136859665187 151 | 2024-02-21 21:32:24,643 - INFO - EPOCH 144/151: Total Train Loss: 0.13170721843129113, Tea Loss: 0.11586916924039006, Reblur Loss: 0.015838049380390934 152 | 2024-02-21 21:36:17,628 - INFO - EPOCH 145/151: Total Train Loss: 0.13255383874172771, Tea Loss: 0.11643216858952593, Reblur Loss: 0.01612167028524788 153 | 2024-02-21 21:41:53,258 - INFO - EPOCH 146/151: Total Train Loss: 0.13475415675025998, Tea Loss: 0.11823435321256712, Reblur Loss: 0.01651980377153143 154 | 2024-02-21 21:45:46,174 - INFO - EPOCH 147/151: Total Train Loss: 0.13381607143522858, Tea Loss: 0.11708475265539053, Reblur Loss: 0.01673131895320111 155 | 2024-02-21 21:49:39,078 - INFO - EPOCH 148/151: Total Train Loss: 0.13221219717424154, Tea Loss: 0.11615527588954735, Reblur Loss: 0.01605692087346083 156 | 2024-02-21 21:53:32,073 - INFO - EPOCH 149/151: Total Train Loss: 0.13134429555434685, Tea Loss: 0.1154829551711743, Reblur Loss: 0.01586134055250393 157 | 2024-02-21 21:57:25,062 - INFO - EPOCH 150/151: Total Train Loss: 0.13437037873061705, Tea Loss: 0.11755732788797064, Reblur Loss: 0.01681305057655423 158 | 2024-02-21 22:13:18,640 - INFO - SDM: mse: 0.010 ssim: 0.732 psnr: 26.362 lpips: 0.259 159 | 2024-02-21 22:13:18,642 - INFO - Deblur_Net: mse: 0.007 ssim: 0.785 psnr: 27.902 lpips: 0.252 160 | -------------------------------------------------------------------------------- /log/SR/opt.txt: -------------------------------------------------------------------------------- 1 | base_folder: GOPRO/ 2 | save_folder: exp/SR 3 | data_type: GOPRO 4 | exp_name: NEW_GOPRO_9_bsn1000 5 | bsn_path: exp/BSN/NEW_GOPRO_9_test_full/ckpts/BSN_1000.pth 6 | sr_path: exp/SR/NEW_GOPRO_9/ckpts/SR_100.pth 7 | epochs: 101 8 | lr: 0.0002 9 | seed: 42 10 | bsn_len: 9 11 | use_small: False 12 | test_mode: False 13 | -------------------------------------------------------------------------------- /log/SR/test_log.txt: -------------------------------------------------------------------------------- 1 | 2024-03-15 11:42:47,789 - INFO - Namespace(base_folder='GOPRO', bsn_len=9, bsn_path='model/BSN_1000.pth', data_type='GOPRO', epochs=101, exp_name='test', lr=0.0002, save_folder='exp/SR', seed=42, sr_path='model/SR_70.pth', test_mode=True, use_small=False) 2 | 2024-03-15 11:42:48,131 - INFO - Start Training! 3 | 2024-03-15 11:42:48,131 - INFO - EPOCH 0/101: Train Loss: 0 4 | 2024-03-15 12:09:24,118 - INFO - SDM: ssim: 0.732 psnr: 26.362 5 | 2024-03-15 12:09:24,119 - INFO - Blur_SR: ssim: 0.945 psnr: 37.286 6 | 2024-03-15 12:09:24,120 - INFO - Blur_Resize: ssim: 0.855 psnr: 31.124 7 | 2024-03-15 12:09:24,120 - INFO - BSN_SR: ssim: 0.708 psnr: 26.144 8 | 2024-03-15 12:09:24,121 - INFO - BSN_Resize: ssim: 0.645 psnr: 24.490 9 | 2024-03-15 12:09:24,121 - INFO - TFP_Resize: ssim: 0.461 psnr: 22.943 10 | -------------------------------------------------------------------------------- /log/SR/train_log.txt: -------------------------------------------------------------------------------- 1 | 2024-02-18 21:44:09,917 - INFO - Namespace(base_folder='GOPRO/', save_folder='exp/SR', data_type='GOPRO', exp_name='NEW_GOPRO_9_bsn1000', bsn_path='exp/BSN/NEW_GOPRO_9_test_full/ckpts/BSN_1000.pth', sr_path='exp/SR/NEW_GOPRO_9/ckpts/SR_100.pth', epochs=101, lr=0.0002, seed=42, bsn_len=9, use_small=False, test_mode=False) 2 | 2024-02-18 21:44:11,448 - INFO - Start Training! 3 | 2024-02-18 21:44:35,386 - INFO - EPOCH 0/101: Train Loss: 0.0014211533537479238 4 | 2024-02-18 22:04:49,354 - INFO - SDM: mse: 0.010 ssim: 0.717 psnr: 26.293 lpips: 0.287 5 | 2024-02-18 22:04:49,359 - INFO - Blur_SR: mse: 0.002 ssim: 0.918 psnr: 34.995 lpips: 0.187 6 | 2024-02-18 22:04:49,359 - INFO - Blur_Resize: mse: 0.004 ssim: 0.855 psnr: 31.124 lpips: 0.433 7 | 2024-02-18 22:04:49,359 - INFO - BSN_SR: mse: 0.012 ssim: 0.686 psnr: 25.692 lpips: 0.391 8 | 2024-02-18 22:04:49,360 - INFO - BSN_Resize: mse: 0.016 ssim: 0.645 psnr: 24.490 lpips: 0.411 9 | 2024-02-18 22:04:49,360 - INFO - TFP_Resize: mse: 0.021 ssim: 0.461 psnr: 22.943 lpips: 0.567 10 | 2024-02-18 22:05:10,622 - INFO - EPOCH 1/101: Train Loss: 0.00045315917495394214 11 | 2024-02-18 22:05:30,408 - INFO - EPOCH 2/101: Train Loss: 0.0004365248677423587 12 | 2024-02-18 22:05:49,968 - INFO - EPOCH 3/101: Train Loss: 0.0004334324128972165 13 | 2024-02-18 22:06:10,202 - INFO - EPOCH 4/101: Train Loss: 0.0004208742305651525 14 | 2024-02-18 22:06:30,051 - INFO - EPOCH 5/101: Train Loss: 0.000399760951347764 15 | 2024-02-18 22:08:16,945 - INFO - EPOCH 6/101: Train Loss: 0.00040361383598871346 16 | 2024-02-18 22:08:36,443 - INFO - EPOCH 7/101: Train Loss: 0.00039817116557024784 17 | 2024-02-18 22:08:56,198 - INFO - EPOCH 8/101: Train Loss: 0.0003857492040631107 18 | 2024-02-18 22:09:16,297 - INFO - EPOCH 9/101: Train Loss: 0.000374428103321375 19 | 2024-02-18 22:09:35,824 - INFO - EPOCH 10/101: Train Loss: 0.00038021592708940524 20 | 2024-02-18 22:29:48,360 - INFO - SDM: mse: 0.010 ssim: 0.731 psnr: 26.402 lpips: 0.268 21 | 2024-02-18 22:29:48,365 - INFO - Blur_SR: mse: 0.002 ssim: 0.933 psnr: 36.117 lpips: 0.141 22 | 2024-02-18 22:29:48,366 - INFO - Blur_Resize: mse: 0.004 ssim: 0.855 psnr: 31.124 lpips: 0.433 23 | 2024-02-18 22:29:48,366 - INFO - BSN_SR: mse: 0.011 ssim: 0.700 psnr: 25.961 lpips: 0.372 24 | 2024-02-18 22:29:48,367 - INFO - BSN_Resize: mse: 0.016 ssim: 0.645 psnr: 24.490 lpips: 0.411 25 | 2024-02-18 22:29:48,367 - INFO - TFP_Resize: mse: 0.021 ssim: 0.461 psnr: 22.943 lpips: 0.567 26 | 2024-02-18 22:30:09,554 - INFO - EPOCH 11/101: Train Loss: 0.0003734087890005989 27 | 2024-02-18 22:30:30,197 - INFO - EPOCH 12/101: Train Loss: 0.00038408602982733267 28 | 2024-02-18 22:30:49,718 - INFO - EPOCH 13/101: Train Loss: 0.0003796412013707989 29 | 2024-02-18 22:31:08,348 - INFO - EPOCH 14/101: Train Loss: 0.0003678143284048228 30 | 2024-02-18 22:31:28,271 - INFO - EPOCH 15/101: Train Loss: 0.00036798545054953294 31 | 2024-02-18 22:33:18,243 - INFO - EPOCH 16/101: Train Loss: 0.0003596360475687689 32 | 2024-02-18 22:33:37,820 - INFO - EPOCH 17/101: Train Loss: 0.00037178791592248363 33 | 2024-02-18 22:33:57,316 - INFO - EPOCH 18/101: Train Loss: 0.00035189093822347265 34 | 2024-02-18 22:34:16,896 - INFO - EPOCH 19/101: Train Loss: 0.0003537243204476066 35 | 2024-02-18 22:34:36,585 - INFO - EPOCH 20/101: Train Loss: 0.0003496722201501761 36 | 2024-02-18 22:54:58,323 - INFO - SDM: mse: 0.011 ssim: 0.720 psnr: 26.013 lpips: 0.277 37 | 2024-02-18 22:54:58,324 - INFO - Blur_SR: mse: 0.003 ssim: 0.896 psnr: 32.358 lpips: 0.131 38 | 2024-02-18 22:54:58,324 - INFO - Blur_Resize: mse: 0.004 ssim: 0.855 psnr: 31.124 lpips: 0.433 39 | 2024-02-18 22:54:58,325 - INFO - BSN_SR: mse: 0.013 ssim: 0.684 psnr: 25.471 lpips: 0.364 40 | 2024-02-18 22:54:58,325 - INFO - BSN_Resize: mse: 0.016 ssim: 0.645 psnr: 24.490 lpips: 0.411 41 | 2024-02-18 22:54:58,325 - INFO - TFP_Resize: mse: 0.021 ssim: 0.461 psnr: 22.943 lpips: 0.567 42 | 2024-02-18 22:55:19,391 - INFO - EPOCH 21/101: Train Loss: 0.0003841655182153757 43 | 2024-02-18 22:55:39,402 - INFO - EPOCH 22/101: Train Loss: 0.000341238439826066 44 | 2024-02-18 22:55:58,865 - INFO - EPOCH 23/101: Train Loss: 0.0003420646372074826 45 | 2024-02-18 22:56:18,752 - INFO - EPOCH 24/101: Train Loss: 0.0003486189204140907 46 | 2024-02-18 22:56:38,449 - INFO - EPOCH 25/101: Train Loss: 0.00033939894858507836 47 | 2024-02-18 22:58:28,952 - INFO - EPOCH 26/101: Train Loss: 0.00033795114644290074 48 | 2024-02-18 22:58:47,653 - INFO - EPOCH 27/101: Train Loss: 0.000348002309133591 49 | 2024-02-18 22:59:06,305 - INFO - EPOCH 28/101: Train Loss: 0.0003404388012805726 50 | 2024-02-18 22:59:25,697 - INFO - EPOCH 29/101: Train Loss: 0.0003469659714658192 51 | 2024-02-18 22:59:44,931 - INFO - EPOCH 30/101: Train Loss: 0.00033793993217636755 52 | 2024-02-18 23:20:06,934 - INFO - SDM: mse: 0.010 ssim: 0.731 psnr: 26.371 lpips: 0.264 53 | 2024-02-18 23:20:06,939 - INFO - Blur_SR: mse: 0.001 ssim: 0.940 psnr: 36.720 lpips: 0.125 54 | 2024-02-18 23:20:06,940 - INFO - Blur_Resize: mse: 0.004 ssim: 0.855 psnr: 31.124 lpips: 0.433 55 | 2024-02-18 23:20:06,940 - INFO - BSN_SR: mse: 0.011 ssim: 0.704 psnr: 26.044 lpips: 0.364 56 | 2024-02-18 23:20:06,940 - INFO - BSN_Resize: mse: 0.016 ssim: 0.645 psnr: 24.490 lpips: 0.411 57 | 2024-02-18 23:20:06,940 - INFO - TFP_Resize: mse: 0.021 ssim: 0.461 psnr: 22.943 lpips: 0.567 58 | 2024-02-18 23:20:28,590 - INFO - EPOCH 31/101: Train Loss: 0.00033414859230956667 59 | 2024-02-18 23:20:48,187 - INFO - EPOCH 32/101: Train Loss: 0.0003287021171952219 60 | 2024-02-18 23:21:07,387 - INFO - EPOCH 33/101: Train Loss: 0.00032866710166495243 61 | 2024-02-18 23:21:27,734 - INFO - EPOCH 34/101: Train Loss: 0.00033242032808141814 62 | 2024-02-18 23:21:46,919 - INFO - EPOCH 35/101: Train Loss: 0.00033035686930532674 63 | 2024-02-18 23:23:34,118 - INFO - EPOCH 36/101: Train Loss: 0.00032164452114828875 64 | 2024-02-18 23:23:53,515 - INFO - EPOCH 37/101: Train Loss: 0.0003143733826748628 65 | 2024-02-18 23:24:13,141 - INFO - EPOCH 38/101: Train Loss: 0.00031801123744420085 66 | 2024-02-18 23:24:32,576 - INFO - EPOCH 39/101: Train Loss: 0.00032052480803023483 67 | 2024-02-18 23:24:51,224 - INFO - EPOCH 40/101: Train Loss: 0.00031738842848095703 68 | 2024-02-18 23:45:14,654 - INFO - SDM: mse: 0.010 ssim: 0.732 psnr: 26.372 lpips: 0.264 69 | 2024-02-18 23:45:14,656 - INFO - Blur_SR: mse: 0.001 ssim: 0.942 psnr: 36.864 lpips: 0.120 70 | 2024-02-18 23:45:14,656 - INFO - Blur_Resize: mse: 0.004 ssim: 0.855 psnr: 31.124 lpips: 0.433 71 | 2024-02-18 23:45:14,656 - INFO - BSN_SR: mse: 0.011 ssim: 0.707 psnr: 26.112 lpips: 0.359 72 | 2024-02-18 23:45:14,657 - INFO - BSN_Resize: mse: 0.016 ssim: 0.645 psnr: 24.490 lpips: 0.411 73 | 2024-02-18 23:45:14,657 - INFO - TFP_Resize: mse: 0.021 ssim: 0.461 psnr: 22.943 lpips: 0.567 74 | 2024-02-18 23:45:34,343 - INFO - EPOCH 41/101: Train Loss: 0.0003112662171033308 75 | 2024-02-18 23:45:54,126 - INFO - EPOCH 42/101: Train Loss: 0.0003128112463867171 76 | 2024-02-18 23:46:13,857 - INFO - EPOCH 43/101: Train Loss: 0.00038736918827748625 77 | 2024-02-18 23:46:33,461 - INFO - EPOCH 44/101: Train Loss: 0.000311614652130263 78 | 2024-02-18 23:46:53,385 - INFO - EPOCH 45/101: Train Loss: 0.00031099672609052187 79 | 2024-02-18 23:48:43,192 - INFO - EPOCH 46/101: Train Loss: 0.0003036633764786214 80 | 2024-02-18 23:49:02,453 - INFO - EPOCH 47/101: Train Loss: 0.0002961308712762623 81 | 2024-02-18 23:49:22,003 - INFO - EPOCH 48/101: Train Loss: 0.0003085716928529767 82 | 2024-02-18 23:49:41,660 - INFO - EPOCH 49/101: Train Loss: 0.00030397324178023144 83 | 2024-02-18 23:50:01,265 - INFO - EPOCH 50/101: Train Loss: 0.00030308918473970957 84 | 2024-02-19 00:10:45,013 - INFO - SDM: mse: 0.010 ssim: 0.731 psnr: 26.374 lpips: 0.263 85 | 2024-02-19 00:10:45,018 - INFO - Blur_SR: mse: 0.001 ssim: 0.943 psnr: 37.075 lpips: 0.114 86 | 2024-02-19 00:10:45,018 - INFO - Blur_Resize: mse: 0.004 ssim: 0.855 psnr: 31.124 lpips: 0.433 87 | 2024-02-19 00:10:45,018 - INFO - BSN_SR: mse: 0.011 ssim: 0.706 psnr: 26.114 lpips: 0.356 88 | 2024-02-19 00:10:45,018 - INFO - BSN_Resize: mse: 0.016 ssim: 0.645 psnr: 24.490 lpips: 0.411 89 | 2024-02-19 00:10:45,019 - INFO - TFP_Resize: mse: 0.021 ssim: 0.461 psnr: 22.943 lpips: 0.567 90 | 2024-02-19 00:11:05,652 - INFO - EPOCH 51/101: Train Loss: 0.00029659506737154273 91 | 2024-02-19 00:11:25,623 - INFO - EPOCH 52/101: Train Loss: 0.00030188390887992507 92 | 2024-02-19 00:11:44,842 - INFO - EPOCH 53/101: Train Loss: 0.00028545065259851226 93 | 2024-02-19 00:12:05,023 - INFO - EPOCH 54/101: Train Loss: 0.00029897224343884774 94 | 2024-02-19 00:12:25,539 - INFO - EPOCH 55/101: Train Loss: 0.00029665410318447164 95 | 2024-02-19 00:14:11,359 - INFO - EPOCH 56/101: Train Loss: 0.00029271354684500363 96 | 2024-02-19 00:14:30,650 - INFO - EPOCH 57/101: Train Loss: 0.00028587002436946914 97 | 2024-02-19 00:14:49,898 - INFO - EPOCH 58/101: Train Loss: 0.00029207505800411533 98 | 2024-02-19 00:15:09,992 - INFO - EPOCH 59/101: Train Loss: 0.000288543489447178 99 | 2024-02-19 00:15:29,946 - INFO - EPOCH 60/101: Train Loss: 0.00029507283475003066 100 | 2024-02-19 00:35:49,782 - INFO - SDM: mse: 0.010 ssim: 0.735 psnr: 26.444 lpips: 0.262 101 | 2024-02-19 00:35:49,783 - INFO - Blur_SR: mse: 0.001 ssim: 0.943 psnr: 37.077 lpips: 0.122 102 | 2024-02-19 00:35:49,783 - INFO - Blur_Resize: mse: 0.004 ssim: 0.855 psnr: 31.124 lpips: 0.433 103 | 2024-02-19 00:35:49,784 - INFO - BSN_SR: mse: 0.011 ssim: 0.707 psnr: 26.121 lpips: 0.362 104 | 2024-02-19 00:35:49,784 - INFO - BSN_Resize: mse: 0.016 ssim: 0.645 psnr: 24.490 lpips: 0.411 105 | 2024-02-19 00:35:49,784 - INFO - TFP_Resize: mse: 0.021 ssim: 0.461 psnr: 22.943 lpips: 0.567 106 | 2024-02-19 00:36:11,003 - INFO - EPOCH 61/101: Train Loss: 0.000283065656955617 107 | 2024-02-19 00:36:31,520 - INFO - EPOCH 62/101: Train Loss: 0.00029362727483120665 108 | 2024-02-19 00:36:50,988 - INFO - EPOCH 63/101: Train Loss: 0.0002813149460675034 109 | 2024-02-19 00:37:10,857 - INFO - EPOCH 64/101: Train Loss: 0.00028284162286597455 110 | 2024-02-19 00:37:30,360 - INFO - EPOCH 65/101: Train Loss: 0.00027693967154507576 111 | 2024-02-19 00:39:19,888 - INFO - EPOCH 66/101: Train Loss: 0.0002798173035526733 112 | 2024-02-19 00:39:39,484 - INFO - EPOCH 67/101: Train Loss: 0.00028095614078512126 113 | 2024-02-19 00:39:59,174 - INFO - EPOCH 68/101: Train Loss: 0.000282799996639999 114 | 2024-02-19 00:40:19,057 - INFO - EPOCH 69/101: Train Loss: 0.00027430040608245873 115 | 2024-02-19 00:40:39,029 - INFO - EPOCH 70/101: Train Loss: 0.0002796248417802244 116 | 2024-02-19 01:00:59,204 - INFO - SDM: mse: 0.010 ssim: 0.732 psnr: 26.362 lpips: 0.259 117 | 2024-02-19 01:00:59,205 - INFO - Blur_SR: mse: 0.001 ssim: 0.945 psnr: 37.286 lpips: 0.109 118 | 2024-02-19 01:00:59,206 - INFO - Blur_Resize: mse: 0.004 ssim: 0.855 psnr: 31.124 lpips: 0.433 119 | 2024-02-19 01:00:59,206 - INFO - BSN_SR: mse: 0.011 ssim: 0.708 psnr: 26.144 lpips: 0.349 120 | 2024-02-19 01:00:59,206 - INFO - BSN_Resize: mse: 0.016 ssim: 0.645 psnr: 24.490 lpips: 0.411 121 | 2024-02-19 01:00:59,206 - INFO - TFP_Resize: mse: 0.021 ssim: 0.461 psnr: 22.943 lpips: 0.567 122 | 2024-02-19 01:01:20,403 - INFO - EPOCH 71/101: Train Loss: 0.0002738387952042135 123 | 2024-02-19 01:01:39,201 - INFO - EPOCH 72/101: Train Loss: 0.0002767562307001624 124 | 2024-02-19 01:01:58,023 - INFO - EPOCH 73/101: Train Loss: 0.00027031866407467385 125 | 2024-02-19 01:02:16,770 - INFO - EPOCH 74/101: Train Loss: 0.00026943054254563903 126 | 2024-02-19 01:02:36,184 - INFO - EPOCH 75/101: Train Loss: 0.0002645513773584603 127 | 2024-02-19 01:04:26,003 - INFO - EPOCH 76/101: Train Loss: 0.0002584966490388136 128 | 2024-02-19 01:04:45,514 - INFO - EPOCH 77/101: Train Loss: 0.00027077237312121065 129 | 2024-02-19 01:05:05,113 - INFO - EPOCH 78/101: Train Loss: 0.000273223011823748 130 | 2024-02-19 01:05:23,908 - INFO - EPOCH 79/101: Train Loss: 0.00026525524041078204 131 | 2024-02-19 01:05:42,648 - INFO - EPOCH 80/101: Train Loss: 0.00026188710400430504 132 | 2024-02-19 01:26:04,975 - INFO - SDM: mse: 0.010 ssim: 0.734 psnr: 26.402 lpips: 0.261 133 | 2024-02-19 01:26:04,979 - INFO - Blur_SR: mse: 0.001 ssim: 0.945 psnr: 37.316 lpips: 0.113 134 | 2024-02-19 01:26:04,980 - INFO - Blur_Resize: mse: 0.004 ssim: 0.855 psnr: 31.124 lpips: 0.433 135 | 2024-02-19 01:26:04,980 - INFO - BSN_SR: mse: 0.011 ssim: 0.708 psnr: 26.142 lpips: 0.354 136 | 2024-02-19 01:26:04,980 - INFO - BSN_Resize: mse: 0.016 ssim: 0.645 psnr: 24.490 lpips: 0.411 137 | 2024-02-19 01:26:04,980 - INFO - TFP_Resize: mse: 0.021 ssim: 0.461 psnr: 22.943 lpips: 0.567 138 | 2024-02-19 01:26:25,886 - INFO - EPOCH 81/101: Train Loss: 0.0002638731355357056 139 | 2024-02-19 01:26:45,491 - INFO - EPOCH 82/101: Train Loss: 0.00026061319765857585 140 | 2024-02-19 01:27:04,218 - INFO - EPOCH 83/101: Train Loss: 0.0002606669278967585 141 | 2024-02-19 01:27:24,915 - INFO - EPOCH 84/101: Train Loss: 0.0002591617188584745 142 | -------------------------------------------------------------------------------- /model/BSN_1000.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/model/BSN_1000.pth -------------------------------------------------------------------------------- /model/DeblurNet_100.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/model/DeblurNet_100.pth -------------------------------------------------------------------------------- /model/SR_70.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/model/SR_70.pth -------------------------------------------------------------------------------- /scripts/GOPRO_dataset.md: -------------------------------------------------------------------------------- 1 | ## Step 1: Obtain the Original GOPRO Dataset 2 | Step into the `scripts` subfolder: 3 | ``` 4 | cd scripts 5 | ``` 6 | 7 | Download the [GOPRO_Large_all](https://drive.google.com/file/d/1rJTmM9_mLCNzBUUhYIGldBYgup279E_f/view) from the [GOPRO website](https://seungjunnah.github.io/Datasets/gopro) to get the sharp sequence for simulating spikes and synthesizing blurry frames. We also provide a [small GOPRO dataset](https://pan.baidu.com/s/1FGqlMFtnL5jwI39I5mNkTw?pwd=1623) for debugging. After downloading the data, rename the data file to `GOPRO` and place it in the `scripts` directory. Add a subfolder `raw_folder` in both `train` and `test` folders. The file structure is as follows: 8 | ``` 9 | scripts 10 | ├── XVFI-main 11 | ├── run.sh 12 | ├── ... 13 | └── GOPRO 14 | ├── train 15 | │ └── raw_data 16 | │ ├── GOPR0372_07_00 17 | │ ├── ... 18 | │ └── GOPR0884_11_00 19 | └── test 20 | └── raw_data 21 | ├── GOPR0384_11_00 22 | ├── ... 23 | └── GOPR0881_11_01 24 | ``` 25 | 26 | If you prefer not to process the following steps sequentially, you can skip them and simply run: 27 | ``` 28 | bash run.sh 29 | ``` 30 | 31 | ## Step 2: Frame Interpolation 32 | We use the XVFI frame interpolation algorithm to insert 7 additional imgs between two adjacent imgs, increasing the frame rate of the GOPRO image sequence, which is time-consuming and takes up a large amount of space. Run 33 | 34 | ``` 35 | cd XVFI-main/ 36 | python main.py --custom_path ../GOPRO/test/raw_data --gpu 0 --phase test_custom --exp_num 1 --dataset X4K1000FPS --module_scale_factor 4 --S_tst 5 --multiple 8 37 | python main.py --custom_path ../GOPRO/train/raw_data --gpu 0 --phase test_custom --exp_num 1 --dataset X4K1000FPS --module_scale_factor 4 --S_tst 5 --multiple 8 38 | cd .. 39 | ``` 40 | ## Step 3: Blur Synthesis 41 | We synthesize a blurred frame using 97 images in the dataset after frame interpolation (corresponding to 13 images before interpolation). Run 42 | 43 | ``` 44 | python blur_syn.py --raw_folder GOPRO/test/raw_data --overlap_len 7 --blur_num 13 45 | python blur_syn.py --raw_folder GOPRO/train/raw_data --overlap_len 7 --blur_num 13 46 | ``` 47 | 48 | ❗: Please note that the output `Synthesize blurry 000006.png from 000323.png to 000335.png` is not an error. This is because we rename our output files starting from `000000.png`, which may differ slightly from the GOPRO dataset that starts with `000001.png`, `000323.png`, etc. 49 | 50 | ## Step 4: Spike Simulation 51 | 52 | We resize the image size from `720×1280` to `180×320` and apply a spike generation physical model to simulate low-resolution spikes, obtaining the spike stream corresponding to the virtual exposure time in `Step 3: Blur Synthesis`. Run 53 | 54 | ``` 55 | python spike_simulate.py --raw_folder GOPRO/test/raw_data --overlap_len 7 --use_resize --blur_num 13 56 | python spike_simulate.py --raw_folder GOPRO/train/raw_data --overlap_len 7 --use_resize --blur_num 13 57 | ``` 58 | 59 | Since the spike stream also contains the additional 20 spike frames out of the exposure period, the start and end segments of the spike stream cannot be utilized. (For example, `000006.png` is the first frame of the `blur_data` folder while `000019.dat` is the first spike data of the `spike_data` folder.) 60 | 61 | ## Step 5: Sharp Extract 62 | 63 | For obtaining the single sharp frame corresponding to the blurry frame: 64 | 65 | ``` 66 | python sharp_extract.py --raw_folder GOPRO/test/raw_data --overlap_len 7 --blur_num 13 67 | python sharp_extract.py --raw_folder GOPRO/train/raw_data --overlap_len 7 --blur_num 13 68 | ``` 69 | 70 | For obtaining the sharp sequence (7 images in this example) corresponding to the blurry frame: 71 | ``` 72 | python sharp_extract.py --raw_folder GOPRO/test/raw_data --overlap_len 7 --blur_num 13 --multi 73 | python sharp_extract.py --raw_folder GOPRO/train/raw_data --overlap_len 7 --blur_num 13 --multi 74 | ``` 75 | 76 | ## Step 6: Final 77 | Omit the raw_folder, the structure of the `Spike-GOPRO` dataset is as follows: 78 | ``` 79 | GOPRO 80 | ├── test 81 | │ ├── blur_data 82 | │ ├── sharp_data 83 | │ └── spike_data 84 | └── train 85 | ├── blur_data 86 | ├── sharp_data 87 | └── spike_data 88 | ``` 89 | 90 | -------------------------------------------------------------------------------- /scripts/XVFI-main/(0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/(0000.png -------------------------------------------------------------------------------- /scripts/XVFI-main/README.md: -------------------------------------------------------------------------------- 1 | # XVFI (ICCV2021, Oral) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/xvfi-extreme-video-frame-interpolation/video-frame-interpolation-on-x4k1000fps)](https://paperswithcode.com/sota/video-frame-interpolation-on-x4k1000fps?p=xvfi-extreme-video-frame-interpolation) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/xvfi-extreme-video-frame-interpolation/video-frame-interpolation-on-vimeo90k)](https://paperswithcode.com/sota/video-frame-interpolation-on-vimeo90k?p=xvfi-extreme-video-frame-interpolation) 2 | 3 | 4 | [![ArXiv](https://img.shields.io/badge/ArXiv-Paper-.svg)](https://arxiv.org/abs/2103.16206) 5 | [![ICCV2021](https://img.shields.io/badge/ICCV2021-Paper-.svg)](https://openaccess.thecvf.com/content/ICCV2021/papers/Sim_XVFI_eXtreme_Video_Frame_Interpolation_ICCV_2021_paper.pdf) 6 | [![GitHub Stars](https://img.shields.io/github/stars/JihyongOh/XVFI?style=social)](https://github.com/JihyongOh/XVFI) 7 | [![Demo views](https://img.shields.io/youtube/views/5qAiffYFJh8)](https://www.youtube.com/watch?v=5qAiffYFJh8&ab_channel=VICLabKAIST) 8 | ![visitors](https://visitor-badge.glitch.me/badge?page_id=JihyongOh/XVFI) 9 | 10 | **This is the official repository of XVFI (eXtreme Video Frame Interpolation)** 11 | 12 | \[[ArXiv_ver.](https://arxiv.org/abs/2103.16206)\] \[[ICCV2021_ver.](https://openaccess.thecvf.com/content/ICCV2021/papers/Sim_XVFI_eXtreme_Video_Frame_Interpolation_ICCV_2021_paper.pdf)\] \[[Supp.](https://openaccess.thecvf.com/content/ICCV2021/supplemental/Sim_XVFI_eXtreme_Video_ICCV_2021_supplemental.pdf)\] \[[Demo(YouTube)](https://www.youtube.com/watch?v=5qAiffYFJh8)\] \[[Oral12mins(YouTube)](https://www.youtube.com/watch?v=igwy1TJQiRc&t=13s)\] \[[Flowframes(GUI)](https://nmkd.itch.io/flowframes)\] \[[Poster](https://drive.google.com/file/d/16HtZdAKmUWLPkKR9FQX9ua-xDM3-ei6c/view?usp=sharing)\] 13 | 14 | 15 | 16 | 17 | Last Update: 20211130 - We provide extended input sequences for X-TEST. Please refer to [X4K1000FPS](#X4K1000FPS) 18 | 19 | We provide the training and test code along with the trained weights and the dataset (train+test) used for XVFI. 20 | If you find this repository useful, please consider citing our [paper](https://openaccess.thecvf.com/content/ICCV2021/papers/Sim_XVFI_eXtreme_Video_Frame_Interpolation_ICCV_2021_paper.pdf). 21 | 22 | ### Examples of the VFI (x8 Multi-Frame Interpolation) results on X-TEST 23 | ![results_045_resized](/figures/results_045_resized_768.gif "results_045_resized") 24 | ![results_079_resized](/figures/results_079_resized_768.gif "results_079_resized") 25 | ![results_158_resized](/figures/results_158_resized_768.gif "results_158_resized")\ 26 | The 4K@30fps input frames are interpolated to be 4K@240fps frames. All results are encoded at 30fps to be played as x8 slow motion and spatially down-scaled due to the limit of file sizes. All methods are trained on X-TRAIN. 27 | 28 | 29 | ## Table of Contents 30 | 1. [X4K1000FPS](#X4K1000FPS) 31 | 1. [Requirements](#Requirements) 32 | 1. [Test](#Test) 33 | 1. [Test_Custom](#Test_Custom) 34 | 1. [Training](#Training) 35 | 1. [Collection_of_Visual_Results](#Collection_of_Visual_Results) 36 | 1. [Reference](#Reference) 37 | 1. [Contact](#Contact) 38 | 39 | ## X4K1000FPS 40 | #### Dataset of high-resolution (4096×2160), high-fps (1000fps) video frames with extreme motion. 41 | ![003](/figures/003.gif "003") ![004](/figures/004.gif "004") ![045](/figures/045.gif "045") 42 | ![078](/figures/078.gif "078") ![081](/figures/081.gif "081") ![146](/figures/146.gif "146")\ 43 | Some examples of X4K1000FPS dataset, which are frames of 1000-fps and 4K-resolution. Our dataset contains the various scenes with extreme motions. (Displayed in spatiotemporally subsampled .gif files) 44 | 45 | We provide our X4K1000FPS dataset which consists of X-TEST and X-TRAIN. Please refer to our main/suppl. [paper](https://arxiv.org/abs/2103.16206) for the details of the dataset. You can download the dataset from this dropbox [link](https://www.dropbox.com/sh/duisote638etlv2/AABJw5Vygk94AWjGM4Se0Goza?dl=0). 46 | 47 | `X-TEST` consists of 15 video clips with 33-length of 4K-1000fps frames. It follows the below directory format: 48 | ``` 49 | ├──── YOUR_DIR/ 50 | ├──── test/ 51 | ├──── Type1/ 52 | ├──── TEST01/ 53 | ├──── 0000.png 54 | ├──── ... 55 | └──── 0032.png 56 | ├──── TEST02/ 57 | ├──── 0000.png 58 | ├──── ... 59 | └──── 0032.png 60 | ├──── ... 61 | ├──── ... 62 | ``` 63 | `Extended version of X-TEST` [issue#9](https://github.com/JihyongOh/XVFI/issues/9). 64 | As described in our paper, we assume that the number of input frames for VFI is fixed to 2 in X-TEST. However, for the VFI methods that require more than 2 input frames, we provide an **extended version of X-TEST** which contains **8 input frames** (in a temporal distance of 32 frames) for each test seqeuence. The middle two adjacent frames among the 8 frames are the same input frames in the original X-TEST. To sort .png files properly by their file names, we added 1000 to the frame indices (e.g. '0000.png' and '0032.png' in the original version of X-TEST correspond to '1000.png' and '1032.png', respectively, in the extended version of X-TEST). Please note that the extended one consists of input frames only, without the ground truth intermediate frames ('1001.png'~'1031.png'). In addition, for the sequence 'TEST11_078_f4977', '1064.png', '1096.png' and '1128.png' are replicated frames since '1064.png' is the last frame of the raw video file. 65 | The **extended version of X-TEST** can be downloaded from the [link](https://www.dropbox.com/s/tjmxauo05axoi5v/png_test_8_input_frames.zip?dl=0). 66 | 67 | 68 | `X-TRAIN` consists of 4,408 clips from various types of 110 scenes. The clips are 65-length of 1000fps frames. Each frame is the size of 768x768 cropped from 4K frame. It follows the below directory format: 69 | ``` 70 | ├──── YOUR_DIR/ 71 | ├──── train/ 72 | ├──── 002/ 73 | ├──── occ008.320/ 74 | ├──── 0000.png 75 | ├──── ... 76 | └──── 0064.png 77 | ├──── occ008.322/ 78 | ├──── 0000.png 79 | ├──── ... 80 | └──── 0064.png 81 | ├──── ... 82 | ├──── ... 83 | ``` 84 | 85 | After downloading the files from the link, decompress the `encoded_test.tar.gz` and `encoded_train.tar.gz`. The resulting .mp4 files can be decoded into .png files via running `mp4_decoding.py`. Please follow the instruction written in `mp4_decoding.py`. 86 | 87 | 88 | ## Requirements 89 | Our code is implemented using PyTorch1.7, and was tested under the following setting: 90 | * Python 3.7 91 | * PyTorch 1.7.1 92 | * CUDA 10.2 93 | * cuDNN 7.6.5 94 | * NVIDIA TITAN RTX GPU 95 | * Ubuntu 16.04 LTS 96 | 97 | **Caution**: since there is "align_corners" option in "nn.functional.interpolate" and "nn.functional.grid_sample" in PyTorch1.7, we recommend you to follow our settings. 98 | Especially, if you use the other PyTorch versions, it may lead to yield a different performance. 99 | 100 | 101 | 102 | ## Test 103 | ### Quick Start for X-TEST (x8 Multi-Frame Interpolation as in Table 2) 104 | 1. Download the source codes in a directory of your choice **\**. 105 | 2. First download our X-TEST test dataset by following the above section 'X4K1000FPS'. 106 | 3. Download the pre-trained weights, which was trained by X-TRAIN, from [this link](https://www.dropbox.com/s/xj2ixvay0e5ldma/XVFInet_X4K1000FPS_exp1_latest.pt?dl=0) to place in **\/checkpoint_dir/XVFInet_X4K1000FPS_exp1**. 107 | ``` 108 | XVFI 109 | └── checkpoint_dir 110 | └── XVFInet_X4K1000FPS_exp1 111 | ├── XVFInet_X4K1000FPS_exp1_latest.pt 112 | ``` 113 | 4. Run **main.py** with the following options in parse_args: 114 | ```bash 115 | python main.py --gpu 0 --phase 'test' --exp_num 1 --dataset 'X4K1000FPS' --module_scale_factor 4 --S_tst 5 --multiple 8 116 | ``` 117 | ==> It would yield **(PSNR/SSIM/tOF) = (30.12/0.870/2.15)**. 118 | ```bash 119 | python main.py --gpu 0 --phase 'test' --exp_num 1 --dataset 'X4K1000FPS' --module_scale_factor 4 --S_tst 3 --multiple 8 120 | ``` 121 | ==> It would yield **(PSNR/SSIM/tOF) = (28.86/0.858/2.67)**. 122 | 123 | 124 | 125 | ### Description 126 | * After running with the above test option, you can get the result images in **\/test_img_dir/XVFInet_X4K1000FPS_exp1**, then obtain the PSNR/SSIM/tOF results per each test clip as "total_metrics.csv" in the same folder. 127 | * Our proposed XVFI-Net can start from any downscaled input upward by regulating '--S_tst', which is adjustable in terms of 128 | the number of scales for inference according to the input resolutions or the motion magnitudes. 129 | * You can get any Multi-Frame Interpolation (x M) result by regulating '--multiple'. 130 | 131 | 132 | 133 | ### Quick Start for Vimeo90K (as in Fig. 8) 134 | 1. Download the source codes in a directory of your choice **\**. 135 | 2. First download Vimeo90K dataset from [this link](http://toflow.csail.mit.edu/) (including 'tri_trainlist.txt') to place in **\/vimeo_triplet**. 136 | ``` 137 | XVFI 138 | └── vimeo_triplet 139 | ├── sequences 140 | readme.txt 141 | tri_testlist.txt 142 | tri_trainlist.txt 143 | ``` 144 | 3. Download the pre-trained weights (XVFI-Net_v), which was trained by Vimeo90K, from [this link](https://www.dropbox.com/s/5v4dp81bto4x9xy/XVFInet_Vimeo_exp1_latest.pt?dl=0) to place in **\/checkpoint_dir/XVFInet_Vimeo_exp1**. 145 | ``` 146 | XVFI 147 | └── checkpoint_dir 148 | └── XVFInet_Vimeo_exp1 149 | ├── XVFInet_Vimeo_exp1_latest.pt 150 | ``` 151 | 4. Run **main.py** with the following options in parse_args: 152 | ```bash 153 | python main.py --gpu 0 --phase 'test' --exp_num 1 --dataset 'Vimeo' --module_scale_factor 2 --S_tst 1 --multiple 2 154 | ``` 155 | ==> It would yield **PSNR = 35.07** on Vimeo90K. 156 | 157 | ### Description 158 | * After running with the above test option, you can get the result images in **\/test_img_dir/XVFInet_Vimeo_exp1**. 159 | * There are certain code lines in front of the 'def main()' for a convenience when running with the Vimeo option. 160 | * The SSIM result of 0.9760 as in Fig. 8 was measured by matlab ssim function for a fair comparison after running the above guide because other SOTA methods did so. We also upload "compare_psnr_ssim.m" matlab file to obtain it. 161 | * ~~It should be noted that there is a typo "S_trn 162 | and S_tst are set to 2" in the current version of XVFI paper, which should be modified to 1 (not 2), sorry for inconvenience.~~ -> Updated in the latest arXiv version. 163 | 164 | ## Test_Custom 165 | ### Quick Start for your own video data ('--custom_path') for any Multi-Frame Interpolation (x M) 166 | 1. Download the source codes in a directory of your choice **\**. 167 | 2. First prepare your own video datasets in **\/custom_path** by following a hierarchy as belows: 168 | ``` 169 | XVFI 170 | └── custom_path 171 | ├── scene1 172 | ├── 'xxx.png' 173 | ├── ... 174 | └── 'xxx.png' 175 | ... 176 | 177 | ├── sceneN 178 | ├── 'xxxxx.png' 179 | ├── ... 180 | └── 'xxxxx.png' 181 | 182 | ``` 183 | 3. Download the pre-trained weights trained on [X-TRAIN](#quick-start-for-x-test-x8-multi-frame-interpolation-as-in-table-2) or [Vimeo90K](#quick-start-for-vimeo90k-as-in-fig-8) as decribed above. 184 | 185 | 4. Run **main.py** with the following options in parse_args (ex) x8 Multi-Frame Interpolation): 186 | ```bash 187 | # For the model trained on X-TRAIN 188 | python main.py --gpu 0 --phase 'test_custom' --exp_num 1 --dataset 'X4K1000FPS' --module_scale_factor 4 --S_tst 5 --multiple 8 --custom_path './custom_path' 189 | ``` 190 | ```bash 191 | # For the model trained on Vimeo90K 192 | python main.py --gpu 0 --phase 'test_custom' --exp_num 1 --dataset 'Vimeo' --module_scale_factor 2 --S_tst 1 --multiple 8 --custom_path './custom_path' 193 | ``` 194 | 195 | 196 | ### Description 197 | * Our proposed XVFI-Net can start from any downscaled input upward by regulating '--S_tst', which is adjustable in terms of 198 | the number of scales for inference according to the input resolutions or the motion magnitudes. 199 | * You can get any Multi-Frame Interpolation (x M) result by regulating '--multiple'. 200 | * It only supports for '.png' format. 201 | * Since we can not cover diverse possibilites of naming rule for custom frames, please sort your own frames properly. 202 | 203 | 204 | ## Training 205 | ### Quick Start for X-TRAIN 206 | 1. Download the source codes in a directory of your choice **\**. 207 | 2. First download our X-TRAIN train/val/test datasets by following the above section 'X4K1000FPS' and place them as belows: 208 | ``` 209 | XVFI 210 | └── X4K1000FPS 211 | ├── train 212 | ├── 002 213 | ├── ... 214 | └── 172 215 | ├── val 216 | ├── Type1 217 | ├── Type2 218 | ├── Type3 219 | ├── test 220 | ├── Type1 221 | ├── Type2 222 | ├── Type3 223 | 224 | ``` 225 | 3. Run **main.py** with the following options in parse_args: 226 | ```bash 227 | python main.py --phase 'train' --exp_num 1 --dataset 'X4K1000FPS' --module_scale_factor 4 --S_trn 3 --S_tst 5 228 | ``` 229 | ### Quick Start for Vimeo90K 230 | 1. Download the source codes in a directory of your choice **\**. 231 | 2. First download Vimeo90K dataset from [this link](http://toflow.csail.mit.edu/) (including 'tri_trainlist.txt') to place in **\/vimeo_triplet**. 232 | ``` 233 | XVFI 234 | └── vimeo_triplet 235 | ├── sequences 236 | readme.txt 237 | tri_testlist.txt 238 | tri_trainlist.txt 239 | ``` 240 | 3. Run **main.py** with the following options in parse_args: 241 | ```bash 242 | python main.py --phase 'train' --exp_num 1 --dataset 'Vimeo' --module_scale_factor 2 --S_trn 1 --S_tst 1 243 | ``` 244 | ### Description 245 | * You can freely regulate other arguments in the parser of **main.py**, [here](https://github.com/JihyongOh/XVFI/blob/484bdea1448c22459b10548a488909c268e1dde9/main.py#L12-L72) 246 | 247 | ## Collection_of_Visual_Results 248 | * We also provide all visual results (x8 Multi-Frame Interpolation) on X-TEST for an easier comparison as belows. Each zip file has about 1~1.5GB. 249 | * [AdaCoF*o*](https://www.dropbox.com/s/6ivl96nwrdl7oh1/AdaCoF_final_x8%20%28pretrained%2C%20original%29.zip?dl=0), [AdaCoF*f*](https://www.dropbox.com/s/3iqwzyns0jld2xp/AdaCoF_final_x8%20Retrain.zip?dl=0), [FeFlow*o*](https://www.dropbox.com/s/ukn8acqrim5vg7b/FeFlow_final_x8%20%28pretrained%2C%20original%29.zip?dl=0), [FeFlow*f*](https://www.dropbox.com/s/q26w3c9tm455jau/FeFlow_final_x8%20Retrain.zip?dl=0), [DAIN*o*](https://www.dropbox.com/s/yjtj4tvfhs2niqq/DAIN_final_x8%20%28pretrained%2C%20original%29.zip?dl=0), [DAIN*f*](https://www.dropbox.com/s/ftvimsx4czab5z4/DAIN_final_x8%20Retrain.zip?dl=0), [XVFI-Net](https://www.dropbox.com/s/3sbjjy226njk8by/XVFI-Net_final_Strn3_Stst3.zip?dl=0) (S*tst*=3), [XVFI-Net](https://www.dropbox.com/s/dgf61z08wab3jie/XVFI-Net_final_Strn3_Stst5.zip?dl=0) (S*tst*=5) 250 | * The quantitative comparisons (Table2 and Figure5) are attached as belows for a reference. 251 | ![Table2](/figures/Table2.PNG "Table2") 252 | ![Figure5](/figures/Figure5.PNG "Figure5")\ 253 | 254 | 255 | ## Reference 256 | > Hyeonjun Sim*, Jihyong Oh*, and Munchurl Kim "XVFI: eXtreme Video Frame Interpolation", In _ICCV_, 2021. (* *equal contribution*) 257 | > 258 | **BibTeX** 259 | ```bibtex 260 | @inproceedings{sim2021xvfi, 261 | title={XVFI: eXtreme Video Frame Interpolation}, 262 | author={Sim, Hyeonjun and Oh, Jihyong and Kim, Munchurl}, 263 | booktitle={Proceedings of the IEEE International Conference on Computer Vision (ICCV)}, 264 | year={2021} 265 | } 266 | ``` 267 | 268 | 269 | 270 | ## Contact 271 | If you have any question, please send an email to either \ 272 | [[Hyeonjun Sim](https://sites.google.com/view/hjsim)] - flhy5836@kaist.ac.kr or \ 273 | [[Jihyong Oh](https://sites.google.com/view/ozbro)] - jhoh94@kaist.ac.kr. 274 | 275 | ## License 276 | The source codes and datasets can be freely used for research and education only. Any commercial use should get formal permission first. 277 | -------------------------------------------------------------------------------- /scripts/XVFI-main/XVFInet.py: -------------------------------------------------------------------------------- 1 | import functools, random 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import numpy as np 7 | 8 | 9 | class XVFInet(nn.Module): 10 | 11 | def __init__(self, args): 12 | super(XVFInet, self).__init__() 13 | self.args = args 14 | self.device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') # will be used as "x.to(device)" 15 | self.nf = args.nf 16 | self.scale = args.module_scale_factor 17 | self.vfinet = VFInet(args) 18 | self.lrelu = nn.ReLU() 19 | self.in_channels = 3 20 | self.channel_converter = nn.Sequential( 21 | nn.Conv3d(self.in_channels, self.nf, [1, 3, 3], [1, 1, 1], [0, 1, 1]), 22 | nn.ReLU()) 23 | 24 | self.rec_ext_ds_module = [self.channel_converter] 25 | self.rec_ext_ds = nn.Conv3d(self.nf, self.nf, [1, 3, 3], [1, 2, 2], [0, 1, 1]) 26 | for _ in range(int(np.log2(self.scale))): 27 | self.rec_ext_ds_module.append(self.rec_ext_ds) 28 | self.rec_ext_ds_module.append(nn.ReLU()) 29 | self.rec_ext_ds_module.append(nn.Conv3d(self.nf, self.nf, [1, 3, 3], 1, [0, 1, 1])) 30 | self.rec_ext_ds_module.append(RResBlock2D_3D(args, T_reduce_flag=False)) 31 | self.rec_ext_ds_module = nn.Sequential(*self.rec_ext_ds_module) 32 | 33 | self.rec_ctx_ds = nn.Conv3d(self.nf, self.nf, [1, 3, 3], [1, 2, 2], [0, 1, 1]) 34 | 35 | print("The lowest scale depth for training (S_trn): ", self.args.S_trn) 36 | print("The lowest scale depth for test (S_tst): ", self.args.S_tst) 37 | 38 | def forward(self, x, t_value, is_training=True): 39 | ''' 40 | x shape : [B,C,T,H,W] 41 | t_value shape : [B,1] ############### 42 | ''' 43 | B, C, T, H, W = x.size() 44 | B2, C2 = t_value.size() 45 | assert C2 == 1, "t_value shape is [B,]" 46 | assert T % 2 == 0, "T must be an even number" 47 | t_value = t_value.view(B, 1, 1, 1) 48 | 49 | flow_l = None 50 | feat_x = self.rec_ext_ds_module(x) 51 | feat_x_list = [feat_x] 52 | self.lowest_depth_level = self.args.S_trn if is_training else self.args.S_tst 53 | for level in range(1, self.lowest_depth_level+1): 54 | feat_x = self.rec_ctx_ds(feat_x) 55 | feat_x_list.append(feat_x) 56 | 57 | if is_training: 58 | out_l_list = [] 59 | flow_refine_l_list = [] 60 | out_l, flow_l, flow_refine_l = self.vfinet(x, feat_x_list[self.args.S_trn], flow_l, t_value, level=self.args.S_trn, is_training=True) 61 | out_l_list.append(out_l) 62 | flow_refine_l_list.append(flow_refine_l) 63 | for level in range(self.args.S_trn-1, 0, -1): ## self.args.S_trn, self.args.S_trn-1, ..., 1. level 0 is not included 64 | out_l, flow_l = self.vfinet(x, feat_x_list[level], flow_l, t_value, level=level, is_training=True) 65 | out_l_list.append(out_l) 66 | out_l, flow_l, flow_refine_l, occ_0_l0 = self.vfinet(x, feat_x_list[0], flow_l, t_value, level=0, is_training=True) 67 | out_l_list.append(out_l) 68 | flow_refine_l_list.append(flow_refine_l) 69 | return out_l_list[::-1], flow_refine_l_list[::-1], occ_0_l0, torch.mean(x, dim=2) # out_l_list should be reversed. [out_l0, out_l1, ...] 70 | 71 | else: # Testing 72 | for level in range(self.args.S_tst, 0, -1): ## self.args.S_tst, self.args.S_tst-1, ..., 1. level 0 is not included 73 | flow_l = self.vfinet(x, feat_x_list[level], flow_l, t_value, level=level, is_training=False) 74 | out_l = self.vfinet(x, feat_x_list[0], flow_l, t_value, level=0, is_training=False) 75 | return out_l 76 | 77 | 78 | class VFInet(nn.Module): 79 | 80 | def __init__(self, args): 81 | super(VFInet, self).__init__() 82 | self.args = args 83 | self.device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') # will be used as "x.to(device)" 84 | self.nf = args.nf 85 | self.scale = args.module_scale_factor 86 | self.in_channels = 3 87 | 88 | self.conv_flow_bottom = nn.Sequential( 89 | nn.Conv2d(2*self.nf, 2*self.nf, [4,4], 2, [1,1]), 90 | nn.ReLU(), 91 | nn.Conv2d(2*self.nf, 4*self.nf, [4,4], 2, [1,1]), 92 | nn.ReLU(), 93 | nn.UpsamplingNearest2d(scale_factor=2), 94 | nn.Conv2d(4 * self.nf, 2 * self.nf, [3, 3], 1, [1, 1]), 95 | nn.ReLU(), 96 | nn.UpsamplingNearest2d(scale_factor=2), 97 | nn.Conv2d(2 * self.nf, self.nf, [3, 3], 1, [1, 1]), 98 | nn.ReLU(), 99 | nn.Conv2d(self.nf, 6, [3,3], 1, [1,1]), 100 | ) 101 | 102 | self.conv_flow1 = nn.Conv2d(2*self.nf, self.nf, [3, 3], 1, [1, 1]) 103 | 104 | self.conv_flow2 = nn.Sequential( 105 | nn.Conv2d(2*self.nf + 4, 2 * self.nf, [4, 4], 2, [1, 1]), 106 | nn.ReLU(), 107 | nn.Conv2d(2 * self.nf, 4 * self.nf, [4, 4], 2, [1, 1]), 108 | nn.ReLU(), 109 | nn.UpsamplingNearest2d(scale_factor=2), 110 | nn.Conv2d(4 * self.nf, 2 * self.nf, [3, 3], 1, [1, 1]), 111 | nn.ReLU(), 112 | nn.UpsamplingNearest2d(scale_factor=2), 113 | nn.Conv2d(2 * self.nf, self.nf, [3, 3], 1, [1, 1]), 114 | nn.ReLU(), 115 | nn.Conv2d(self.nf, 6, [3, 3], 1, [1, 1]), 116 | ) 117 | 118 | self.conv_flow3 = nn.Sequential( 119 | nn.Conv2d(4 + self.nf * 4, self.nf, [1, 1], 1, [0, 0]), 120 | nn.ReLU(), 121 | nn.Conv2d(self.nf, 2 * self.nf, [4, 4], 2, [1, 1]), 122 | nn.ReLU(), 123 | nn.Conv2d(2 * self.nf, 4 * self.nf, [4, 4], 2, [1, 1]), 124 | nn.ReLU(), 125 | nn.UpsamplingNearest2d(scale_factor=2), 126 | nn.Conv2d(4 * self.nf, 2 * self.nf, [3, 3], 1, [1, 1]), 127 | nn.ReLU(), 128 | nn.UpsamplingNearest2d(scale_factor=2), 129 | nn.Conv2d(2 * self.nf, self.nf, [3, 3], 1, [1, 1]), 130 | nn.ReLU(), 131 | nn.Conv2d(self.nf, 4, [3, 3], 1, [1, 1]), 132 | ) 133 | 134 | self.refine_unet = RefineUNet(args) 135 | self.lrelu = nn.ReLU() 136 | 137 | def forward(self, x, feat_x, flow_l_prev, t_value, level, is_training): 138 | ''' 139 | x shape : [B,C,T,H,W] 140 | t_value shape : [B,1] ############### 141 | ''' 142 | B, C, T, H, W = x.size() 143 | assert T % 2 == 0, "T must be an even number" 144 | 145 | ####################### For a single level 146 | l = 2 ** level 147 | x_l = x.permute(0,2,1,3,4) 148 | x_l = x_l.contiguous().view(B * T, C, H, W) 149 | 150 | if level == 0: 151 | pass 152 | else: 153 | x_l = F.interpolate(x_l, scale_factor=(1.0 / l, 1.0 / l), mode='bicubic', align_corners=False) 154 | ''' 155 | Down pixel-shuffle 156 | ''' 157 | x_l = x_l.view(B, T, C, H//l, W//l) 158 | x_l = x_l.permute(0,2,1,3,4) 159 | 160 | B, C, T, H, W = x_l.size() 161 | 162 | ## Feature extraction 163 | feat0_l = feat_x[:,:,0,:,:] 164 | feat1_l = feat_x[:,:,1,:,:] 165 | 166 | ## Flow estimation 167 | if flow_l_prev is None: 168 | flow_l_tmp = self.conv_flow_bottom(torch.cat((feat0_l, feat1_l), dim=1)) 169 | flow_l = flow_l_tmp[:,:4,:,:] 170 | else: 171 | up_flow_l_prev = 2.0*F.interpolate(flow_l_prev.detach(), scale_factor=(2,2), mode='bilinear', align_corners=False) 172 | warped_feat1_l = self.bwarp(feat1_l, up_flow_l_prev[:,:2,:,:]) 173 | warped_feat0_l = self.bwarp(feat0_l, up_flow_l_prev[:,2:,:,:]) 174 | flow_l_tmp = self.conv_flow2(torch.cat([self.conv_flow1(torch.cat([feat0_l, warped_feat1_l],dim=1)), self.conv_flow1(torch.cat([feat1_l, warped_feat0_l],dim=1)), up_flow_l_prev],dim=1)) 175 | flow_l = flow_l_tmp[:,:4,:,:] + up_flow_l_prev 176 | 177 | if not is_training and level!=0: 178 | return flow_l 179 | 180 | flow_01_l = flow_l[:,:2,:,:] 181 | flow_10_l = flow_l[:,2:,:,:] 182 | z_01_l = torch.sigmoid(flow_l_tmp[:,4:5,:,:]) 183 | z_10_l = torch.sigmoid(flow_l_tmp[:,5:6,:,:]) 184 | 185 | ## Complementary Flow Reversal (CFR) 186 | flow_forward, norm0_l = self.z_fwarp(flow_01_l, t_value * flow_01_l, z_01_l) ## Actually, F (t) -> (t+1). Translation only. Not normalized yet 187 | flow_backward, norm1_l = self.z_fwarp(flow_10_l, (1-t_value) * flow_10_l, z_10_l) ## Actually, F (1-t) -> (-t). Translation only. Not normalized yet 188 | 189 | flow_t0_l = -(1-t_value) * ((t_value)*flow_forward) + (t_value) * ((t_value)*flow_backward) # The numerator of Eq.(1) in the paper. 190 | flow_t1_l = (1-t_value) * ((1-t_value)*flow_forward) - (t_value) * ((1-t_value)*flow_backward) # The numerator of Eq.(2) in the paper. 191 | 192 | norm_l = (1-t_value)*norm0_l + t_value*norm1_l 193 | mask_ = (norm_l.detach() > 0).type(norm_l.type()) 194 | flow_t0_l = (1-mask_) * flow_t0_l + mask_ * (flow_t0_l.clone() / (norm_l.clone() + (1-mask_))) # Divide the numerator with denominator in Eq.(1) 195 | flow_t1_l = (1-mask_) * flow_t1_l + mask_ * (flow_t1_l.clone() / (norm_l.clone() + (1-mask_))) # Divide the numerator with denominator in Eq.(2) 196 | 197 | ## Feature warping 198 | warped0_l = self.bwarp(feat0_l, flow_t0_l) 199 | warped1_l = self.bwarp(feat1_l, flow_t1_l) 200 | 201 | ## Flow refinement 202 | flow_refine_l = torch.cat([feat0_l, warped0_l, warped1_l, feat1_l, flow_t0_l, flow_t1_l], dim=1) 203 | flow_refine_l = self.conv_flow3(flow_refine_l) + torch.cat([flow_t0_l, flow_t1_l], dim=1) 204 | flow_t0_l = flow_refine_l[:, :2, :, :] 205 | flow_t1_l = flow_refine_l[:, 2:4, :, :] 206 | 207 | warped0_l = self.bwarp(feat0_l, flow_t0_l) 208 | warped1_l = self.bwarp(feat1_l, flow_t1_l) 209 | 210 | ## Flow upscale 211 | flow_t0_l = self.scale * F.interpolate(flow_t0_l, scale_factor=(self.scale, self.scale), mode='bilinear',align_corners=False) 212 | flow_t1_l = self.scale * F.interpolate(flow_t1_l, scale_factor=(self.scale, self.scale), mode='bilinear',align_corners=False) 213 | 214 | ## Image warping and blending 215 | warped_img0_l = self.bwarp(x_l[:,:,0,:,:], flow_t0_l) 216 | warped_img1_l = self.bwarp(x_l[:,:,1,:,:], flow_t1_l) 217 | 218 | refine_out = self.refine_unet(torch.cat([F.pixel_shuffle(torch.cat([feat0_l, feat1_l, warped0_l, warped1_l],dim=1), self.scale), x_l[:,:,0,:,:], x_l[:,:,1,:,:], warped_img0_l, warped_img1_l, flow_t0_l, flow_t1_l],dim=1)) 219 | occ_0_l = torch.sigmoid(refine_out[:, 0:1, :, :]) 220 | occ_1_l = 1-occ_0_l 221 | 222 | out_l = (1-t_value)*occ_0_l*warped_img0_l + t_value*occ_1_l*warped_img1_l 223 | out_l = out_l / ( (1-t_value)*occ_0_l + t_value*occ_1_l ) + refine_out[:, 1:4, :, :] 224 | 225 | if not is_training and level==0: 226 | return out_l 227 | 228 | if is_training: 229 | if flow_l_prev is None: 230 | # if level == self.args.S_trn: 231 | return out_l, flow_l, flow_refine_l[:, 0:4, :, :] 232 | elif level != 0: 233 | return out_l, flow_l 234 | else: # level==0 235 | return out_l, flow_l, flow_refine_l[:, 0:4, :, :], occ_0_l 236 | 237 | def bwarp(self, x, flo): 238 | ''' 239 | x: [B, C, H, W] (im2) 240 | flo: [B, 2, H, W] flow 241 | ''' 242 | B, C, H, W = x.size() 243 | # mesh grid 244 | xx = torch.arange(0, W).view(1, 1, 1, W).expand(B, 1, H, W) 245 | yy = torch.arange(0, H).view(1, 1, H, 1).expand(B, 1, H, W) 246 | grid = torch.cat((xx, yy), 1).float() 247 | 248 | if x.is_cuda: 249 | grid = grid.to(self.device) 250 | vgrid = torch.autograd.Variable(grid) + flo 251 | 252 | # scale grid to [-1,1] 253 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0 254 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0 255 | 256 | vgrid = vgrid.permute(0, 2, 3, 1) # [B,H,W,2] 257 | output = nn.functional.grid_sample(x, vgrid, align_corners=True) 258 | mask = torch.autograd.Variable(torch.ones(x.size())).to(self.device) 259 | mask = nn.functional.grid_sample(mask, vgrid, align_corners=True) 260 | 261 | # mask[mask<0.9999] = 0 262 | # mask[mask>0] = 1 263 | mask = mask.masked_fill_(mask < 0.999, 0) 264 | mask = mask.masked_fill_(mask > 0, 1) 265 | 266 | return output * mask 267 | 268 | def fwarp(self, img, flo): 269 | 270 | """ 271 | -img: image (N, C, H, W) 272 | -flo: optical flow (N, 2, H, W) 273 | elements of flo is in [0, H] and [0, W] for dx, dy 274 | https://github.com/lyh-18/EQVI/blob/EQVI-master/models/forward_warp_gaussian.py 275 | """ 276 | 277 | # (x1, y1) (x1, y2) 278 | # +---------------+ 279 | # | | 280 | # | o(x, y) | 281 | # | | 282 | # | | 283 | # | | 284 | # | | 285 | # +---------------+ 286 | # (x2, y1) (x2, y2) 287 | 288 | N, C, _, _ = img.size() 289 | 290 | # translate start-point optical flow to end-point optical flow 291 | y = flo[:, 0:1:, :] 292 | x = flo[:, 1:2, :, :] 293 | 294 | x = x.repeat(1, C, 1, 1) 295 | y = y.repeat(1, C, 1, 1) 296 | 297 | # Four point of square (x1, y1), (x1, y2), (x2, y1), (y2, y2) 298 | x1 = torch.floor(x) 299 | x2 = x1 + 1 300 | y1 = torch.floor(y) 301 | y2 = y1 + 1 302 | 303 | # firstly, get gaussian weights 304 | w11, w12, w21, w22 = self.get_gaussian_weights(x, y, x1, x2, y1, y2) 305 | 306 | # secondly, sample each weighted corner 307 | img11, o11 = self.sample_one(img, x1, y1, w11) 308 | img12, o12 = self.sample_one(img, x1, y2, w12) 309 | img21, o21 = self.sample_one(img, x2, y1, w21) 310 | img22, o22 = self.sample_one(img, x2, y2, w22) 311 | 312 | imgw = img11 + img12 + img21 + img22 313 | o = o11 + o12 + o21 + o22 314 | 315 | return imgw, o 316 | 317 | 318 | def z_fwarp(self, img, flo, z): 319 | """ 320 | -img: image (N, C, H, W) 321 | -flo: optical flow (N, 2, H, W) 322 | elements of flo is in [0, H] and [0, W] for dx, dy 323 | modified from https://github.com/lyh-18/EQVI/blob/EQVI-master/models/forward_warp_gaussian.py 324 | """ 325 | 326 | # (x1, y1) (x1, y2) 327 | # +---------------+ 328 | # | | 329 | # | o(x, y) | 330 | # | | 331 | # | | 332 | # | | 333 | # | | 334 | # +---------------+ 335 | # (x2, y1) (x2, y2) 336 | 337 | N, C, _, _ = img.size() 338 | 339 | # translate start-point optical flow to end-point optical flow 340 | y = flo[:, 0:1:, :] 341 | x = flo[:, 1:2, :, :] 342 | 343 | x = x.repeat(1, C, 1, 1) 344 | y = y.repeat(1, C, 1, 1) 345 | 346 | # Four point of square (x1, y1), (x1, y2), (x2, y1), (y2, y2) 347 | x1 = torch.floor(x) 348 | x2 = x1 + 1 349 | y1 = torch.floor(y) 350 | y2 = y1 + 1 351 | 352 | # firstly, get gaussian weights 353 | w11, w12, w21, w22 = self.get_gaussian_weights(x, y, x1, x2, y1, y2, z+1e-5) 354 | 355 | # secondly, sample each weighted corner 356 | img11, o11 = self.sample_one(img, x1, y1, w11) 357 | img12, o12 = self.sample_one(img, x1, y2, w12) 358 | img21, o21 = self.sample_one(img, x2, y1, w21) 359 | img22, o22 = self.sample_one(img, x2, y2, w22) 360 | 361 | imgw = img11 + img12 + img21 + img22 362 | o = o11 + o12 + o21 + o22 363 | 364 | return imgw, o 365 | 366 | 367 | def get_gaussian_weights(self, x, y, x1, x2, y1, y2, z=1.0): 368 | # z 0.0 ~ 1.0 369 | w11 = z * torch.exp(-((x - x1) ** 2 + (y - y1) ** 2)) 370 | w12 = z * torch.exp(-((x - x1) ** 2 + (y - y2) ** 2)) 371 | w21 = z * torch.exp(-((x - x2) ** 2 + (y - y1) ** 2)) 372 | w22 = z * torch.exp(-((x - x2) ** 2 + (y - y2) ** 2)) 373 | 374 | return w11, w12, w21, w22 375 | 376 | def sample_one(self, img, shiftx, shifty, weight): 377 | """ 378 | Input: 379 | -img (N, C, H, W) 380 | -shiftx, shifty (N, c, H, W) 381 | """ 382 | 383 | N, C, H, W = img.size() 384 | 385 | # flatten all (all restored as Tensors) 386 | flat_shiftx = shiftx.view(-1) 387 | flat_shifty = shifty.view(-1) 388 | flat_basex = torch.arange(0, H, requires_grad=False).view(-1, 1)[None, None].to(self.device).long().repeat(N, C,1,W).view(-1) 389 | flat_basey = torch.arange(0, W, requires_grad=False).view(1, -1)[None, None].to(self.device).long().repeat(N, C,H,1).view(-1) 390 | flat_weight = weight.view(-1) 391 | flat_img = img.contiguous().view(-1) 392 | 393 | # The corresponding positions in I1 394 | idxn = torch.arange(0, N, requires_grad=False).view(N, 1, 1, 1).to(self.device).long().repeat(1, C, H, W).view(-1) 395 | idxc = torch.arange(0, C, requires_grad=False).view(1, C, 1, 1).to(self.device).long().repeat(N, 1, H, W).view(-1) 396 | idxx = flat_shiftx.long() + flat_basex 397 | idxy = flat_shifty.long() + flat_basey 398 | 399 | # recording the inside part the shifted 400 | mask = idxx.ge(0) & idxx.lt(H) & idxy.ge(0) & idxy.lt(W) 401 | 402 | # Mask off points out of boundaries 403 | ids = (idxn * C * H * W + idxc * H * W + idxx * W + idxy) 404 | ids_mask = torch.masked_select(ids, mask).clone().to(self.device) 405 | 406 | # Note here! accmulate fla must be true for proper bp 407 | img_warp = torch.zeros([N * C * H * W, ]).to(self.device) 408 | img_warp.put_(ids_mask, torch.masked_select(flat_img * flat_weight, mask), accumulate=True) 409 | 410 | one_warp = torch.zeros([N * C * H * W, ]).to(self.device) 411 | one_warp.put_(ids_mask, torch.masked_select(flat_weight, mask), accumulate=True) 412 | 413 | return img_warp.view(N, C, H, W), one_warp.view(N, C, H, W) 414 | 415 | class RefineUNet(nn.Module): 416 | def __init__(self, args): 417 | super(RefineUNet, self).__init__() 418 | self.args = args 419 | self.scale = args.module_scale_factor 420 | self.nf = args.nf 421 | self.conv1 = nn.Conv2d(self.nf, self.nf, [3,3], 1, [1,1]) 422 | self.conv2 = nn.Conv2d(self.nf, self.nf, [3,3], 1, [1,1]) 423 | self.lrelu = nn.ReLU() 424 | self.NN = nn.UpsamplingNearest2d(scale_factor=2) 425 | 426 | self.enc1 = nn.Conv2d((4*self.nf)//self.scale//self.scale + 4*args.img_ch + 4, self.nf, [4, 4], 2, [1, 1]) 427 | self.enc2 = nn.Conv2d(self.nf, 2*self.nf, [4, 4], 2, [1, 1]) 428 | self.enc3 = nn.Conv2d(2*self.nf, 4*self.nf, [4, 4], 2, [1, 1]) 429 | self.dec0 = nn.Conv2d(4*self.nf, 4*self.nf, [3, 3], 1, [1, 1]) 430 | self.dec1 = nn.Conv2d(4*self.nf + 2*self.nf, 2*self.nf, [3, 3], 1, [1, 1]) ## input concatenated with enc2 431 | self.dec2 = nn.Conv2d(2*self.nf + self.nf, self.nf, [3, 3], 1, [1, 1]) ## input concatenated with enc1 432 | self.dec3 = nn.Conv2d(self.nf, 1+args.img_ch, [3, 3], 1, [1, 1]) ## input added with warped image 433 | 434 | def forward(self, concat): 435 | enc1 = self.lrelu(self.enc1(concat)) 436 | enc2 = self.lrelu(self.enc2(enc1)) 437 | out = self.lrelu(self.enc3(enc2)) 438 | 439 | out = self.lrelu(self.dec0(out)) 440 | out = self.NN(out) 441 | 442 | out = torch.cat((out,enc2),dim=1) 443 | out = self.lrelu(self.dec1(out)) 444 | 445 | out = self.NN(out) 446 | out = torch.cat((out,enc1),dim=1) 447 | out = self.lrelu(self.dec2(out)) 448 | 449 | out = self.NN(out) 450 | out = self.dec3(out) 451 | return out 452 | 453 | class ResBlock2D_3D(nn.Module): 454 | ## Shape of input [B,C,T,H,W] 455 | ## Shape of output [B,C,T,H,W] 456 | def __init__(self, args): 457 | super(ResBlock2D_3D, self).__init__() 458 | self.args = args 459 | self.nf = args.nf 460 | 461 | self.conv3x3_1 = nn.Conv3d(self.nf, self.nf, [1,3,3], 1, [0,1,1]) 462 | self.conv3x3_2 = nn.Conv3d(self.nf, self.nf, [1,3,3], 1, [0,1,1]) 463 | self.lrelu = nn.ReLU() 464 | 465 | def forward(self, x): 466 | ''' 467 | x shape : [B,C,T,H,W] 468 | ''' 469 | B, C, T, H, W = x.size() 470 | 471 | out = self.conv3x3_2(self.lrelu(self.conv3x3_1(x))) 472 | 473 | return x + out 474 | 475 | class RResBlock2D_3D(nn.Module): 476 | 477 | def __init__(self, args, T_reduce_flag=False): 478 | super(RResBlock2D_3D, self).__init__() 479 | self.args = args 480 | self.nf = args.nf 481 | self.T_reduce_flag = T_reduce_flag 482 | self.resblock1 = ResBlock2D_3D(self.args) 483 | self.resblock2 = ResBlock2D_3D(self.args) 484 | if T_reduce_flag: 485 | self.reduceT_conv = nn.Conv3d(self.nf, self.nf, [3,1,1], 1, [0,0,0]) 486 | 487 | def forward(self, x): 488 | ''' 489 | x shape : [B,C,T,H,W] 490 | ''' 491 | out = self.resblock1(x) 492 | out = self.resblock2(out) 493 | if self.T_reduce_flag: 494 | return self.reduceT_conv(out + x) 495 | else: 496 | return out + x 497 | -------------------------------------------------------------------------------- /scripts/XVFI-main/__pycache__/XVFInet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/__pycache__/XVFInet.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/XVFI-main/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/XVFI-main/checkpoint_dir/XVFInet_Vimeo_exp1/info.txt: -------------------------------------------------------------------------------- 1 | *place the downloaded pre-trained weights 'XVFInet_Vimeo_exp1_latest.pt' in this folder* -------------------------------------------------------------------------------- /scripts/XVFI-main/checkpoint_dir/XVFInet_X4K1000FPS_exp1/XVFInet_X4K1000FPS_exp1_latest.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/checkpoint_dir/XVFInet_X4K1000FPS_exp1/XVFInet_X4K1000FPS_exp1_latest.pt -------------------------------------------------------------------------------- /scripts/XVFI-main/checkpoint_dir/XVFInet_X4K1000FPS_exp1/info.txt: -------------------------------------------------------------------------------- 1 | *place the downloaded pre-trained weights 'XVFInet_X4K1000FPS_exp1_latest.pt' in this folder* -------------------------------------------------------------------------------- /scripts/XVFI-main/compare_psnr_ssim.m: -------------------------------------------------------------------------------- 1 | close all; 2 | clear;clc; 3 | fid = fopen('./vimeo_triplet/tri_testlist.txt','r'); 4 | pred_path = './test_img_dir/XVFInet_Vimeo_exp1/epoch_00199_final_x2_S_tst1/'; 5 | 6 | gt_path = './vimeo_triplet/sequences/'; 7 | sample = fgetl(fid); 8 | cnt = 0; 9 | 10 | while ischar(sample) 11 | cnt = cnt+1; 12 | 13 | pred = imread(strcat(pred_path, sample, '/im2.png')); 14 | gt = imread(strcat(gt_path, sample, '/im2.png')); 15 | 16 | 17 | [h,w,c] = size(pred); 18 | 19 | pred_y = pred; 20 | gt_y = gt; 21 | 22 | total_psnr1(cnt) = psnr(pred, gt); 23 | total_ssim1(cnt) = ssim(pred_y, gt_y); 24 | fprintf('%s : %f, %f \n', sample, total_psnr1(cnt), total_ssim1(cnt)) 25 | 26 | sample = fgetl(fid); 27 | end 28 | cnt 29 | mean(total_psnr1) 30 | mean(total_ssim1) 31 | -------------------------------------------------------------------------------- /scripts/XVFI-main/figures/.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /scripts/XVFI-main/figures/003.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/figures/003.gif -------------------------------------------------------------------------------- /scripts/XVFI-main/figures/004.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/figures/004.gif -------------------------------------------------------------------------------- /scripts/XVFI-main/figures/045.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/figures/045.gif -------------------------------------------------------------------------------- /scripts/XVFI-main/figures/078.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/figures/078.gif -------------------------------------------------------------------------------- /scripts/XVFI-main/figures/081.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/figures/081.gif -------------------------------------------------------------------------------- /scripts/XVFI-main/figures/146.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/figures/146.gif -------------------------------------------------------------------------------- /scripts/XVFI-main/figures/Figure5.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/figures/Figure5.PNG -------------------------------------------------------------------------------- /scripts/XVFI-main/figures/Table2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/figures/Table2.PNG -------------------------------------------------------------------------------- /scripts/XVFI-main/figures/results_045_resized_768.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/figures/results_045_resized_768.gif -------------------------------------------------------------------------------- /scripts/XVFI-main/figures/results_079_resized_768.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/figures/results_079_resized_768.gif -------------------------------------------------------------------------------- /scripts/XVFI-main/figures/results_158_resized_768.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/figures/results_158_resized_768.gif -------------------------------------------------------------------------------- /scripts/XVFI-main/main.py: -------------------------------------------------------------------------------- 1 | import argparse, os, shutil, time, random, torch, cv2, datetime, torch.utils.data, math 2 | import torch.backends.cudnn as cudnn 3 | import torch.optim as optim 4 | import numpy as np 5 | 6 | from torch.autograd import Variable 7 | from utils import * 8 | from XVFInet import * 9 | from collections import Counter 10 | 11 | 12 | def parse_args(): 13 | desc = "PyTorch implementation for XVFI" 14 | parser = argparse.ArgumentParser(description=desc) 15 | parser.add_argument('--gpu', type=int, default=0, help='gpu index') 16 | parser.add_argument('--net_type', type=str, default='XVFInet', choices=['XVFInet'], help='The type of Net') 17 | parser.add_argument('--net_object', default=XVFInet, choices=[XVFInet], help='The type of Net') 18 | parser.add_argument('--exp_num', type=int, default=1, help='The experiment number') 19 | parser.add_argument('--phase', type=str, default='test', choices=['train', 'test', 'test_custom', 'metrics_evaluation',]) 20 | parser.add_argument('--continue_training', action='store_true', default=False, help='continue the training') 21 | 22 | """ Information of directories """ 23 | parser.add_argument('--test_img_dir', type=str, default='./test_img_dir', help='test_img_dir path') 24 | parser.add_argument('--text_dir', type=str, default='./text_dir', help='text_dir path') 25 | parser.add_argument('--checkpoint_dir', type=str, default='./checkpoint_dir', help='checkpoint_dir') 26 | parser.add_argument('--log_dir', type=str, default='./log_dir', help='Directory name to save training logs') 27 | 28 | parser.add_argument('--dataset', default='X4K1000FPS', choices=['X4K1000FPS', 'Vimeo'], 29 | help='Training/test Dataset') 30 | 31 | # parser.add_argument('--train_data_path', type=str, default='./X4K1000FPS/train') 32 | # parser.add_argument('--val_data_path', type=str, default='./X4K1000FPS/val') 33 | # parser.add_argument('--test_data_path', type=str, default='./X4K1000FPS/test') 34 | parser.add_argument('--train_data_path', type=str, default='../Datasets/VIC_4K_1000FPS/train') 35 | parser.add_argument('--val_data_path', type=str, default='../Datasets/VIC_4K_1000FPS/val') 36 | parser.add_argument('--test_data_path', type=str, default='../Datasets/VIC_4K_1000FPS/test') 37 | 38 | 39 | parser.add_argument('--vimeo_data_path', type=str, default='./vimeo_triplet') 40 | 41 | """ Hyperparameters for Training (when [phase=='train']) """ 42 | parser.add_argument('--epochs', type=int, default=200, help='The number of epochs to run') 43 | parser.add_argument('--freq_display', type=int, default=100, help='The number of iterations frequency for display') 44 | parser.add_argument('--save_img_num', type=int, default=4, 45 | help='The number of saved image while training for visualization. It should smaller than the batch_size') 46 | parser.add_argument('--init_lr', type=float, default=1e-4, help='The initial learning rate') 47 | parser.add_argument('--lr_dec_fac', type=float, default=0.25, help='step - lr_decreasing_factor') 48 | parser.add_argument('--lr_milestones', type=int, default=[100, 150, 180]) 49 | parser.add_argument('--lr_dec_start', type=int, default=0, 50 | help='When scheduler is StepLR, lr decreases from epoch at lr_dec_start') 51 | parser.add_argument('--batch_size', type=int, default=8, help='The size of batch size.') 52 | parser.add_argument('--weight_decay', type=float, default=0, help='for optim., weight decay (default: 0)') 53 | 54 | parser.add_argument('--need_patch', default=True, help='get patch form image while training') 55 | parser.add_argument('--img_ch', type=int, default=3, help='base number of channels for image') 56 | parser.add_argument('--nf', type=int, default=64, help='base number of channels for feature maps') # 64 57 | parser.add_argument('--module_scale_factor', type=int, default=4, help='sptial reduction for pixelshuffle') 58 | parser.add_argument('--patch_size', type=int, default=384, help='patch size') 59 | parser.add_argument('--num_thrds', type=int, default=4, help='number of threads for data loading') 60 | parser.add_argument('--loss_type', default='L1', choices=['L1', 'MSE', 'L1_Charbonnier_loss'], help='Loss type') 61 | 62 | parser.add_argument('--S_trn', type=int, default=3, help='The lowest scale depth for training') 63 | parser.add_argument('--S_tst', type=int, default=5, help='The lowest scale depth for test') 64 | 65 | """ Weighting Parameters Lambda for Losses (when [phase=='train']) """ 66 | parser.add_argument('--rec_lambda', type=float, default=1.0, help='Lambda for Reconstruction Loss') 67 | 68 | """ Settings for Testing (when [phase=='test' or 'test_custom']) """ 69 | parser.add_argument('--saving_flow_flag', default=False) 70 | parser.add_argument('--multiple', type=int, default=8, help='Due to the indexing problem of the file names, we recommend to use the power of 2. (e.g. 2, 4, 8, 16 ...). CAUTION : For the provided X-TEST, multiple should be one of [2, 4, 8, 16, 32].') 71 | parser.add_argument('--metrics_types', type=list, default=["PSNR", "SSIM", "tOF"], choices=["PSNR", "SSIM", "tOF"]) 72 | 73 | """ Settings for test_custom (when [phase=='test_custom']) """ 74 | parser.add_argument('--custom_path', type=str, default='./custom_path', help='path for custom video containing frames') 75 | 76 | return check_args(parser.parse_args()) 77 | 78 | 79 | def check_args(args): 80 | # --checkpoint_dir 81 | check_folder(args.checkpoint_dir) 82 | 83 | # --text_dir 84 | check_folder(args.text_dir) 85 | 86 | # --log_dir 87 | check_folder(args.log_dir) 88 | 89 | # --test_img_dir 90 | check_folder(args.test_img_dir) 91 | 92 | return args 93 | 94 | 95 | def main(): 96 | args = parse_args() 97 | if args.dataset == 'Vimeo': 98 | if args.phase != 'test_custom': 99 | args.multiple = 2 100 | args.S_trn = 1 101 | args.S_tst = 1 102 | args.module_scale_factor = 2 103 | args.patch_size = 256 104 | args.batch_size = 16 105 | print('vimeo triplet data dir : ', args.vimeo_data_path) 106 | 107 | print("Exp:", args.exp_num) 108 | args.model_dir = args.net_type + '_' + args.dataset + '_exp' + str( 109 | args.exp_num) # ex) model_dir = "XVFInet_X4K1000FPS_exp1" 110 | 111 | if args is None: 112 | exit() 113 | for arg in vars(args): 114 | print('# {} : {}'.format(arg, getattr(args, arg))) 115 | device = torch.device( 116 | 'cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') # will be used as "x.to(device)" 117 | torch.cuda.set_device(device) # change allocation of current GPU 118 | # caution!!!! if not "torch.cuda.set_device()": 119 | # RuntimeError: grid_sampler(): expected input and grid to be on same device, but input is on cuda:1 and grid is on cuda:0 120 | print('Available devices: ', torch.cuda.device_count()) 121 | print('Current cuda device: ', torch.cuda.current_device()) 122 | print('Current cuda device name: ', torch.cuda.get_device_name(device)) 123 | if args.gpu is not None: 124 | print("Use GPU: {} is used".format(args.gpu)) 125 | 126 | SM = save_manager(args) 127 | 128 | """ Initialize a model """ 129 | model_net = args.net_object(args).apply(weights_init).to(device) 130 | criterion = [set_rec_loss(args).to(device), set_smoothness_loss().to(device)] 131 | 132 | # to enable the inbuilt cudnn auto-tuner 133 | # to find the best algorithm to use for your hardware. 134 | cudnn.benchmark = True 135 | 136 | if args.phase == "train": 137 | train(model_net, criterion, device, SM, args) 138 | epoch = args.epochs - 1 139 | 140 | elif args.phase == "test" or args.phase == "metrics_evaluation" or args.phase == 'test_custom': 141 | checkpoint = SM.load_model() 142 | model_net.load_state_dict(checkpoint['state_dict_Model']) 143 | epoch = checkpoint['last_epoch'] 144 | 145 | postfix = '_final_x' + str(args.multiple) + '_S_tst' + str(args.S_tst) 146 | if args.phase != "metrics_evaluation": 147 | print("\n-------------------------------------- Final Test starts -------------------------------------- ") 148 | print('Evaluate on test set (final test) with multiple = %d ' % (args.multiple)) 149 | 150 | final_test_loader = get_test_data(args, multiple=args.multiple, 151 | validation=False) # multiple is only used for X4K1000FPS 152 | 153 | testLoss, testPSNR, testSSIM, final_pred_save_path = test(final_test_loader, model_net, 154 | criterion, epoch, 155 | args, device, 156 | multiple=args.multiple, 157 | postfix=postfix, validation=False) 158 | SM.write_info('Final 4k frames PSNR : {:.4}\n'.format(testPSNR)) 159 | 160 | if args.dataset == 'X4K1000FPS' and args.phase != 'test_custom': 161 | final_pred_save_path = os.path.join(args.test_img_dir, args.model_dir, 'epoch_' + str(epoch).zfill(5)) + postfix 162 | metrics_evaluation_X_Test(final_pred_save_path, args.test_data_path, args.metrics_types, 163 | flow_flag=args.saving_flow_flag, multiple=args.multiple) 164 | 165 | 166 | 167 | print("------------------------- Test has been ended. -------------------------\n") 168 | print("Exp:", args.exp_num) 169 | 170 | 171 | def train(model_net, criterion, device, save_manager, args): 172 | SM = save_manager 173 | multi_scale_recon_loss = criterion[0] 174 | smoothness_loss = criterion[1] 175 | 176 | optimizer = optim.Adam(model_net.parameters(), lr=args.init_lr, betas=(0.9, 0.999), 177 | weight_decay=args.weight_decay) # optimizer 178 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=args.lr_dec_fac) 179 | 180 | last_epoch = 0 181 | best_PSNR = 0.0 182 | 183 | if args.continue_training: 184 | checkpoint = SM.load_model() 185 | last_epoch = checkpoint['last_epoch'] + 1 186 | best_PSNR = checkpoint['best_PSNR'] 187 | model_net.load_state_dict(checkpoint['state_dict_Model']) 188 | optimizer.load_state_dict(checkpoint['state_dict_Optimizer']) 189 | scheduler.load_state_dict(checkpoint['state_dict_Scheduler']) 190 | print("Optimizer and Scheduler have been reloaded. ") 191 | scheduler.milestones = Counter(args.lr_milestones) 192 | scheduler.gamma = args.lr_dec_fac 193 | print("scheduler.milestones : {}, scheduler.gamma : {}".format(scheduler.milestones, scheduler.gamma)) 194 | start_epoch = last_epoch 195 | 196 | # switch to train mode 197 | model_net.train() 198 | 199 | start_time = time.time() 200 | 201 | SM.write_info('Epoch\ttrainLoss\ttestPSNR\tbest_PSNR\n') 202 | print("[*] Training starts") 203 | 204 | # Main training loop for total epochs (start from 'epoch=0') 205 | valid_loader = get_test_data(args, multiple=4, validation=True) # multiple is only used for X4K1000FPS 206 | 207 | for epoch in range(start_epoch, args.epochs): 208 | train_loader = get_train_data(args, 209 | max_t_step_size=32) # max_t_step_size (temporal distance) is only used for X4K1000FPS 210 | 211 | batch_time = AverageClass('batch_time[s]:', ':6.3f') 212 | losses = AverageClass('Loss:', ':.4e') 213 | progress = ProgressMeter(len(train_loader), batch_time, losses, prefix="Epoch: [{}]".format(epoch)) 214 | 215 | print('Start epoch {} at [{:s}], learning rate : [{}]'.format(epoch, (str(datetime.now())[:-7]), 216 | optimizer.param_groups[0]['lr'])) 217 | 218 | # train for one epoch 219 | for trainIndex, (frames, t_value) in enumerate(train_loader): 220 | 221 | input_frames = frames[:, :, :-1, :] # [B, C, T, H, W] 222 | frameT = frames[:, :, -1, :] # [B, C, H, W] 223 | 224 | # Getting the input and the target from the training set 225 | input_frames = Variable(input_frames.to(device)) 226 | frameT = Variable(frameT.to(device)) # ground truth for frameT 227 | t_value = Variable(t_value.to(device)) # [B,1] 228 | 229 | optimizer.zero_grad() 230 | # compute output 231 | pred_frameT_pyramid, pred_flow_pyramid, occ_map, simple_mean = model_net(input_frames, t_value) 232 | rec_loss = 0.0 233 | smooth_loss = 0.0 234 | for l, pred_frameT_l in enumerate(pred_frameT_pyramid): 235 | rec_loss += args.rec_lambda * multi_scale_recon_loss(pred_frameT_l, 236 | F.interpolate(frameT, scale_factor=1 / (2 ** l), 237 | mode='bicubic', align_corners=False)) 238 | smooth_loss += 0.5 * smoothness_loss(pred_flow_pyramid[0], 239 | F.interpolate(frameT, scale_factor=1 / args.module_scale_factor, 240 | mode='bicubic', 241 | align_corners=False)) # Apply 1st order edge-aware smoothness loss to the fineset level 242 | rec_loss /= len(pred_frameT_pyramid) 243 | pred_frameT = pred_frameT_pyramid[0] # final result I^0_t at original scale (s=0) 244 | pred_coarse_flow = 2 ** (args.S_trn) * F.interpolate(pred_flow_pyramid[-1], scale_factor=2 ** ( 245 | args.S_trn) * args.module_scale_factor, mode='bicubic', align_corners=False) 246 | pred_fine_flow = F.interpolate(pred_flow_pyramid[0], scale_factor=args.module_scale_factor, mode='bicubic', 247 | align_corners=False) 248 | 249 | total_loss = rec_loss + smooth_loss 250 | 251 | # compute gradient and do SGD step 252 | total_loss.backward() # Backpropagate 253 | optimizer.step() # Optimizer update 254 | 255 | # measure accumulated time and update average "batch" time consumptions via "AverageClass" 256 | # update average values via "AverageClass" 257 | losses.update(total_loss.item(), 1) 258 | batch_time.update(time.time() - start_time) 259 | start_time = time.time() 260 | 261 | if trainIndex % args.freq_display == 0: 262 | progress.print(trainIndex) 263 | batch_images = get_batch_images(args, save_img_num=args.save_img_num, 264 | save_images=[pred_frameT, pred_coarse_flow, pred_fine_flow, frameT, 265 | simple_mean, occ_map]) 266 | cv2.imwrite(os.path.join(args.log_dir, '{:03d}_{:04d}_training.png'.format(epoch, trainIndex)), batch_images) 267 | 268 | 269 | 270 | if epoch >= args.lr_dec_start: 271 | scheduler.step() 272 | 273 | # if (epoch + 1) % 10 == 0 or epoch==0: 274 | val_multiple = 4 if args.dataset == 'X4K1000FPS' else 2 275 | print('\nEvaluate on test set (validation while training) with multiple = {}'.format(val_multiple)) 276 | postfix = '_val_' + str(val_multiple) + '_S_tst' + str(args.S_tst) 277 | testLoss, testPSNR, testSSIM, final_pred_save_path = test(valid_loader, model_net, criterion, epoch, args, 278 | device, multiple=val_multiple, postfix=postfix, 279 | validation=True) 280 | 281 | # remember best best_PSNR and best_SSIM and save checkpoint 282 | print("best_PSNR : {:.3f}, testPSNR : {:.3f}".format(best_PSNR, testPSNR)) 283 | best_PSNR_flag = testPSNR > best_PSNR 284 | best_PSNR = max(testPSNR, best_PSNR) 285 | # save checkpoint. 286 | combined_state_dict = { 287 | 'net_type': args.net_type, 288 | 'last_epoch': epoch, 289 | 'batch_size': args.batch_size, 290 | 'trainLoss': losses.avg, 291 | 'testLoss': testLoss, 292 | 'testPSNR': testPSNR, 293 | 'best_PSNR': best_PSNR, 294 | 'state_dict_Model': model_net.state_dict(), 295 | 'state_dict_Optimizer': optimizer.state_dict(), 296 | 'state_dict_Scheduler': scheduler.state_dict()} 297 | 298 | SM.save_best_model(combined_state_dict, best_PSNR_flag) 299 | 300 | if (epoch + 1) % 10 == 0: 301 | SM.save_epc_model(combined_state_dict, epoch) 302 | SM.write_info('{}\t{:.4}\t{:.4}\t{:.4}\n'.format(epoch, losses.avg, testPSNR, best_PSNR)) 303 | 304 | print("------------------------- Training has been ended. -------------------------\n") 305 | print("information of model:", args.model_dir) 306 | print("best_PSNR of model:", best_PSNR) 307 | 308 | 309 | def test(test_loader, model_net, criterion, epoch, args, device, multiple, postfix, validation): 310 | batch_time = AverageClass('Time:', ':6.3f') 311 | losses = AverageClass('testLoss:', ':.4e') 312 | PSNRs = AverageClass('testPSNR:', ':.4e') 313 | SSIMs = AverageClass('testSSIM:', ':.4e') 314 | args.divide = 2 ** (args.S_tst) * args.module_scale_factor * 4 315 | 316 | # progress = ProgressMeter(len(test_loader), batch_time, accm_time, losses, PSNRs, SSIMs, prefix='Test after Epoch[{}]: '.format(epoch)) 317 | progress = ProgressMeter(len(test_loader), PSNRs, SSIMs, prefix='Test after Epoch[{}]: '.format(epoch)) 318 | 319 | multi_scale_recon_loss = criterion[0] 320 | 321 | # switch to evaluate mode 322 | model_net.eval() 323 | 324 | print("------------------------------------------- Test ----------------------------------------------") 325 | with torch.no_grad(): 326 | start_time = time.time() 327 | for testIndex, (frames, t_value, scene_name, frameRange) in enumerate(test_loader): 328 | # Shape of 'frames' : [1,C,T+1,H,W] 329 | frameT = frames[:, :, -1, :, :] # [1,C,H,W] 330 | It_Path, I0_Path, I1_Path = frameRange 331 | 332 | frameT = Variable(frameT.to(device)) # ground truth for frameT 333 | t_value = Variable(t_value.to(device)) 334 | 335 | if (testIndex % (multiple - 1)) == 0: 336 | input_frames = frames[:, :, :-1, :, :] # [1,C,T,H,W] 337 | input_frames = Variable(input_frames.to(device)) 338 | 339 | B, C, T, H, W = input_frames.size() 340 | H_padding = (args.divide - H % args.divide) % args.divide 341 | W_padding = (args.divide - W % args.divide) % args.divide 342 | if H_padding != 0 or W_padding != 0: 343 | input_frames = F.pad(input_frames, (0, W_padding, 0, H_padding), "constant") 344 | 345 | 346 | pred_frameT = model_net(input_frames, t_value, is_training=False) 347 | 348 | if H_padding != 0 or W_padding != 0: 349 | pred_frameT = pred_frameT[:, :, :H, :W] 350 | 351 | 352 | 353 | if args.phase != 'test_custom': 354 | test_loss = args.rec_lambda * multi_scale_recon_loss(pred_frameT, frameT) 355 | 356 | pred_frameT = np.squeeze(pred_frameT.detach().cpu().numpy()) 357 | frameT = np.squeeze(frameT.detach().cpu().numpy()) 358 | 359 | """ compute PSNR & SSIM """ 360 | output_img = np.around(denorm255_np(np.transpose(pred_frameT, [1, 2, 0]))) # [h,w,c] and [-1,1] to [0,255] 361 | target_img = denorm255_np(np.transpose(frameT, [1, 2, 0])) # [h,w,c] and [-1,1] to [0,255] 362 | 363 | test_psnr = psnr(target_img, output_img) 364 | test_ssim = ssim_bgr(target_img, output_img) ############### CAUTION: calculation for BGR 365 | 366 | """ save frame0 & frame1 """ 367 | if validation: 368 | epoch_save_path = os.path.join(args.test_img_dir, args.model_dir, 'latest' + postfix) 369 | else: 370 | epoch_save_path = os.path.join(args.test_img_dir, args.model_dir, 371 | 'epoch_' + str(epoch).zfill(5) + postfix) 372 | check_folder(epoch_save_path) 373 | scene_save_path = os.path.join(epoch_save_path, scene_name[0]) 374 | check_folder(scene_save_path) 375 | 376 | if (testIndex % (multiple - 1)) == 0: 377 | save_input_frames = frames[:, :, :-1, :, :] 378 | cv2.imwrite(os.path.join(scene_save_path, I0_Path[0]), 379 | np.transpose(np.squeeze(denorm255_np(save_input_frames[:, :, 0, :, :].detach().numpy())), 380 | [1, 2, 0]).astype(np.uint8)) 381 | cv2.imwrite(os.path.join(scene_save_path, I1_Path[0]), 382 | np.transpose(np.squeeze(denorm255_np(save_input_frames[:, :, 1, :, :].detach().numpy())), 383 | [1, 2, 0]).astype(np.uint8)) 384 | 385 | cv2.imwrite(os.path.join(scene_save_path, It_Path[0]), output_img.astype(np.uint8)) 386 | 387 | # measure 388 | losses.update(test_loss.item(), 1) 389 | PSNRs.update(test_psnr, 1) 390 | SSIMs.update(test_ssim, 1) 391 | 392 | # measure elapsed time 393 | batch_time.update(time.time() - start_time) 394 | start_time = time.time() 395 | 396 | if (testIndex % (multiple - 1)) == multiple - 2: 397 | progress.print(testIndex) 398 | 399 | else: 400 | epoch_save_path = args.custom_path 401 | scene_save_path = os.path.join(epoch_save_path, scene_name[0]) 402 | pred_frameT = np.squeeze(pred_frameT.detach().cpu().numpy()) 403 | output_img = np.around(denorm255_np(np.transpose(pred_frameT, [1, 2, 0]))) # [h,w,c] and [-1,1] to [0,255] 404 | print(os.path.join(scene_save_path, It_Path[0])) 405 | cv2.imwrite(os.path.join(scene_save_path, It_Path[0]), output_img.astype(np.uint8)) 406 | 407 | losses.update(0.0, 1) 408 | PSNRs.update(0.0, 1) 409 | SSIMs.update(0.0, 1) 410 | 411 | print("-----------------------------------------------------------------------------------------------") 412 | 413 | return losses.avg, PSNRs.avg, SSIMs.avg, epoch_save_path 414 | 415 | 416 | if __name__ == '__main__': 417 | main() 418 | -------------------------------------------------------------------------------- /scripts/XVFI-main/run.sh: -------------------------------------------------------------------------------- 1 | python main.py --gpu 0 --phase test_custom --exp_num 1 --dataset X4K1000FPS --module_scale_factor 4 --S_tst 5 --multiple 8 --custom_path D:\Code\SpikeCamera\Nerf\blend_files\dataset\raw_data 2 | -------------------------------------------------------------------------------- /scripts/XVFI-main/test.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/test.gif -------------------------------------------------------------------------------- /scripts/XVFI-main/text_dir/XVFInet_X4K1000FPS_exp1.txt: -------------------------------------------------------------------------------- 1 | ----- Model parameters ----- 2 | 2023-11-14 19:45:33 3 | gpu : 0 4 | net_type : XVFInet 5 | net_object : 6 | exp_num : 1 7 | phase : test_custom 8 | continue_training : False 9 | test_img_dir : ./test_img_dir 10 | text_dir : ./text_dir 11 | checkpoint_dir : ./checkpoint_dir 12 | log_dir : ./log_dir 13 | dataset : X4K1000FPS 14 | train_data_path : ../Datasets/VIC_4K_1000FPS/train 15 | val_data_path : ../Datasets/VIC_4K_1000FPS/val 16 | test_data_path : ../Datasets/VIC_4K_1000FPS/test 17 | vimeo_data_path : ./vimeo_triplet 18 | epochs : 200 19 | freq_display : 100 20 | save_img_num : 4 21 | init_lr : 0.0001 22 | lr_dec_fac : 0.25 23 | lr_milestones : [100, 150, 180] 24 | lr_dec_start : 0 25 | batch_size : 8 26 | weight_decay : 0 27 | need_patch : True 28 | img_ch : 3 29 | nf : 64 30 | module_scale_factor : 4 31 | patch_size : 384 32 | num_thrds : 4 33 | loss_type : L1 34 | S_trn : 3 35 | S_tst : 5 36 | rec_lambda : 1.0 37 | saving_flow_flag : False 38 | multiple : 8 39 | metrics_types : ['PSNR', 'SSIM', 'tOF'] 40 | custom_path : /d/Code/SpikeCamera/Nerf/blend_files/dataset/train_bg 41 | model_dir : XVFInet_X4K1000FPS_exp1 42 | Final 4k frames PSNR : 0.0 43 | Final 4k frames PSNR : 0.0 44 | Final 4k frames PSNR : 0.0 45 | Final 4k frames PSNR : 0.0 46 | Final 4k frames PSNR : 0.0 47 | Final 4k frames PSNR : 0.0 48 | Final 4k frames PSNR : 0.0 49 | Final 4k frames PSNR : 0.0 50 | Final 4k frames PSNR : 0.0 51 | Final 4k frames PSNR : 0.0 52 | Final 4k frames PSNR : 0.0 53 | Final 4k frames PSNR : 0.0 54 | Final 4k frames PSNR : 0.0 55 | Final 4k frames PSNR : 0.0 56 | Final 4k frames PSNR : 0.0 57 | Final 4k frames PSNR : 0.0 58 | Final 4k frames PSNR : 0.0 59 | Final 4k frames PSNR : 0.0 60 | Final 4k frames PSNR : 0.0 61 | Final 4k frames PSNR : 0.0 62 | Final 4k frames PSNR : 0.0 63 | -------------------------------------------------------------------------------- /scripts/__pycache__/generate_spike.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/__pycache__/generate_spike.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/blur_syn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from tqdm import trange 5 | import argparse 6 | from torchvision import transforms 7 | 8 | #? convert the imgs under raw_folder to blurry imgs on blur_folder 9 | #? Structure as: 10 | #? base_folder 11 | #? ├── blur_folder 12 | #? ├── raw_folder 13 | 14 | # main function 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--raw_folder', type=str, default='GOPRO/test/raw_data') 18 | parser.add_argument('--img_type', type=str, default='.png') 19 | parser.add_argument('--overlap_len', type=int, default= 7, 20 | help = 'Overlap length between two blurry images. Assume [0-12] -> blur_6, [13,25] -> blur_19, overlap_len is the length of interpolated frames between 12.png and 13.png,i.e.,12_0.png,...12_6.png') 21 | parser.add_argument('--height', type=int, default= 720) 22 | parser.add_argument('--width', type=int, default= 1280) 23 | parser.add_argument('--use_resize', action='store_true', help = 'Resize the image size to the half.') 24 | parser.add_argument('--blur_num', type=int, default= 13,help = 'Number of images before interpolation to synthesize one blurry frame.') 25 | 26 | opt = parser.parse_args() 27 | width = opt.width 28 | height = opt.height 29 | use_resize = opt.use_resize 30 | raw_folder = opt.raw_folder 31 | base_folder = os.path.dirname(raw_folder) 32 | blur_folder = os.path.join(base_folder,'blur_data') 33 | os.makedirs(blur_folder,exist_ok = True) 34 | 35 | for dirpath, sub_dirs, sub_files in os.walk(raw_folder): 36 | if len(sub_files) == 0 or sub_files[0].endswith(opt.img_type) == False: 37 | continue 38 | print(dirpath) 39 | output_folder = dirpath.replace(raw_folder,blur_folder) 40 | os.makedirs(output_folder,exist_ok = True) 41 | sub_files = sorted(sub_files) 42 | bais = 0 # bais that overlap between two blurry imgs 43 | num_blur_raw = opt.blur_num # number of sharp imgs before interpolated per blurry one 44 | num_inter = 7 # number of interpolated imgs between two imgs 45 | num_blur = (num_blur_raw - 1) * (num_inter + 1) + 1 # number of interpolated imgs per blurry one 46 | str_len = len(sub_files[0].split('.')[0]) 47 | imgs = [] 48 | start = 0 49 | for i in trange(len(sub_files)): 50 | if i + bais >= len(sub_files): 51 | break 52 | file_name = sub_files[i + bais] 53 | img = cv2.imread(os.path.join(dirpath,file_name)) 54 | imgs.append(img) 55 | # synthesize the blurry image 56 | if i % (num_blur) == num_blur - 1: 57 | blur_img = np.mean(np.stack(imgs,axis = 0),axis = 0) 58 | end = i + bais 59 | # (num_blur_raw - 1 + (opt.overlap_len + 1) // (num_inter + 1)): number of imgs per blur 60 | # ((i + 1) // num_blur - 1): denotes the order of blur, 0 at first 61 | # num_blur_raw // 2: middle frame 62 | if use_resize: 63 | blur_img = cv2.resize(blur_img,(width // 2,height//2),interpolation=cv2.INTER_LINEAR) 64 | # ! the former denotes the renamed order, while the latter denotes the name in the raw image folder. 65 | print(f"Synthesize blurry {str((num_blur_raw - 1 + (opt.overlap_len + 1) // (num_inter + 1)) * ((i + 1) // num_blur - 1) + num_blur_raw // 2).zfill(str_len)}.png from {sub_files[start]} to {sub_files[end]}") 66 | cv2.imwrite(os.path.join(output_folder,f"{str((num_blur_raw - 1 + (opt.overlap_len + 1) // (num_inter + 1)) * ((i + 1) // num_blur - 1) + num_blur_raw // 2).zfill(str_len)}.png"),blur_img) 67 | imgs = [] 68 | bais += opt.overlap_len 69 | start = i + bais + 1 70 | -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | # todo GOPRO Data 2 | # 1. Download Data 3 | 4 | #? Structure: 5 | #? GOPRO 6 | #? ├── train 7 | #? │ ├── raw_data 8 | #? │ │ ├── GOPR0372_07_00 9 | #? │ │ ├── ... 10 | #? │ │ └── GOPR0884_11_00 11 | #? │ 12 | #? ├── test 13 | #? │ ├── raw_data 14 | #? │ │ ├── GOPR0384_11_00 15 | #? │ │ ├── ... 16 | #? │ │ └── GOPR0881_11_01 17 | 18 | # 2. Interpolate frames 19 | cd XVFI-main/ 20 | python main.py --custom_path ../GOPRO/test/raw_data --gpu 0 --phase test_custom --exp_num 1 --dataset X4K1000FPS --module_scale_factor 4 --S_tst 5 --multiple 8 21 | python main.py --custom_path ../GOPRO/train/raw_data --gpu 0 --phase test_custom --exp_num 1 --dataset X4K1000FPS --module_scale_factor 4 --S_tst 5 --multiple 8 22 | cd .. 23 | 24 | # 3. Synthesize the blurry image 25 | python blur_syn.py --raw_folder GOPRO/test/raw_data --overlap_len 7 --blur_num 13 26 | python blur_syn.py --raw_folder GOPRO/train/raw_data --overlap_len 7 --blur_num 13 27 | 28 | # 4. Simulate the spike 29 | python spike_simulate.py --raw_folder GOPRO/test/raw_data --overlap_len 7 --use_resize --blur_num 13 30 | python spike_simulate.py --raw_folder GOPRO/train/raw_data --overlap_len 7 --use_resize --blur_num 13 31 | 32 | # 5. extract GT from raw_folder 33 | ## single frame 34 | python sharp_extract.py --raw_folder GOPRO/test/raw_data --overlap_len 7 --blur_num 13 35 | python sharp_extract.py --raw_folder GOPRO/train/raw_data --overlap_len 7 --blur_num 13 36 | 37 | ## sequence 38 | # python sharp_extract.py --raw_folder GOPRO/test/raw_data --overlap_len 7 --blur_num 13 --multi 39 | # python sharp_extract.py --raw_folder GOPRO/train/raw_data --overlap_len 7 --blur_num 13 --multi -------------------------------------------------------------------------------- /scripts/sharp_extract.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from tqdm import trange 5 | import argparse 6 | 7 | #? extract the imgs under raw_folder to sharp sequence on sharp_folder 8 | #? Structure as: 9 | #? base_folder 10 | #? ├── blur_folder 11 | #? ├── raw_folder 12 | #? └── spike_folder 13 | 14 | # main function 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--raw_folder', type=str, default='GOPRO/test/raw_data') 18 | parser.add_argument('--img_type', type=str, default='.png') 19 | parser.add_argument('--overlap_len', type=int, default= 7, 20 | help = 'Overlap length between two blurry images. Assume [0-12] -> blur_6, [13,25] -> blur_19, overlap_len is the length of interpolated frames between 12.png and 13.png,i.e.,12_0.png,...12_6.png') 21 | parser.add_argument('--height', type=int, default= 720) 22 | parser.add_argument('--width', type=int, default= 1280) 23 | parser.add_argument('--use_resize', action='store_true', help = 'Resize the image size to the half.') 24 | parser.add_argument('--blur_num', type=int, default= 13,help = 'Number of images before interpolation to synthesize one blurry frame.') 25 | parser.add_argument('--multi', action='store_true',default = False,help= 'extract gt sequence from raw_folder') 26 | 27 | opt = parser.parse_args() 28 | width = opt.width 29 | height = opt.height 30 | use_resize = opt.use_resize 31 | raw_folder = opt.raw_folder 32 | base_folder = os.path.dirname(raw_folder) 33 | blur_folder = os.path.join(base_folder,'blur_data') 34 | sharp_folder = os.path.join(base_folder,'sharp_data') 35 | 36 | for dirpath, sub_dirs, sub_files in os.walk(raw_folder): 37 | if len(sub_files) == 0 or sub_files[0].endswith(opt.img_type) == False: 38 | continue 39 | print(dirpath) 40 | output_folder = dirpath.replace(raw_folder,sharp_folder) 41 | os.makedirs(output_folder,exist_ok = True) 42 | sub_files = sorted(sub_files) 43 | bais = 0 # bais that overlap between two blurry imgs 44 | num_blur_raw = opt.blur_num # number of sharp imgs before interpolated per blurry one 45 | num_inter = 7 # number of interpolated imgs between two imgs 46 | num_blur = num_blur_raw * (num_inter + 1) + 1 # number of interpolated imgs per blurry one 47 | str_len = len(sub_files[0].split('.')[0]) 48 | imgs = [] 49 | idx = 0 50 | loop_bais = 0 51 | for i in trange(len(sub_files)): 52 | if i % (num_inter + 1) == 0: 53 | if idx % num_blur_raw == num_blur_raw // 2 or (opt.multi == True and ((idx - loop_bais) % 12) % 2 == 0): # % 2 is set to extract 7 images 54 | file_name = sub_files[i] 55 | print(file_name) 56 | img = cv2.imread(os.path.join(dirpath,file_name)) 57 | if use_resize: 58 | img = cv2.resize(img,(width // 2,height//2),interpolation=cv2.INTER_LINEAR) 59 | cv2.imwrite(os.path.join(output_folder,f"{str(idx).zfill(str_len)}.png"),img) 60 | if idx % num_blur_raw == num_blur_raw - 1: 61 | loop_bais += 1 62 | idx += 1 63 | -------------------------------------------------------------------------------- /scripts/spike_simulate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from tqdm import trange 5 | from utils import * 6 | from utils_spike import * 7 | import argparse 8 | 9 | #? convert the imgs under raw_folder to spike sequence on spike_folder 10 | #? Structure as: 11 | #? base_folder 12 | #? ├── blur_folder 13 | #? ├── raw_folder 14 | #? └── spike_folder 15 | 16 | # main function 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--raw_folder', type=str, default=r'GOPRO/test/raw_data') 20 | parser.add_argument('--img_type', type=str, default='.png') 21 | parser.add_argument('--overlap_len', type=int, default= 7, 22 | help = 'Overlap length between two blurry images. Assume [0-12] -> blur_6, [13,25] -> blur_19, overlap_len is the length of interpolated frames between 12.png and 13.png,i.e.,12_0.png,...12_6.png') 23 | parser.add_argument('--height', type=int, default= 720) 24 | parser.add_argument('--width', type=int, default= 1280) 25 | parser.add_argument('--use_resize', action='store_true', help = 'Resize the image size to the half.') 26 | parser.add_argument('--blur_num', type=int, default= 13,help = 'Number of images before interpolation to synthesize one blurry frame.') 27 | parser.add_argument('--spike_add', type=int, default= 20,help = 'Additional spike out of the exposure period.') 28 | 29 | opt = parser.parse_args() 30 | width = opt.width 31 | height = opt.height 32 | resize = opt.use_resize 33 | raw_folder = opt.raw_folder 34 | base_folder = os.path.dirname(raw_folder) 35 | spike_folder = os.path.join(base_folder,'spike_data') 36 | os.makedirs(spike_folder,exist_ok = True) 37 | for dirpath, sub_dirs, sub_files in os.walk(raw_folder): 38 | if len(sub_files) == 0 or sub_files[0].endswith(opt.img_type) == False: 39 | continue 40 | print(dirpath) 41 | output_folder = dirpath.replace(raw_folder,spike_folder) 42 | os.makedirs(output_folder,exist_ok = True) 43 | sub_files = sorted(sub_files) 44 | bais = 0 # bais that overlap between two blurry imgs 45 | num_blur_raw = opt.blur_num # number of sharp imgs before interpolated per blurry one 46 | num_inter = 7 # number of interpolated imgs between two imgs 47 | num_blur = (num_blur_raw - 1) * (num_inter + 1) + 1 # number of interpolated imgs per blurry one 48 | spike_add = opt.spike_add # additional spike out of the exposure period 49 | num_omit = 1 # reduce the number of spike sequence to [num/num_omit] 50 | str_len = len(sub_files[0].split('.')[0]) 51 | imgs = [] 52 | start = 0 53 | for i in trange(len(sub_files)): 54 | if i + bais >= len(sub_files): 55 | break 56 | file_name = sub_files[i + bais] 57 | img = cv2.imread(os.path.join(dirpath,file_name)) 58 | img = img.astype(np.float32) / 255 59 | if resize == True: 60 | img = cv2.resize(img,(width // 4,height // 4),interpolation=cv2.INTER_LINEAR) 61 | # GRAY=0.3*R+0.59*G+0.11*B 62 | img = 0.11 * img[...,0] + 0.59 * img[...,1] + 0.3 * img[...,2] 63 | imgs.append(img) 64 | # simulate the spike sequence during the exposure 65 | if i % (num_blur) == num_blur - 1: 66 | end = i + bais 67 | # skip the first blurry image 68 | if start == 0: 69 | imgs = [] 70 | bais += opt.overlap_len 71 | start = i + bais + 1 72 | continue 73 | if end + spike_add >= len(sub_files): 74 | break 75 | # add the spike data out of the exposure period 76 | for jj in range(1,spike_add + 1): 77 | img_start = cv2.imread(os.path.join(dirpath,sub_files[start - jj])) 78 | if resize == True: 79 | img_start = cv2.resize(img_start,(width // 4,height // 4),interpolation=cv2.INTER_LINEAR ) 80 | img_start = img_start.astype(np.float32) / 255 81 | img_start = 0.11 * img_start[...,0] + 0.59 * img_start[...,1] + 0.3 * img_start[...,2] 82 | img_end = cv2.imread(os.path.join(dirpath,sub_files[end + jj])) 83 | img_end = img_end.astype(np.float32) / 255 84 | if resize == True: 85 | img_end = cv2.resize(img_end,(width // 4,height // 4),interpolation=cv2.INTER_LINEAR ) 86 | img_end = 0.11 * img_end[...,0] + 0.59 * img_end[...,1] + 0.3 * img_end[...,2] 87 | imgs.append(img_end) 88 | imgs.insert(0,img_start) 89 | #! reduce the number of spikes to / num_omit 90 | imgs = imgs[::num_omit] 91 | # todo from zjy 92 | spike = SimulationSimple_Video(imgs) 93 | noise = Inherent_Noise_fast_torch(spike.shape[0], H=spike.shape[1], W=spike.shape[2]) 94 | spike = torch.bitwise_or(spike.permute((1,2,0)), noise) 95 | spike = spike.permute((2, 0, 1)) 96 | # todo from csy 97 | # spike = v2s_interface(imgs,threshold=2) 98 | 99 | # save spike 100 | print(spike.shape) 101 | SpikeToRaw(os.path.join(output_folder, str((num_blur_raw - 1 + (opt.overlap_len + 1) // (num_inter + 1)) * ((i + 1) // num_blur - 1) + num_blur_raw // 2).zfill(str_len) + ".dat"), spike) 102 | print(f"Generating spikes from {sub_files[start - spike_add]} to {sub_files[end + spike_add]}") 103 | print(f"Blur area ranges from {sub_files[start]} to {sub_files[end]}") 104 | imgs = [] 105 | bais += opt.overlap_len 106 | start = i + bais + 1 -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | import imageio 5 | # from moviepy.editor import ImageSequenceClip 6 | import os 7 | 8 | 9 | def save_gif(image_list, gif_path = 'test', duration = 2,RGB = True,nor = False): 10 | imgs = [] 11 | os.makedirs('Video',exist_ok = True) 12 | with imageio.get_writer(os.path.join('Video',gif_path + '.gif'), mode='I',duration = 1000 * duration / len(image_list),loop=0) as writer: 13 | for i in range(len(image_list)): 14 | img = normal_img(image_list[i],RGB,nor) 15 | writer.append_data(img) 16 | 17 | def save_video(image_list,path = 'test',duration = 2,RGB = True,nor = False): 18 | os.makedirs('Video',exist_ok = True) 19 | img_size = (image_list[0].shape[1], image_list[0].shape[0]) 20 | fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') 21 | videowriter = cv2.VideoWriter(os.path.join('Video',path + '.avi'), fourcc, len(image_list) / duration, img_size) 22 | for i in range(len(image_list)): 23 | img = normal_img(image_list[i],RGB,nor) 24 | videowriter.write(img) 25 | 26 | 27 | def normal_img(img,RGB = True,nor = True): 28 | if nor: 29 | img = 255 * ((img - img.min()) / (img.max() - img.min())) 30 | if isinstance(img,torch.Tensor): 31 | img = np.array(img.detach().cpu()) 32 | if len(img.shape) == 2: 33 | img = img[...,None] 34 | if img.shape[-1] == 1: 35 | img = np.repeat(img,3,axis = -1) 36 | img = img.astype(np.uint8) 37 | if RGB == False: 38 | img = img[...,::-1] 39 | return img 40 | 41 | def save_img(path = 'test.png',img = None,nor = True): 42 | if nor: 43 | img = 255 * ((img - img.min()) / (img.max() - img.min())) 44 | if isinstance(img,torch.Tensor): 45 | img = np.array(img.detach().cpu()) 46 | img = img.astype(np.uint8) 47 | cv2.imwrite(os.path.join(os.getcwd(),'imgs',path),img) 48 | 49 | def make_folder(path): 50 | os.makedirs(path,exist_ok = True) 51 | 52 | def video_to_spike( 53 | sourefolder=None, 54 | imgs = None, 55 | savefolder_debug=None, 56 | threshold=5.0, 57 | init_noise=True, 58 | format="png", 59 | ): 60 | """ 61 | 函数说明 62 | :param 参数名: 参数说明 63 | :return: 返回值名: 返回值说明 64 | """ 65 | if sourefolder != None: 66 | filelist = sorted(os.listdir(sourefolder)) 67 | datas = [fn for fn in filelist if fn.endswith(format)] 68 | 69 | T = len(datas) 70 | 71 | frame0 = cv2.imread(os.path.join(sourefolder, datas[0])) 72 | H, W, C = frame0.shape 73 | 74 | frame0 = cv2.cvtColor(frame0, cv2.COLOR_BGR2GRAY) 75 | 76 | spikematrix = np.zeros([T, H, W], np.uint8) 77 | 78 | if init_noise: 79 | integral = np.random.random(size=([H,W])) * threshold 80 | else: 81 | integral = np.random.zeros(size=([H,W])) 82 | 83 | Thr = np.ones_like(integral).astype(np.float32) * threshold 84 | 85 | for t in range(0, T): 86 | frame = cv2.imread(os.path.join(sourefolder, datas[t])) 87 | if C > 1: 88 | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 89 | gray = gray / 255.0 90 | integral += gray 91 | fire = (integral - Thr) >= 0 92 | fire_pos = fire.nonzero() 93 | 94 | integral[fire_pos] -= threshold 95 | spikematrix[t][fire_pos] = 1 96 | 97 | if savefolder_debug: 98 | np.save(os.path.join(savefolder_debug, "spike_debug.npy"), spikematrix) 99 | elif imgs != None: 100 | frame0 = imgs[0] 101 | H, W, C = frame0.shape 102 | T = len(imgs) 103 | spikematrix = np.zeros([T, H, W], np.uint8) 104 | 105 | if init_noise: 106 | integral = np.random.random(size=([H,W])) * threshold 107 | else: 108 | integral = np.random.zeros(size=([H,W])) 109 | 110 | Thr = np.ones_like(integral).astype(np.float32) * threshold 111 | 112 | for t in range(0, T): 113 | frame = imgs[t] 114 | if C > 1: 115 | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 116 | gray = gray / 255.0 117 | integral += gray 118 | fire = (integral - Thr) >= 0 119 | fire_pos = fire.nonzero() 120 | 121 | integral[fire_pos] -= threshold 122 | spikematrix[t][fire_pos] = 1 123 | return spikematrix 124 | 125 | 126 | def load_vidar_dat(filename, left_up=(0, 0), window=None, frame_cnt = None, height = 800,width = 800): 127 | if isinstance(filename, str): 128 | array = np.fromfile(filename, dtype=np.uint8) 129 | elif isinstance(filename, (list, tuple)): 130 | l = [] 131 | for name in filename: 132 | a = np.fromfile(name, dtype=np.uint8) 133 | l.append(a) 134 | array = np.concatenate(l) 135 | else: 136 | raise NotImplementedError 137 | if window == None: 138 | window = (height - left_up[0], width - left_up[0]) 139 | len_per_frame = height * width // 8 140 | framecnt = frame_cnt if frame_cnt != None else len(array) // len_per_frame 141 | spikes = [] 142 | for i in range(framecnt): 143 | compr_frame = array[i * len_per_frame: (i + 1) * len_per_frame] 144 | blist = [] 145 | for b in range(8): 146 | blist.append(np.right_shift(np.bitwise_and(compr_frame, np.left_shift(1, b)), b)) 147 | frame_ = np.stack(blist).transpose() 148 | frame_ = np.flipud(frame_.reshape((height, width), order='C')) 149 | if window is not None: 150 | spk = frame_[left_up[0]:left_up[0] + window[0], left_up[1]:left_up[1] + window[1]] 151 | else: 152 | spk = frame_ 153 | spk = torch.from_numpy(spk.copy().astype(np.float32)) 154 | spikes.append(spk) 155 | return torch.stack(spikes,dim = 0) -------------------------------------------------------------------------------- /scripts/utils_spike.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange 2 | import torch 3 | import os 4 | import numpy as np 5 | from PIL import Image 6 | import cv2 7 | import torchvision.transforms as transforms 8 | 9 | 10 | def Inherent_Noise_fast_torch(T, mu=140.0, std=50.0, H=250, W=400): 11 | """ 12 | Generate Gaussian distributed inherent noise. 13 | args: 14 | - T: Simulation time length 15 | - mu: Mean 16 | - std: Standard deviation 17 | return: 18 | - The array records the location of noise 19 | """ 20 | shape = [H, W, T] 21 | size = H * W * T 22 | gaussian = torch.normal(mu, std, size=(size,)) 23 | noise = torch.zeros(size, dtype=torch.int16) 24 | keys = torch.cumsum(gaussian, dim=0) 25 | keys = keys[keys Vth 63 | syn_rec[n][vol_to_Thr] = 1 64 | vol[vol_to_Thr] = vol[vol_to_Thr] % Vth 65 | # for t in range(250): 66 | # # print(g.shape, noise_1[t].shape) 67 | # vol += g * K 68 | # vol_to_Thr = vol >= Vth 69 | # syn_rec[n][vol_to_Thr] = 1 70 | # vol[vol_to_Thr] = 0 71 | return syn_rec 72 | 73 | def v2s_interface(imgs, savefolder=None, threshold=5.0): 74 | T = len(imgs) 75 | # frame0[..., 0:2] = 0 76 | # cv2.imshow('red', frame0) 77 | # cv2.waitKey(0) 78 | H, W = imgs[0].shape 79 | # exit(0) 80 | 81 | spikematrix = np.zeros([T, H, W], np.uint8) 82 | # integral = np.array(frame0gray).astype(np.float32) 83 | integral = np.random.random(size=([H,W])) * threshold 84 | Thr = np.ones_like(integral).astype(np.float32) * threshold 85 | 86 | for t in range(0, T): 87 | # print('spike frame %s' % datas[t]) 88 | frame = imgs[t] 89 | # gray = cv2.resize(frame, None, fx=0.25, fy=0.25, interpolation=cv2.INTER_AREA) 90 | gray = frame 91 | integral += gray 92 | fire = (integral - Thr) >= 0 93 | fire_pos = fire.nonzero() 94 | integral[fire_pos] -= threshold 95 | # integral[fire_pos] = 0.0 96 | spikematrix[t][fire_pos] = 1 97 | return spikematrix 98 | 99 | def SimulationSimple_Video(I): 100 | # Initialize Sensor Parameters 101 | Vth = 2.0 102 | Eta = 10**(-13)*1.09 103 | Lambda = 10**(-4)*1.83 104 | Cpd = 10.0**(-15)*15 105 | CLK = 10.0**(6)*10 106 | delta_t = 2 / CLK 107 | K = delta_t * Eta / (Lambda * Cpd) 108 | # print(K) 109 | # vol = 0 110 | T = len(I) 111 | H, W = I[0].shape 112 | vol = torch.zeros(size=(H,W)) 113 | syn_rec = torch.zeros(size=(T+1,H,W), dtype=torch.int16) 114 | 115 | for n in range(T+1): 116 | g = torch.rand(size=(H,W)) if n == 0 else I[n-1] 117 | # g = I 118 | # print(noise_1.shape) 119 | 120 | # print(g_all[:, 100,100]) 121 | # flag, vol = inner_clock(g * K, Vth, vol) 122 | # print(torch.max(flag), torch.max(vol)) 123 | # syn_rec[n] = flag 124 | vol += g * K * 250 125 | vol_to_Thr = vol > Vth 126 | syn_rec[n][vol_to_Thr] = 1 127 | vol[vol_to_Thr] = vol[vol_to_Thr] % Vth 128 | # for t in range(250): 129 | # # print(g.shape, noise_1[t].shape) 130 | # vol += g * K 131 | # vol_to_Thr = vol >= Vth 132 | # syn_rec[n][vol_to_Thr] = 1 133 | # vol[vol_to_Thr] = 0 134 | return syn_rec[1:] 135 | 136 | 137 | def SpikeToRaw(save_path, SpikeSeq, filpud=True, delete_if_exists=True): 138 | """ 139 | save spike sequence to .dat file 140 | save_path: full saving path (string) 141 | SpikeSeq: Numpy array (T x H x W) 142 | Rui Zhao 143 | """ 144 | if delete_if_exists: 145 | if os.path.exists(save_path): 146 | os.remove(save_path) 147 | 148 | sfn, h, w = SpikeSeq.shape 149 | remainder = int((h * w) % 8) 150 | # assert (h * w) % 8 == 0 151 | base = np.power(2, np.linspace(0, 7, 8)) 152 | fid = open(save_path, 'ab') 153 | for img_id in range(sfn): 154 | if filpud: 155 | # 模拟相机的倒像 156 | spike = np.flipud(SpikeSeq[img_id, :, :]) 157 | else: 158 | spike = SpikeSeq[img_id, :, :] 159 | # numpy按自动按行排,数据也是按行存的 160 | # spike = spike.flatten() 161 | if remainder == 0: 162 | spike = spike.flatten() 163 | else: 164 | spike = np.concatenate([spike.flatten(), np.array([0]*(8-remainder))]) 165 | spike = spike.reshape([int(h*w/8), 8]) 166 | data = spike * base 167 | data = np.sum(data, axis=1).astype(np.uint8) 168 | fid.write(data.tobytes()) 169 | fid.close() 170 | return 171 | 172 | def load_vidar_dat(filename, frame_cnt=None, width=640, height=480, reverse_spike=True): 173 | ''' 174 | output: (frame_cnt, height, width) {0,1} float32 175 | ''' 176 | array = np.fromfile(filename, dtype=np.uint8) 177 | 178 | len_per_frame = height * width // 8 179 | framecnt = frame_cnt if frame_cnt != None else len(array) // len_per_frame 180 | 181 | spikes = [] 182 | for i in range(framecnt): 183 | compr_frame = array[i * len_per_frame: (i + 1) * len_per_frame] 184 | blist = [] 185 | for b in range(8): 186 | blist.append(np.right_shift(np.bitwise_and( 187 | compr_frame, np.left_shift(1, b)), b)) 188 | 189 | frame_ = np.stack(blist).transpose() 190 | frame_ = frame_.reshape((height, width), order='C') 191 | if reverse_spike: 192 | frame_ = np.flipud(frame_) 193 | spikes.append(frame_) 194 | 195 | return np.array(spikes).astype(np.float32) 196 | 197 | if __name__ == "__main__": 198 | # main() 199 | imgs = np.ones((100,300,300)) 200 | print(v2s_interface(imgs).shape) 201 | # spike = load_vidar_dat('I:\\Datasets\\REDS\\train\\train_spike\\000\\00000007.dat', width=260, height=640) 202 | # cv2.imwrite('test_spike_load.png', np.mean(spike, axis=0)[..., None] * 255) 203 | -------------------------------------------------------------------------------- /train_bsn.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import sys 7 | sys.path.append("..") 8 | import os 9 | import argparse 10 | from datetime import datetime 11 | from torch.optim import Adam 12 | from torchvision import transforms 13 | from torch.utils.tensorboard import SummaryWriter 14 | from tqdm import tqdm 15 | import shutil 16 | from torch.optim.lr_scheduler import CosineAnnealingLR,MultiStepLR 17 | from codes.utils import * 18 | from codes.dataset import * 19 | from codes.metrics import compute_img_metric 20 | from codes.model.bsn_model import BSN 21 | 22 | 23 | # SDM model 24 | def sdm_deblur(blur,spike,spike_idx): 25 | """ Basic SDM model. 26 | 27 | Args: 28 | blur (tensor): blurry input. [bs,3,w,h] 29 | spike (tensor): spike sequence. [bs,137,w,h] 30 | spike_idx (int): central idx of the short-exposure spike stream 31 | 32 | Returns: 33 | rgb_sdm: deblur result. [bs,3,w,h] 34 | """ 35 | global spike_bsn_len 36 | spike_sum = torch.sum(spike[:,20:-20],dim = 1,keepdim = True) 37 | spike_bsn = spike[:,spike_idx - spike_bsn_len // 2:spike_idx + spike_bsn_len // 2 + 1,:,:] 38 | rgb_sdm = blur / spike_sum 39 | rgb_sdm[spike_sum.repeat(1,3,1,1) == 0] = 0 40 | rgb_sdm = rgb_sdm * torch.sum(spike_bsn,dim = 1,keepdim = True) * 97 / spike_bsn_len 41 | rgb_sdm = rgb_sdm.clip(0,1) 42 | return rgb_sdm 43 | 44 | # TFP model 45 | def cal_tfp(spike,spike_idx,tfp_len): 46 | """TPF Model 47 | 48 | Args: 49 | spike (tensor): spike sequence. [bs,137,w,h] 50 | spike_idx (int): central idx of the virtual exposure window 51 | tfp_len (_type_): length of the virtual exposure window. 97 for long-TFP, [7,9,11] for short-TFP. 52 | 53 | Returns: 54 | tfp_pred: tfp result 55 | """ 56 | spike = spike[:,spike_idx - tfp_len // 2:spike_idx + tfp_len // 2 + 1,:,:] 57 | tfp_pred = torch.mean(spike,dim = 1,keepdim = True) 58 | return tfp_pred 59 | 60 | if __name__ == '__main__': 61 | # parameters 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--base_folder', type=str,default=r'GOPRO/',help = 'base folder of the GOPRO dataset') 64 | parser.add_argument('--save_folder', type=str,default='exp/BSN', help = 'experimental results save folder') 65 | parser.add_argument('--data_type',type=str, default='GOPRO' ,help = 'dataset type') 66 | parser.add_argument('--exp_name', type=str,default='test', help = 'experiment name') 67 | parser.add_argument('--epochs', type=int, default=1001) 68 | parser.add_argument('--lr', type=float, default=3e-4) 69 | parser.add_argument('--seed', type=int, default=42) 70 | parser.add_argument('--bsn_len', type=int, default=9, help = 'spike length for BSN input TFP image') 71 | parser.add_argument('--width', type=int, default=1280) 72 | parser.add_argument('--height', type=int, default=720) 73 | parser.add_argument('--use_small', action='store_true',default = False,help='train at the small GOPRO dataset for debugging') 74 | parser.add_argument('--spike_full', action='store_true',default = False,help='train BSN under high resolution spike stream (1280 * 720)') 75 | parser.add_argument('--test_mode', action='store_true',default = False, help='test the metric') 76 | parser.add_argument('--bsn_path', type=str,default='model/BSN_1000.pth') 77 | opt = parser.parse_args() 78 | 79 | # prepare 80 | ckpt_folder = f"{opt.save_folder}/{opt.exp_name}/ckpts" 81 | img_folder = f"{opt.save_folder}/{opt.exp_name}/imgs" 82 | os.makedirs(ckpt_folder,exist_ok= True) 83 | os.makedirs(img_folder,exist_ok= True) 84 | set_random_seed(opt.seed) 85 | save_opt(opt,f"{opt.save_folder}/{opt.exp_name}/opt.txt") 86 | log_file = f"{opt.save_folder}/{opt.exp_name}/results.txt" 87 | logger = setup_logging(log_file) 88 | if os.path.exists(f'{opt.save_folder}/{opt.exp_name}/tensorboard'): 89 | shutil.rmtree(f'{opt.save_folder}/{opt.exp_name}/tensorboard') 90 | writer = SummaryWriter(f'{opt.save_folder}/{opt.exp_name}/tensorboard') 91 | resize_method = transforms.Resize((opt.height // 4,opt.width // 4),interpolation=transforms.InterpolationMode.BILINEAR) 92 | logger.info(opt) 93 | 94 | # train and test data splitting 95 | train_dataset = SpikeData(opt.base_folder,opt.data_type,'train',use_roi = True, 96 | roi_size = [128 * 4,128 * 4],use_small = opt.use_small,spike_full = opt.spike_full) 97 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True,num_workers=4,pin_memory=True) 98 | test_dataset = SpikeData(opt.base_folder,opt.data_type,'test',use_roi = False, 99 | use_small = opt.use_small,spike_full = opt.spike_full) 100 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False,num_workers=1,pin_memory=True) 101 | 102 | # config for network and training parameters 103 | bsn_teacher = BSN(n_channels = 1,n_output = 1).cuda() 104 | if opt.test_mode: 105 | bsn_teacher.load_state_dict(torch.load(opt.bsn_path)) 106 | optim = Adam(bsn_teacher.parameters(), lr=opt.lr) 107 | scheduler = CosineAnnealingLR(optim, T_max=opt.epochs) 108 | 109 | criterion = nn.MSELoss() 110 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 111 | spike_bsn_len = opt.bsn_len 112 | # -------------------- train ---------------------- 113 | train_start = datetime.now() 114 | logger.info("Start Training!") 115 | for epoch in range(opt.epochs): 116 | if opt.test_mode == False: 117 | train_loss = AverageMeter() 118 | for batch_idx, (blur,spike,sharp) in enumerate(tqdm(train_loader)): 119 | # read the data 120 | blur,spike = blur.to(device),spike.to(device) 121 | # reconstruct the initial result 122 | for spike_idx in range(20,117,3 * 8): 123 | # TFP Part 124 | tfp = cal_tfp(spike,spike_idx,spike_bsn_len) 125 | tfp_bsn = bsn_teacher(tfp).clip(0,1) 126 | loss = criterion(tfp,tfp_bsn) 127 | optim.zero_grad() 128 | loss.backward() 129 | optim.step() 130 | writer.add_scalar('Training Loss', loss.item()) 131 | train_loss.update(loss.item()) 132 | logger.info(f"EPOCH {epoch}/{opt.epochs}: Train Loss: {train_loss.avg}") 133 | writer.add_scalar('Epoch Loss', train_loss.avg, epoch) 134 | scheduler.step() 135 | # visualization result 136 | if epoch % 100 == 0: 137 | with torch.no_grad(): 138 | # save the network 139 | save_network(bsn_teacher, f"{ckpt_folder}/BSN_{epoch}.pth") 140 | # visualization 141 | for batch_idx, (blur,spike,sharp) in enumerate(tqdm(test_loader)): 142 | if batch_idx in [int(i) for i in np.linspace(0,len(test_loader),5)]: 143 | blur,spike = blur.to(device),spike.to(device) 144 | spike_idx = len(spike[0]) // 2 145 | # TFP Part 146 | tfp = cal_tfp(spike,spike_idx,spike_bsn_len) 147 | tfp_bsn = bsn_teacher(tfp).clip(0,1) 148 | # visualization 149 | save_img(img = normal_img(blur[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_blur.png') 150 | save_img(img = normal_img(sharp[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_sharp.png') 151 | save_img(img = normal_img(tfp[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_tfp.png') 152 | save_img(img = normal_img(tfp_bsn[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_tfp_bsn.png') 153 | else: 154 | continue 155 | # save metric result 156 | if epoch % 100 == 0: 157 | with torch.no_grad(): 158 | # calculate the metric 159 | metrics = {} 160 | method_list = ['TFP','BSN'] 161 | metric_list = ['mse','ssim','psnr','lpips'] 162 | for method_name in method_list: 163 | metrics[method_name] = {} # 初始化每个方法的字典 164 | for metric_name in metric_list: 165 | metrics[method_name][metric_name] = AverageMeter() 166 | for batch_idx, (blur,spike,sharp) in enumerate(tqdm(test_loader)): 167 | blur,spike = blur.to(device),spike.to(device) 168 | sharp = 0.11 * sharp[:,0:1] + 0.59 * sharp[:,1:2] + 0.3 * sharp[:,2:3] 169 | sharp = resize_method(sharp) 170 | spike_idx = len(spike[0]) // 2 171 | # TFP and BSN 172 | tfp = cal_tfp(spike,spike_idx,spike_bsn_len) 173 | tfp_bsn = bsn_teacher(tfp).clip(0,1) 174 | # Metric 175 | for key in metric_list : 176 | metrics['TFP'][key].update(compute_img_metric(tfp,sharp,key)) 177 | metrics['BSN'][key].update(compute_img_metric(tfp_bsn,sharp,key)) 178 | # Print all results 179 | for method_name in method_list: 180 | re_msg = '' 181 | for metric_name in metric_list: 182 | re_msg += metric_name + ": " + "{:.3f}".format(metrics[method_name][metric_name].avg) + " " 183 | logger.info(f"{method_name}: " + re_msg) 184 | writer.add_scalar(f'{method_name}/{metric_name}', metrics[method_name][metric_name].avg, epoch) 185 | # test mode 186 | if opt.test_mode: 187 | break 188 | writer.close() 189 | 190 | -------------------------------------------------------------------------------- /train_deblur.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import sys 7 | sys.path.append("..") 8 | import os 9 | import argparse 10 | from datetime import datetime 11 | from torch.optim import Adam 12 | from codes.utils import * 13 | from codes.metrics import compute_img_metric 14 | import glob 15 | from torch.utils.tensorboard import SummaryWriter 16 | from tqdm import tqdm 17 | import shutil 18 | from codes.dataset import * 19 | from codes.model.ldn_model import Deblur_Net 20 | from codes.model.bsn_model import BSN 21 | from codes.model.edsr_model import EDSR 22 | import lpips 23 | from pytorch_msssim import SSIM 24 | from torch.optim.lr_scheduler import CosineAnnealingLR 25 | 26 | # SDM model 27 | def sdm_deblur(blur,tfp_long,tfp_short): 28 | """ General SDM model. 29 | 30 | Args: 31 | blur (tensor): blurry input. [bs,3,w,h] 32 | tfp_long (tensor): long tfp corresponding to the blurry input. [bs,1,w,h] 33 | tfp_short (tensor): short tfp corresponding to the short-exposure image. [bs,1,w,h] 34 | 35 | Returns: 36 | rgb_sdm: deblur result. [bs,3,w,h] 37 | """ 38 | tfp_long = tfp_long.repeat(1,3,1,1) 39 | tfp_short = tfp_short.repeat(1,3,1,1) 40 | rgb_sdm = blur / tfp_long 41 | rgb_sdm[tfp_long == 0] = 0 42 | rgb_sdm = rgb_sdm * tfp_short 43 | return rgb_sdm 44 | 45 | 46 | # TFP model 47 | def cal_tfp(spike,spike_idx,tfp_len): 48 | """TPF Model 49 | 50 | Args: 51 | spike (tensor): spike sequence. [bs,137,w,h] 52 | spike_idx (int): central idx of the virtual exposure window 53 | tfp_len (_type_): length of the virtual exposure window. 97 for long-TFP, [7,9,11] for short-TFP. 54 | 55 | Returns: 56 | tfp_pred: tfp result 57 | """ 58 | spike = spike[:,spike_idx - tfp_len // 2:spike_idx + tfp_len // 2 + 1,:,:] 59 | tfp_pred = torch.mean(spike,dim = 1,keepdim = True) 60 | return tfp_pred 61 | 62 | if __name__ == '__main__': 63 | # parameters 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('--base_folder', default='GOPRO/') 66 | parser.add_argument('--save_folder', default='exp/Deblur') 67 | parser.add_argument('--data_type', default='GOPRO') 68 | parser.add_argument('--exp_name', default='test') 69 | parser.add_argument('--bsn_path', default='model/BSN_1000.pth') 70 | parser.add_argument('--sr_path', default='model/SR_70.pth') 71 | parser.add_argument('--deblur_path', default='model/DeblurNet_100.pth') 72 | parser.add_argument('--epochs', type=int, default=101) 73 | parser.add_argument('--lr', type=float, default=1e-3) 74 | parser.add_argument('--seed', type=int, default=42) 75 | parser.add_argument('--spike_deblur_len', type=int, default=21) 76 | parser.add_argument('--spike_bsn_len', type=int, default=9) 77 | parser.add_argument('--lambda_tea', type=float, default=1) 78 | parser.add_argument('--lambda_reblur', type=float, default=100) 79 | parser.add_argument('--blur_step', type=int, default=24) 80 | parser.add_argument('--use_small', action='store_true',default = False) 81 | parser.add_argument('--test_mode', action='store_true',default = False) 82 | parser.add_argument('--use_ssim', action='store_true',default = False, help= 'use ssim loss or not') 83 | parser.add_argument('--bs', type=int, default=4) 84 | parser.add_argument('--roi_size', type=int, default= 512) 85 | opt = parser.parse_args() 86 | 87 | # prepare 88 | ckpt_folder = f"{opt.save_folder}/{opt.exp_name}/ckpts" 89 | img_folder = f"{opt.save_folder}/{opt.exp_name}/imgs" 90 | os.makedirs(ckpt_folder,exist_ok= True) 91 | os.makedirs(img_folder,exist_ok= True) 92 | set_random_seed(opt.seed) 93 | save_opt(opt,f"{opt.save_folder}/{opt.exp_name}/opt.txt") 94 | log_file = f"{opt.save_folder}/{opt.exp_name}/results.txt" 95 | logger = setup_logging(log_file) 96 | if os.path.exists(f'{opt.save_folder}/{opt.exp_name}/tensorboard'): 97 | shutil.rmtree(f'{opt.save_folder}/{opt.exp_name}/tensorboard') 98 | writer = SummaryWriter(f'{opt.save_folder}/{opt.exp_name}/tensorboard') 99 | logger.info(opt) 100 | 101 | # train and test data splitting 102 | train_dataset = SpikeData(opt.base_folder,opt.data_type,'test',use_roi = True, roi_size = [opt.roi_size,opt.roi_size],use_small= opt.use_small) 103 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True,num_workers=4,pin_memory=True) 104 | test_dataset = SpikeData(opt.base_folder,opt.data_type,'test',use_roi = False,use_small= opt.use_small) 105 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False,num_workers=1,pin_memory=True) 106 | 107 | # config for network and training parameters 108 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 109 | # BSN 110 | bsn_net = BSN(n_channels = 1,n_output = 1).to(device) 111 | bsn_net.load_state_dict(torch.load(opt.bsn_path)) 112 | for param in bsn_net.parameters(): 113 | param.requires_grad = False 114 | # SR 115 | sr_net = EDSR(color_num = 1).to(device) 116 | sr_net.load_state_dict(torch.load(opt.sr_path)) 117 | for param in sr_net.parameters(): 118 | param.requires_grad = False 119 | # Deblur 120 | spike_bsn_len = opt.spike_bsn_len 121 | spike_deblur_len = opt.spike_deblur_len 122 | deblur_net = Deblur_Net(spike_dim = spike_deblur_len).to(device) 123 | if opt.test_mode == True: 124 | deblur_net.load_state_dict(torch.load(opt.deblur_path)) 125 | # other settting 126 | optim = Adam(deblur_net.parameters(), lr=opt.lr) 127 | scheduler = CosineAnnealingLR(optim, T_max=opt.epochs) 128 | loss_lpips = lpips.LPIPS(net='vgg').to(device) 129 | loss_ssim = SSIM(data_range=1.0, size_average=True, channel=3).to(device) 130 | loss_mse = nn.MSELoss() 131 | # -------------------- train ---------------------- 132 | train_start = datetime.now() 133 | logger.info("Start Training!") 134 | for epoch in range(opt.epochs): 135 | # loss definition 136 | train_loss_all = AverageMeter() 137 | tea_loss_all = AverageMeter() 138 | reblur_loss_all = AverageMeter() 139 | if opt.test_mode == False: 140 | for batch_idx, (blur,spike,sharp) in enumerate(tqdm(train_loader)): 141 | # read the data 142 | blur,spike = blur.to(device),spike.to(device) 143 | # reconstruct the initial result 144 | train_loss = 0 145 | tea_loss = 0 146 | reblur_loss = 0 147 | reblur = [] 148 | spike_start,start_end,spike_step = 20,117,opt.blur_step 149 | spike_num = (start_end - spike_start - 1) / spike_step + 1 150 | for spike_idx in range(spike_start,start_end,spike_step): 151 | # Long-TFP Part 152 | tfp_long = cal_tfp(spike,len(spike[0]) // 2 ,97) 153 | tfp_long_sr = sr_net(tfp_long).clip(0,1) 154 | # Short-TFP Part 155 | tfp = cal_tfp(spike,spike_idx,spike_bsn_len) 156 | tfp_bsn = bsn_net(tfp).clip(0,1) 157 | tfp_bsn_sr = sr_net(tfp_bsn).clip(0,1) 158 | deblur_tea = sdm_deblur(blur,tfp_long_sr,tfp_bsn_sr).clip(0,1) 159 | # Deblur_Net Part 160 | spike_roi = spike[:,spike_idx - spike_deblur_len // 2:spike_idx + spike_deblur_len // 2 + 1] 161 | tfp_long = cal_tfp(spike,len(spike[0]) // 2 ,97) 162 | deblur_pred = deblur_net(blur,spike_roi) 163 | deblur_pred = deblur_pred.clip(0,1) 164 | # Loss 165 | if opt.use_ssim: 166 | tea_loss += opt.lambda_tea * (1 - loss_ssim(deblur_tea,deblur_pred)) / spike_num 167 | tea_loss += opt.lambda_tea * torch.mean(loss_lpips(deblur_tea,deblur_pred) ) / spike_num 168 | reblur.append(deblur_pred) 169 | reblur = torch.mean(torch.stack(reblur,dim = 0),dim = 0,keepdim = False) 170 | reblur_loss += opt.lambda_reblur * loss_mse(blur,reblur) 171 | train_loss = tea_loss + reblur_loss 172 | # optimize 173 | optim.zero_grad() 174 | train_loss.backward() 175 | optim.step() 176 | scheduler.step() 177 | writer.add_scalar('Training Loss', train_loss.item()) 178 | # update 179 | train_loss_all.update(train_loss.item()) 180 | tea_loss_all.update(tea_loss.item()) 181 | reblur_loss_all.update(reblur_loss.item()) 182 | logger.info(f"EPOCH {epoch}/{opt.epochs}: Total Train Loss: {train_loss_all.avg}, Tea Loss: {tea_loss_all.avg}, Reblur Loss: {reblur_loss_all.avg}") 183 | writer.add_scalar('Epoch Loss', train_loss_all.avg, epoch) 184 | # visualization result 185 | if epoch % 5 == 0: 186 | with torch.no_grad(): 187 | # save the network 188 | save_network(deblur_net, f"{ckpt_folder}/DeblurNet_{epoch}.pth") 189 | # visualization 190 | for batch_idx, (blur,spike,sharp) in enumerate(tqdm(test_loader)): 191 | blur,spike = blur.to(device),spike.to(device) 192 | spike_idx = len(spike[0]) // 2 193 | # Long-TFP Part 194 | tfp_long = cal_tfp(spike,len(spike[0]) // 2 ,97) 195 | tfp_long_sr = sr_net(tfp_long).clip(0,1) 196 | # Short-TFP Part 197 | tfp = cal_tfp(spike,spike_idx,spike_bsn_len) 198 | tfp_bsn = bsn_net(tfp).clip(0,1) 199 | tfp_bsn_sr = sr_net(tfp_bsn).clip(0,1) 200 | deblur_tea = sdm_deblur(blur,tfp_long_sr,tfp_bsn_sr).clip(0,1) 201 | # Deblur_Net Part 202 | spike_roi = spike[:,spike_idx - spike_deblur_len // 2:spike_idx + spike_deblur_len // 2 + 1] 203 | deblur_pred = deblur_net(blur,spike_roi) 204 | deblur_pred = deblur_pred.clip(0,1) 205 | # visualization 206 | if batch_idx in [int(i) for i in np.linspace(0,len(test_loader),5)]: 207 | save_img(img = normal_img(deblur_tea[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_tea.png') 208 | save_img(img = normal_img(blur[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_blur.png') 209 | save_img(img = normal_img(sharp[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_sharp.png') 210 | save_img(img = normal_img(deblur_pred[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_deblur.png') 211 | 212 | # save metric result 213 | if epoch % 50 == 0: 214 | with torch.no_grad(): 215 | # calculate the metric 216 | metrics = {} 217 | method_list = ['SDM','Deblur_Net'] 218 | # metric_list = ['mse','ssim','psnr','lpips'] 219 | metric_list = ['ssim','psnr'] 220 | for method_name in method_list: 221 | metrics[method_name] = {} # 初始化每个方法的字典 222 | for metric_name in metric_list: 223 | metrics[method_name][metric_name] = AverageMeter() 224 | for batch_idx, (blur,spike,sharp) in enumerate(tqdm(test_loader)): 225 | blur,spike = blur.to(device),spike.to(device) 226 | # Long-TFP Part 227 | tfp_long = cal_tfp(spike,len(spike[0]) // 2 ,97) 228 | tfp_long_sr = sr_net(tfp_long).clip(0,1) 229 | # Short-TFP Part 230 | tfp = cal_tfp(spike,spike_idx,spike_bsn_len) 231 | tfp_bsn = bsn_net(tfp).clip(0,1) 232 | tfp_bsn_sr = sr_net(tfp_bsn).clip(0,1) 233 | deblur_tea = sdm_deblur(blur,tfp_long_sr,tfp_bsn_sr).clip(0,1) 234 | # Deblur_Net Part 235 | spike_roi = spike[:,spike_idx - spike_deblur_len // 2:spike_idx + spike_deblur_len // 2 + 1] 236 | tfp_long = cal_tfp(spike,len(spike[0]) // 2 ,97) 237 | deblur_pred = deblur_net(blur,spike_roi) 238 | deblur_pred = deblur_pred.clip(0,1) 239 | # Metric 240 | for key in metric_list : 241 | metrics['SDM'][key].update(compute_img_metric(deblur_tea,sharp,key)) 242 | metrics['Deblur_Net'][key].update(compute_img_metric(deblur_pred,sharp,key)) 243 | # Print all results 244 | for method_name in method_list: 245 | re_msg = '' 246 | for metric_name in metric_list: 247 | re_msg += metric_name + ": " + "{:.3f}".format(metrics[method_name][metric_name].avg) + " " 248 | logger.info(f"{method_name}: " + re_msg) 249 | writer.add_scalar(f'{method_name}/{metric_name}', metrics[method_name][metric_name].avg, epoch) 250 | writer.close() 251 | -------------------------------------------------------------------------------- /train_sr.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import sys 7 | sys.path.append("..") 8 | import os 9 | import argparse 10 | from datetime import datetime 11 | from torch.optim import Adam 12 | from torchvision import transforms 13 | from torch.utils.tensorboard import SummaryWriter 14 | from tqdm import tqdm 15 | import shutil 16 | 17 | from codes.dataset import * 18 | from codes.utils import * 19 | from codes.metrics import compute_img_metric 20 | from codes.model.bsn_model import BSN 21 | from codes.model.edsr_model import EDSR 22 | 23 | 24 | # SDM model 25 | def sdm_deblur(blur,tfp_long,tfp_short): 26 | """ General SDM model. 27 | 28 | Args: 29 | blur (tensor): blurry input. [bs,3,w,h] 30 | tfp_long (tensor): long tfp corresponding to the blurry input. [bs,1,w,h] 31 | tfp_short (tensor): short tfp corresponding to the short-exposure image. [bs,1,w,h] 32 | 33 | Returns: 34 | rgb_sdm: deblur result. [bs,3,w,h] 35 | """ 36 | tfp_long = tfp_long.repeat(1,3,1,1) 37 | tfp_short = tfp_short.repeat(1,3,1,1) 38 | rgb_sdm = blur / tfp_long 39 | rgb_sdm[tfp_long == 0] = 0 40 | rgb_sdm = rgb_sdm * tfp_short 41 | return rgb_sdm 42 | 43 | # TFP model 44 | def cal_tfp(spike,spike_idx,tfp_len): 45 | """TPF Model 46 | 47 | Args: 48 | spike (tensor): spike sequence. [bs,137,w,h] 49 | spike_idx (int): central idx of the virtual exposure window 50 | tfp_len (_type_): length of the virtual exposure window. 97 for long-TFP, [7,9,11] for short-TFP. 51 | 52 | Returns: 53 | tfp_pred: tfp result 54 | """ 55 | spike = spike[:,spike_idx - tfp_len // 2:spike_idx + tfp_len // 2 + 1,:,:] 56 | tfp_pred = torch.mean(spike,dim = 1,keepdim = True) 57 | return tfp_pred 58 | 59 | # main function 60 | if __name__ == '__main__': 61 | # parameters 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--base_folder', type=str,default=r'GOPRO') 64 | parser.add_argument('--save_folder', type=str,default='exp/SR') 65 | parser.add_argument('--data_type',type=str, default='GOPRO') 66 | parser.add_argument('--exp_name', type=str,default='test') 67 | parser.add_argument('--bsn_path', type=str,default='model/BSN_1000.pth') 68 | parser.add_argument('--sr_path', type=str,default='model/SR_70.pth') 69 | parser.add_argument('--epochs', type=int, default=101) 70 | parser.add_argument('--lr', type=float, default=2e-4) 71 | parser.add_argument('--seed', type=int, default=42) 72 | parser.add_argument('--bsn_len', type=int, default=9) 73 | parser.add_argument('--use_small', action='store_true',default = False) 74 | parser.add_argument('--test_mode', action='store_true',default = False) 75 | opt = parser.parse_args() 76 | 77 | # prepare 78 | ckpt_folder = f"{opt.save_folder}/{opt.exp_name}/ckpts" 79 | img_folder = f"{opt.save_folder}/{opt.exp_name}/imgs" 80 | os.makedirs(ckpt_folder,exist_ok= True) 81 | os.makedirs(img_folder,exist_ok= True) 82 | set_random_seed(opt.seed) 83 | save_opt(opt,f"{opt.save_folder}/{opt.exp_name}/opt.txt") 84 | log_file = f"{opt.save_folder}/{opt.exp_name}/results.txt" 85 | logger = setup_logging(log_file) 86 | if os.path.exists(f'{opt.save_folder}/{opt.exp_name}/tensorboard'): 87 | shutil.rmtree(f'{opt.save_folder}/{opt.exp_name}/tensorboard') 88 | writer = SummaryWriter(f'{opt.save_folder}/{opt.exp_name}/tensorboard') 89 | logger.info(opt) 90 | 91 | # train and test data splitting 92 | train_dataset = SpikeData(opt.base_folder,opt.data_type,'train',use_roi = True, 93 | roi_size = [128 * 4,128 * 4],use_small = opt.use_small) 94 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True,num_workers=4,pin_memory=True) 95 | test_dataset = SpikeData(opt.base_folder,opt.data_type,'test',use_roi = False,use_small = opt.use_small) 96 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False,num_workers=1,pin_memory=True) 97 | 98 | # config for network and training parameters 99 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 100 | sr_net = EDSR(color_num = 1).to(device) 101 | if opt.test_mode: 102 | sr_net.load_state_dict(torch.load(opt.sr_path)) 103 | optim = Adam(sr_net.parameters(), lr=opt.lr) 104 | bsn_net = BSN(n_channels=1, n_output=1).to(device) 105 | bsn_net.load_state_dict(torch.load(opt.bsn_path)) 106 | for param in bsn_net.parameters(): 107 | param.requires_grad = False 108 | criterion = nn.MSELoss() 109 | spike_bsn_len = opt.bsn_len 110 | resize_method = transforms.Resize((720,1280),interpolation=transforms.InterpolationMode.NEAREST) 111 | # -------------------- train ---------------------- 112 | train_start = datetime.now() 113 | logger.info("Start Training!") 114 | for epoch in range(opt.epochs): 115 | train_loss = AverageMeter() 116 | if opt.test_mode == False: 117 | for batch_idx, (blur,spike,sharp) in enumerate(tqdm(train_loader)): 118 | # read the data 119 | blur = 0.11 * blur[:,0:1] + 0.59 * blur[:,1:2] + 0.3 * blur[:,2:3] 120 | blur,spike = blur.to(device),spike.to(device) 121 | for spike_idx in [len(spike[0]) // 2]: 122 | # Long-TFP Part 123 | tfp_long = cal_tfp(spike,spike_idx,97) 124 | sr_tfp = sr_net(tfp_long).clip(0,1) 125 | loss = criterion(sr_tfp,blur) 126 | optim.zero_grad() 127 | loss.backward() 128 | optim.step() 129 | writer.add_scalar('Training Loss', loss.item()) 130 | train_loss.update(loss.item()) 131 | logger.info(f"EPOCH {epoch}/{opt.epochs}: Train Loss: {train_loss.avg}") 132 | writer.add_scalar('Epoch Loss', train_loss.avg, epoch) 133 | # visualization result 134 | if epoch % 5 == 0: 135 | with torch.no_grad(): 136 | # save the network 137 | save_network(sr_net, f"{ckpt_folder}/SR_{epoch}.pth") 138 | # visualization 139 | for batch_idx, (blur,spike,sharp) in enumerate(tqdm(test_loader)): 140 | if batch_idx in [i for i in range(0,len(test_loader),200)]: 141 | blur,spike = blur.to(device),spike.to(device) 142 | spike_idx = len(spike[0]) // 2 143 | # Long-TFP Part 144 | tfp_long = cal_tfp(spike,spike_idx,97) 145 | tfp_long_sr = sr_net(tfp_long).clip(0,1) 146 | save_img(img = normal_img(blur[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_blur.png') 147 | save_img(img = normal_img(sharp[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_sharp.png') 148 | save_img(img = normal_img(tfp_long[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_tfp_long.png') 149 | save_img(img = normal_img(tfp_long_sr[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_tfp_long_sr.png') 150 | # Short-TFP Part 151 | tfp = cal_tfp(spike,spike_idx,spike_bsn_len) 152 | tfp_resize = resize_method(tfp).clip(0,1) 153 | tfp_bsn = bsn_net(tfp).clip(0,1) 154 | tfp_bsn_resize = resize_method(tfp_bsn).clip(0,1) 155 | tfp_bsn_sr = sr_net(tfp_bsn).clip(0,1) 156 | deblur = sdm_deblur(blur,tfp_long_sr,tfp_bsn_sr) 157 | save_img(img = normal_img(tfp[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_tfp_short.png') 158 | save_img(img = normal_img(tfp_bsn[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_tfp_short_bsn.png') 159 | save_img(img = normal_img(tfp_bsn_sr[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_tfp_short_bsn_sr.png') 160 | save_img(img = normal_img(tfp_bsn_resize[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_tfp_short_bsn_resize.png') 161 | save_img(img = normal_img(tfp_resize[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_tfp_short_resize.png') 162 | save_img(img = normal_img(deblur[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_deblur.png') 163 | if opt.test_mode: 164 | break 165 | else: 166 | continue 167 | # save metric result 168 | if epoch % 10 == 0: 169 | with torch.no_grad(): 170 | # calculate the metric 171 | metrics = {} 172 | method_list = ['SDM','Blur_SR','Blur_Resize','BSN_SR','BSN_Resize','TFP_Resize'] 173 | # metric_list = ['mse','ssim','psnr','lpips'] 174 | metric_list = ['ssim','psnr'] 175 | for method_name in method_list: 176 | metrics[method_name] = {} # 初始化每个方法的字典 177 | for metric_name in metric_list: 178 | metrics[method_name][metric_name] = AverageMeter() 179 | for batch_idx, (blur,spike,sharp) in enumerate(tqdm(test_loader)): 180 | blur,spike = blur.to(device),spike.to(device) 181 | blur_gray = 0.11 * blur[:,0:1] + 0.59 * blur[:,1:2] + 0.3 * blur[:,2:3] 182 | sharp_gray = 0.11 * sharp[:,0:1] + 0.59 * sharp[:,1:2] + 0.3 * sharp[:,2:3] 183 | spike_idx = len(spike[0]) // 2 184 | # Metric 185 | # Long-TFP Part 186 | tfp_long = cal_tfp(spike,spike_idx,97) 187 | tfp_long_resize = resize_method(tfp_long).clip(0,1) 188 | tfp_long_sr = sr_net(tfp_long).clip(0,1) 189 | # Short-TFP Part 190 | tfp = cal_tfp(spike,spike_idx,spike_bsn_len) 191 | tfp_resize = resize_method(tfp).clip(0,1) 192 | tfp_bsn = bsn_net(tfp).clip(0,1) 193 | tfp_bsn_resize = resize_method(tfp_bsn).clip(0,1) 194 | tfp_bsn_sr = sr_net(tfp_bsn).clip(0,1) 195 | deblur = sdm_deblur(blur,tfp_long_sr,tfp_bsn_sr) 196 | for key in metric_list : 197 | # SDM 198 | metrics['SDM'][key].update(compute_img_metric(deblur,sharp,key)) 199 | # BLUR 200 | metrics['Blur_SR'][key].update(compute_img_metric(tfp_long_sr,blur_gray,key)) 201 | metrics['Blur_Resize'][key].update(compute_img_metric(tfp_long_resize,blur_gray,key)) 202 | # BSN 203 | metrics['BSN_SR'][key].update(compute_img_metric(tfp_bsn_sr,sharp_gray,key)) 204 | metrics['BSN_Resize'][key].update(compute_img_metric(tfp_bsn_resize,sharp_gray,key)) 205 | # TFP 206 | metrics['TFP_Resize'][key].update(compute_img_metric(tfp_resize,sharp_gray,key)) 207 | 208 | # Print all results 209 | for method_name in method_list: 210 | re_msg = '' 211 | for metric_name in metric_list: 212 | re_msg += metric_name + ": " + "{:.3f}".format(metrics[method_name][metric_name].avg) + " " 213 | logger.info(f"{method_name}: " + re_msg) 214 | writer.add_scalar(f'{method_name}/{metric_name}', metrics[method_name][metric_name].avg, epoch) 215 | # stop 216 | if opt.test_mode: 217 | break 218 | writer.close() --------------------------------------------------------------------------------