├── .gitignore
├── README.md
├── codes
├── dataset.py
├── metrics.py
├── model
│ ├── __pycache__
│ │ └── networks.cpython-38.pyc
│ ├── bsn_model.py
│ ├── edsr_model.py
│ ├── ldn_model.py
│ └── networks.py
└── utils.py
├── imgs
├── gopro.png
├── high_calib_compress.gif
├── middle_calib_compress.gif
├── overview.png
└── rsb.png
├── log
├── BSN
│ ├── opt.txt
│ ├── test_log.txt
│ └── train_log.txt
├── Deblur
│ ├── opt.txt
│ ├── test_log.txt
│ └── train_log.txt
└── SR
│ ├── opt.txt
│ ├── test_log.txt
│ └── train_log.txt
├── model
├── BSN_1000.pth
├── DeblurNet_100.pth
└── SR_70.pth
├── scripts
├── GOPRO_dataset.md
├── XVFI-main
│ ├── (0000.png
│ ├── README.md
│ ├── XVFInet.py
│ ├── __pycache__
│ │ ├── XVFInet.cpython-38.pyc
│ │ └── utils.cpython-38.pyc
│ ├── checkpoint_dir
│ │ ├── XVFInet_Vimeo_exp1
│ │ │ └── info.txt
│ │ └── XVFInet_X4K1000FPS_exp1
│ │ │ ├── XVFInet_X4K1000FPS_exp1_latest.pt
│ │ │ └── info.txt
│ ├── compare_psnr_ssim.m
│ ├── figures
│ │ ├── .txt
│ │ ├── 003.gif
│ │ ├── 004.gif
│ │ ├── 045.gif
│ │ ├── 078.gif
│ │ ├── 081.gif
│ │ ├── 146.gif
│ │ ├── Figure5.PNG
│ │ ├── Table2.PNG
│ │ ├── results_045_resized_768.gif
│ │ ├── results_079_resized_768.gif
│ │ └── results_158_resized_768.gif
│ ├── main.py
│ ├── run.sh
│ ├── test.gif
│ ├── text_dir
│ │ └── XVFInet_X4K1000FPS_exp1.txt
│ └── utils.py
├── __pycache__
│ ├── generate_spike.cpython-38.pyc
│ └── utils.cpython-38.pyc
├── blur_syn.py
├── run.sh
├── sharp_extract.py
├── spike_simulate.py
├── utils.py
└── utils_spike.py
├── train_bsn.py
├── train_deblur.py
└── train_sr.py
/.gitignore:
--------------------------------------------------------------------------------
1 | scripts/GOPRO/*
2 | # log/*
3 | visual_ablation_distillation.py
4 | visual_gopro.py
5 | run.sh
6 | model/DeblurNet_150.pth
7 | model/NEW_GOPRO_9_reblur0_24_100.pth
8 | model/NEW_GOPRO_9_reblur10_24_full_100.pth
9 | model/NEW_GOPRO_9_reblur10_24_full_nolpips.pth
10 | model/NEW_GOPRO_9_reblur50_24_full_100.pth
11 | model/NEW_GOPRO_9_reblur70_24_full.pth
12 | model/NEW_GOPRO_9_reblur100_24_full_100.pth
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
4 |
5 | If you like our project, please give us a star ⭐ on GitHub.
6 |
7 |
8 | [](https://arxiv.org/abs/2403.09486)
9 | [](https://github.com/chenkang455/S-SDM)
10 | [](https://github.com/chenkang455/S-SDM/stargazers)
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 | ## 📕 Abstract
21 | > We begin with a theoretical analysis of the relationship between spike streams, blurry images, and sharp sequences, leading to the development of our Spike-guided Deblurring Model (SDM). We further construct a self-supervised processing pipeline by cascading the denoising network and the super-resolution network to reduce the sensitivity of the SDM to spike noise and its reliance on spatial-resolution matching between the two modalities. To reduce the time consumption and enhance the utilization of spatial-temporal spike information within this pipeline, we further design a Lightweight Deblurring Network (LDN) and train it based on pseudo-labels from the teacher model, i.e., the established self-supervised processing pipeline. By further introducing re-blurring loss during LDN training, we achieve better restoration performance and faster processing speed compared to the processing-lengthy and structure-complicated teacher model.
22 |
23 | ## 👀 Visual Comparisons
24 |
25 | Sequence reconstruction on RSB dataset under different light conditions. (flicker is caused by the gif compression)
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 | Image debluring on RSB Dataset.
34 |
35 |
36 |
37 |
38 |
39 |
40 | ## 🗓️ TODO
41 | - [x] Release the scripts for simulating GOPRO dataset.
42 | - [x] Release the training and testing code.
43 | - [x] Release the pretrained model.
44 | - [x] Release the synthetic/real-world dataset.
45 |
46 | ## 🕶 Dataset
47 |
48 | Guidance on synthesizing the spike-based GOPRO dataset can be found in [GOPRO_dataset](scripts/GOPRO_dataset.md).
49 |
50 | Converted GOPRO dataset can be found in [GOPRO](https://pan.baidu.com/s/1ZvRNF4kqVB8qe1K78hmnzg?pwd=1623) and the real-world blur RSB dataset will be public once our manuscript is accepted.
51 |
52 | ## 🍭 Prepare
53 | Our S-SDM requires the sequential training of BSN, EDSR and LDN respectively. We provide the trained weights through the [link](https://pan.baidu.com/s/1FGqlMFtnL5jwI39I5mNkTw?pwd=1623), which should be placed in the folder `model/`. Meanwhile, downloaded/converted GOPRO dataset should be located under the `project root` folder. The structure of our project is formulated as:
54 | ```
55 |
56 | ├── codes
57 | ├── imgs
58 | ├── log (train and evaluation results)
59 | ├── model
60 | │ ├── BSN_1000.pth
61 | │ └── ...
62 | ├── scripts
63 | ├── GOPRO
64 | │ ├── test
65 | │ └── train
66 | ├── train_bsn.py
67 | ├── train_deblur.py
68 | └── train_sr.py
69 | ```
70 |
71 | ## 🌅 Train
72 |
73 | Train BSN on the GOPRO dataset:
74 | ```
75 | python train_bsn.py --base_folder GOPRO/ --bsn_len 9 --data_type GOPRO
76 | ```
77 |
78 | Train EDSR on the GOPRO dataset:
79 | ```
80 | python train_sr.py --base_folder GOPRO/ --data_type GOPRO
81 | ```
82 |
83 | Train LDN on the GOPRO dataset:
84 | ```
85 | python train_deblur.py --base_folder GOPRO/ --bsn_path model/BSN_1000.pth --sr_path model/SR_70.pth --lambda_reblur 100 --data_type GOPRO
86 | ```
87 |
88 | ## 📊 Evaluate
89 |
90 | Evaluate BSN on the GOPRO dataset:
91 | ```
92 | python train_bsn.py --test_mode --bsn_path model/BSN_1000.pth --data_type GOPRO
93 | ```
94 |
95 | Evaluate EDSR on the GOPRO dataset:
96 | ```
97 | python train_sr.py --test_mode --bsn_path model/BSN_1000.pth --sr_path model/SR_70.pth --data_type GOPRO
98 | ```
99 |
100 | Evaluate LDN on the GOPRO dataset:
101 | ```
102 | python train_deblur.py --test_mode --bsn_path model/BSN_1000.pth --sr_path model/SR_70.pth --deblur_path model/DeblurNet_100.pth --data_type GOPRO
103 | ```
104 |
105 | ## 📞 Contact
106 | Should you have any questions, please feel free to contact [mrchenkang@whu.edu.cn](mailto:mrchenkang@whu.edu.cn).
107 |
108 | ## 🤝 Citation
109 | If you find our work useful in your research, please cite:
110 |
111 | ```
112 | @article{chen2025spikereveal,
113 | title={Spikereveal: Unlocking temporal sequences from real blurry inputs with spike streams},
114 | author={Chen, Kang and Chen, Shiyan and Zhang, Jiyuan and Zhang, Baoyue and Zheng, Yajing and Huang, Tiejun and Yu, Zhaofei},
115 | journal={Advances in Neural Information Processing Systems},
116 | volume={37},
117 | pages={62673--62696},
118 | year={2025}
119 | }
120 | ```
121 |
--------------------------------------------------------------------------------
/codes/dataset.py:
--------------------------------------------------------------------------------
1 | from codes.utils import *
2 | import glob
3 | from torchvision import transforms
4 | # Spike Dataset
5 | class SpikeData(torch.utils.data.Dataset):
6 | def __init__(self, root_dir, data_type = 'GOPRO', stage = 'train',
7 | use_resize = False,use_roi = False,roi_size = [256,256],
8 | use_small = False,spike_full = False):
9 | """ Spike Dataset
10 |
11 | Args:
12 | root_dir (str): base folder of the dataset
13 | data_type (str, optional): data type. Defaults to 'GOPRO'.
14 | stage (str, optional): train / test. Defaults to 'train'.
15 | use_resize (bool, optional): . Defaults to False.
16 | use_roi (bool, optional): ROI operation. Defaults to False.
17 | roi_size (list, optional): ROI size for the image, [ROI/4,ROI/4] for the spike. Defaults to [256,256].
18 | use_small (bool, optional): small dataset for debugging. Defaults to False.
19 | spike_full (bool, optional): full spike size (1280 * 720) instead of (320 * 180). Defaults to False.
20 | """
21 | self.root_dir = root_dir
22 | self.data_type = data_type
23 | self.use_resize = use_resize
24 | self.use_roi = use_roi
25 | self.roi_size = roi_size
26 | self.spike_full = spike_full
27 | # Real DataSet
28 | if data_type == 'real':
29 | pattern = os.path.join(self.root_dir,'spike_data', '*','*')
30 | self.spike_list = sorted(glob.glob(pattern))
31 | self.width = 400 * 4
32 | self.height = 250 * 4
33 | # GOPRO Synthetic DataSet
34 | elif data_type == 'GOPRO':
35 | if self.spike_full:
36 | pattern = os.path.join(self.root_dir, stage,'spike_full', '*','*')
37 | else:
38 | pattern = os.path.join(self.root_dir, stage,'spike_data', '*','*')
39 | self.spike_list = sorted(glob.glob(pattern))
40 | if use_small == True:
41 | self.spike_list = self.spike_list[::10]
42 | self.width = 1280
43 | self.height = 720
44 | self.resize = transforms.Resize((self.height // 2,self.width // 2),interpolation=transforms.InterpolationMode.NEAREST)
45 | self.img_size = (self.height,self.width) if use_resize == False else (self.height // 2,self.width // 2)
46 | self.length = len(self.spike_list)
47 |
48 | def __getitem__(self, index: int):
49 | # blur and spike load
50 | if self.data_type in ['real']:
51 | spike_name = self.spike_list[index]
52 | spike = load_vidar_dat(spike_name,width=self.width //4 ,height=self.height // 4)
53 | blur_name = spike_name.replace('.dat','.jpg').replace('spike_data','blur_data')
54 | blur = cv2.imread(blur_name)
55 | elif self.data_type in ['GOPRO']:
56 | spike_name = self.spike_list[index]
57 | if self.spike_full:
58 | spike = load_vidar_dat(spike_name,width=self.width ,height=self.height )
59 | blur_name = spike_name.replace('.dat','.png').replace('spike_full','blur_data')
60 | else:
61 | spike = load_vidar_dat(spike_name,width=self.width // 4,height=self.height // 4)
62 | blur_name = spike_name.replace('.dat','.png').replace('spike_data','blur_data')
63 | blur = cv2.imread(blur_name)
64 |
65 | # sharp load
66 | if self.data_type in ['real']:
67 | sharp = np.zeros_like(blur)
68 | elif self.data_type in ['GOPRO']:
69 | if self.spike_full:
70 | sharp_name = spike_name.replace('.dat','.png').replace('spike_full','sharp_data')
71 | else:
72 | sharp_name = spike_name.replace('.dat','.png').replace('spike_data','sharp_data')
73 | sharp = cv2.imread(sharp_name)
74 |
75 | # channel & property exchange
76 | blur = torch.from_numpy(blur).permute((2,0,1)).float() / 255
77 | sharp = torch.from_numpy(sharp).permute((2,0,1)).float() / 255
78 | spike = torch.from_numpy(spike)
79 | # resize method (set true for synthetic NeRF dataset and false for real dataset)
80 | if self.use_resize == True:
81 | blur,spike,sharp = self.resize(blur),self.resize(spike),self.resize(sharp)
82 | # randomly crop
83 | if self.use_roi == True:
84 | if self.data_type not in ['GOPRO','real']:
85 | roiTL = (np.random.randint(0, self.img_size[0] -self.roi_size[0]+1), np.random.randint(0, self.img_size[1] -self.roi_size[1]+1))
86 | roiBR = (roiTL[0]+self.roi_size[0],roiTL[1]+self.roi_size[1])
87 | blur = blur[:,roiTL[0]:roiBR[0], roiTL[1]:roiBR[1]]
88 | spike = spike[:,roiTL[0]:roiBR[0], roiTL[1]:roiBR[1]]
89 | sharp = sharp[:,roiTL[0]:roiBR[0], roiTL[1]:roiBR[1]]
90 | else:
91 | roiTL = (np.random.randint(0, self.height // 4 - self.roi_size[0] // 4 +1), np.random.randint(0, self.width // 4 - self.roi_size[1] // 4+1))
92 | roiBR = (roiTL[0]+self.roi_size[0]//4,roiTL[1]+self.roi_size[1]//4)
93 | blur = blur[:,4 * roiTL[0]:4 * roiBR[0], 4 * roiTL[1]:4 * roiBR[1]]
94 | if self.spike_full:
95 | spike = spike[:,4 * roiTL[0]:4 * roiBR[0], 4 * roiTL[1]:4 * roiBR[1]]
96 | else:
97 | spike = spike[:,roiTL[0]:roiBR[0], roiTL[1]:roiBR[1]]
98 | sharp = sharp[:,4 * roiTL[0]:4 *roiBR[0], 4 *roiTL[1]:4 * roiBR[1]]
99 | return blur,spike,sharp
100 |
101 | def __len__(self):
102 | return self.length
--------------------------------------------------------------------------------
/codes/metrics.py:
--------------------------------------------------------------------------------
1 | from skimage import metrics
2 | import torch
3 | import torch.hub
4 | from lpips.lpips import LPIPS
5 | import os
6 | import numpy as np
7 |
8 | photometric = {
9 | "mse": None,
10 | "ssim": None,
11 | "psnr": None,
12 | "lpips": None
13 | }
14 |
15 |
16 | def compute_img_metric(im1t: torch.Tensor, im2t: torch.Tensor,
17 | metric="mse", margin=0, mask=None):
18 | """
19 | im1t, im2t: torch.tensors with batched imaged shape, range from (0, 1)
20 | """
21 | if metric not in photometric.keys():
22 | raise RuntimeError(f"img_utils:: metric {metric} not recognized")
23 | if photometric[metric] is None:
24 | if metric == "mse":
25 | photometric[metric] = metrics.mean_squared_error
26 | elif metric == "ssim":
27 | photometric[metric] = metrics.structural_similarity
28 | elif metric == "psnr":
29 | photometric[metric] = metrics.peak_signal_noise_ratio
30 | elif metric == "lpips":
31 | photometric[metric] = LPIPS().cpu()
32 |
33 | if mask is not None:
34 | if mask.dim() == 3:
35 | mask = mask.unsqueeze(1)
36 | if mask.shape[1] == 1:
37 | mask = mask.expand(-1, 3, -1, -1)
38 | mask = mask.permute(0, 2, 3, 1).numpy()
39 | batchsz, hei, wid, _ = mask.shape
40 | if margin > 0:
41 | marginh = int(hei * margin) + 1
42 | marginw = int(wid * margin) + 1
43 | mask = mask[:, marginh:hei - marginh, marginw:wid - marginw]
44 |
45 | # convert from [0, 1] to [-1, 1]
46 | im1t = (im1t * 2 - 1).clamp(-1, 1)
47 | im2t = (im2t * 2 - 1).clamp(-1, 1)
48 |
49 | if im1t.dim() == 3:
50 | im1t = im1t.unsqueeze(0)
51 | im2t = im2t.unsqueeze(0)
52 | im1t = im1t.detach().cpu()
53 | im2t = im2t.detach().cpu()
54 |
55 | if im1t.shape[-1] == 3:
56 | im1t = im1t.permute(0, 3, 1, 2)
57 | im2t = im2t.permute(0, 3, 1, 2)
58 |
59 | im1 = im1t.permute(0, 2, 3, 1).numpy()
60 | im2 = im2t.permute(0, 2, 3, 1).numpy()
61 | batchsz, hei, wid, _ = im1.shape
62 | if margin > 0:
63 | marginh = int(hei * margin) + 1
64 | marginw = int(wid * margin) + 1
65 | im1 = im1[:, marginh:hei - marginh, marginw:wid - marginw]
66 | im2 = im2[:, marginh:hei - marginh, marginw:wid - marginw]
67 | values = []
68 |
69 | for i in range(batchsz):
70 | if metric in ["mse", "psnr"]:
71 | if mask is not None:
72 | im1 = im1 * mask[i]
73 | im2 = im2 * mask[i]
74 | value = photometric[metric](
75 | im1[i], im2[i]
76 | )
77 | if mask is not None:
78 | hei, wid, _ = im1[i].shape
79 | pixelnum = mask[i, ..., 0].sum()
80 | value = value - 10 * np.log10(hei * wid / pixelnum)
81 | elif metric in ["ssim"]:
82 | value, ssimmap = photometric["ssim"](
83 | im1[i], im2[i], multichannel=True, full=True
84 | )
85 | if mask is not None:
86 | value = (ssimmap * mask[i]).sum() / mask[i].sum()
87 | elif metric in ["lpips"]:
88 | value = photometric[metric](
89 | im1t[i:i + 1], im2t[i:i + 1]
90 | )[0,0,0,0]
91 | else:
92 | raise NotImplementedError
93 | values.append(value)
94 |
95 | return sum(values) / len(values)
96 |
--------------------------------------------------------------------------------
/codes/model/__pycache__/networks.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/codes/model/__pycache__/networks.cpython-38.pyc
--------------------------------------------------------------------------------
/codes/model/bsn_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | # from utils import *
5 | import numpy as np
6 | from codes.model.networks import *
7 |
8 | class crop(nn.Module):
9 | def __init__(self):
10 | super().__init__()
11 |
12 | def forward(self, x):
13 | N, C, H, W = x.shape
14 | x = x[0:N, 0:C, 0:H-1, 0:W]
15 | return x
16 |
17 | class shift(nn.Module):
18 | def __init__(self):
19 | super().__init__()
20 | self.shift_down = nn.ZeroPad2d((0,0,1,0))
21 | self.crop = crop()
22 |
23 | def forward(self, x):
24 | x = self.shift_down(x)
25 | x = self.crop(x)
26 | return x
27 |
28 | class Conv(nn.Module):
29 | def __init__(self, in_channels, out_channels, bias=False, blind=True,stride=1,padding=0,kernel_size=3):
30 | super().__init__()
31 | self.blind = blind
32 | if blind:
33 | self.shift_down = nn.ZeroPad2d((0,0,1,0))
34 | self.crop = crop()
35 | self.replicate = nn.ReplicationPad2d(1)
36 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,padding=padding,bias=bias)
37 | self.relu = nn.LeakyReLU(0.1, inplace=True)
38 |
39 | def forward(self, x):
40 | if self.blind:
41 | x = self.shift_down(x)
42 | x = self.replicate(x)
43 | x = self.conv(x)
44 | x = self.relu(x)
45 | if self.blind:
46 | x = self.crop(x)
47 | return x
48 |
49 | class Pool(nn.Module):
50 | def __init__(self, blind=True):
51 | super().__init__()
52 | self.blind = blind
53 | if blind:
54 | self.shift = shift()
55 | self.pool = nn.MaxPool2d(2)
56 |
57 | def forward(self, x):
58 | if self.blind:
59 | x = self.shift(x)
60 | x = self.pool(x)
61 | return x
62 |
63 | class rotate(nn.Module):
64 | def __init__(self):
65 | super().__init__()
66 |
67 | def forward(self, x):
68 | x90 = x.transpose(2,3).flip(3)
69 | x180 = x.flip(2).flip(3)
70 | x270 = x.transpose(2,3).flip(2)
71 | x = torch.cat((x,x90,x180,x270), dim=0)
72 | return x
73 |
74 | class unrotate(nn.Module):
75 | def __init__(self):
76 | super().__init__()
77 |
78 | def forward(self, x):
79 | x0, x90, x180, x270 = torch.chunk(x, 4, dim=0)
80 | x90 = x90.transpose(2,3).flip(2)
81 | x180 = x180.flip(2).flip(3)
82 | x270 = x270.transpose(2,3).flip(3)
83 | x = torch.cat((x0,x90,x180,x270), dim=1)
84 | return x
85 |
86 | class ENC_Conv(nn.Module):
87 | def __init__(self, in_channels, mid_channels, out_channels, bias=False, reduce=True, blind=True):
88 | super().__init__()
89 | self.reduce = reduce
90 | self.conv1 = Conv(in_channels, mid_channels, bias=bias, blind=blind)
91 | self.conv2 = Conv(mid_channels, mid_channels, bias=bias, blind=blind)
92 | self.conv3 = Conv(mid_channels, out_channels, bias=bias, blind=blind)
93 | if reduce:
94 | self.pool = Pool(blind=blind)
95 |
96 | def forward(self, x):
97 | x = self.conv1(x)
98 | x = self.conv2(x)
99 | x = self.conv3(x)
100 | if self.reduce:
101 | x = self.pool(x)
102 | return x
103 |
104 | class DEC_Conv(nn.Module):
105 | def __init__(self, in_channels, mid_channels, out_channels, bias=False, blind=True):
106 | super().__init__()
107 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
108 | self.conv1 = Conv(in_channels, mid_channels, bias=bias, blind=blind)
109 | self.conv2 = Conv(mid_channels, mid_channels, bias=bias, blind=blind)
110 | self.conv3 = Conv(mid_channels, mid_channels, bias=bias, blind=blind)
111 | self.conv4 = Conv(mid_channels, out_channels, bias=bias, blind=blind)
112 |
113 | def forward(self, x, x_in):
114 | x = self.upsample(x)
115 |
116 | # Smart Padding
117 | diffY = x_in.size()[2] - x.size()[2]
118 | diffX = x_in.size()[3] - x.size()[3]
119 | x = F.pad(x, [diffX // 2, diffX - diffX // 2,
120 | diffY // 2, diffY - diffY // 2])
121 |
122 | x = torch.cat((x, x_in), dim=1)
123 | x = self.conv1(x)
124 | x = self.conv2(x)
125 | x = self.conv3(x)
126 | x = self.conv4(x)
127 | return x
128 |
129 | class Blind_UNet(nn.Module):
130 | def __init__(self, n_channels=3, n_output=96, bias=False, blind=True):
131 | super().__init__()
132 | self.n_channels = n_channels
133 | self.bias = bias
134 | self.enc1 = ENC_Conv(n_channels, 48, 48, bias=bias, blind=blind)
135 | self.enc2 = ENC_Conv(48, 48, 48, bias=bias, blind=blind)
136 | self.enc3 = ENC_Conv(48, 96, 48, bias=bias, reduce=False, blind=blind)
137 | self.dec2 = DEC_Conv(96, 96, 96, bias=bias, blind=blind)
138 | self.dec1 = DEC_Conv(96+n_channels, 96, n_output, bias=bias, blind=blind)
139 |
140 | def forward(self, input):
141 | x1 = self.enc1(input)
142 | x2 = self.enc2(x1)
143 | x = self.enc3(x2)
144 | x = self.dec2(x, x1)
145 | x = self.dec1(x, input)
146 | return x
147 |
148 | class BSN(nn.Module):
149 | def __init__(self, n_channels=1, n_output=1, bias=False, blind=True, sigma_known=True):
150 | super().__init__()
151 | self.n_channels = n_channels
152 | self.c = n_channels
153 | self.n_output = n_output
154 | self.bias = bias
155 | self.blind = blind
156 | self.sigma_known = sigma_known
157 | self.rotate = rotate()
158 | self.unet = Blind_UNet(n_channels=1, bias=bias, blind=blind)
159 | self.shift = shift()
160 | self.unrotate = unrotate()
161 | self.nin_A = nn.Conv2d(384, 384, 1, bias=bias)
162 | self.nin_B = nn.Conv2d(384, 96, 1, bias=bias)
163 | self.nin_C = nn.Conv2d(96, n_output, 1, bias=bias)
164 |
165 | def forward(self, x):
166 | N, C, H, W = x.shape
167 | # square
168 | if(H > W):
169 | diff = H - W
170 | x = F.pad(x, [diff // 2, diff - diff // 2, 0, 0], mode = 'reflect')
171 | elif(W > H):
172 | diff = W - H
173 | x = F.pad(x, [0, 0, diff // 2, diff - diff // 2], mode = 'reflect')
174 | x = self.rotate(x)
175 | x = self.unet(x) # 4 3 100 100 -> 4 96 100 100
176 | if self.blind:
177 | x = self.shift(x)
178 | x = self.unrotate(x) #4 96 100 100 -> 1 384 100 100
179 | x = F.leaky_relu_(self.nin_A(x), negative_slope=0.1)
180 | x = F.leaky_relu_(self.nin_B(x), negative_slope=0.1)
181 | x = self.nin_C(x)
182 | # Unsquare
183 | if(H > W):
184 | diff = H - W
185 | x = x[:, :, 0:H, (diff // 2):(diff // 2 + W)]
186 | elif(W > H):
187 | diff = W - H
188 | x = x[:, :, (diff // 2):(diff // 2 + H), 0:W]
189 | return x
--------------------------------------------------------------------------------
/codes/model/edsr_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 | def default_conv(in_channels, out_channels, kernel_size, bias=True):
7 | return nn.Conv2d(
8 | in_channels, out_channels, kernel_size,
9 | padding=(kernel_size//2), bias=bias)
10 |
11 | class MeanShift(nn.Conv2d):
12 | def __init__(
13 | self, rgb_range,
14 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
15 |
16 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
17 | std = torch.Tensor(rgb_std)
18 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
19 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
20 | for p in self.parameters():
21 | p.requires_grad = False
22 | import torch
23 | import torch.nn as nn
24 |
25 | class MeanShift_GRAY(nn.Conv2d):
26 | def __init__(
27 | self, rgb_range,
28 | gray_mean=0.446, gray_std=1.0, sign=-1):
29 |
30 | super(MeanShift_GRAY, self).__init__(1, 1, kernel_size=1)
31 | std = torch.tensor([gray_std])
32 | self.weight.data = torch.eye(1).view(1, 1, 1, 1) / std.view(1, 1, 1, 1)
33 | self.bias.data = sign * rgb_range * torch.tensor([gray_mean]) / std
34 | for p in self.parameters():
35 | p.requires_grad = False
36 |
37 | class BasicBlock(nn.Sequential):
38 | def __init__(
39 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False,
40 | bn=True, act=nn.ReLU(True)):
41 |
42 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
43 | if bn:
44 | m.append(nn.BatchNorm2d(out_channels))
45 | if act is not None:
46 | m.append(act)
47 |
48 | super(BasicBlock, self).__init__(*m)
49 |
50 | class ResBlock(nn.Module):
51 | def __init__(
52 | self, conv, n_feats, kernel_size,
53 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
54 |
55 | super(ResBlock, self).__init__()
56 | m = []
57 | for i in range(2):
58 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
59 | if bn:
60 | m.append(nn.BatchNorm2d(n_feats))
61 | if i == 0:
62 | m.append(act)
63 |
64 | self.body = nn.Sequential(*m)
65 | self.res_scale = res_scale
66 |
67 | def forward(self, x):
68 | res = self.body(x).mul(self.res_scale)
69 | res += x
70 |
71 | return res
72 |
73 | class Upsampler(nn.Sequential):
74 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
75 |
76 | m = []
77 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
78 | for _ in range(int(math.log(scale, 2))):
79 | m.append(conv(n_feats, 4 * n_feats, 3, bias))
80 | m.append(nn.PixelShuffle(2))
81 | if bn:
82 | m.append(nn.BatchNorm2d(n_feats))
83 | if act == 'relu':
84 | m.append(nn.ReLU(True))
85 | elif act == 'prelu':
86 | m.append(nn.PReLU(n_feats))
87 |
88 | elif scale == 3:
89 | m.append(conv(n_feats, 9 * n_feats, 3, bias))
90 | m.append(nn.PixelShuffle(3))
91 | if bn:
92 | m.append(nn.BatchNorm2d(n_feats))
93 | if act == 'relu':
94 | m.append(nn.ReLU(True))
95 | elif act == 'prelu':
96 | m.append(nn.PReLU(n_feats))
97 | else:
98 | raise NotImplementedError
99 |
100 | super(Upsampler, self).__init__(*m)
101 |
102 |
103 |
104 | def make_model(args, parent=False):
105 | return EDSR(args)
106 |
107 | class EDSR(nn.Module):
108 | def __init__(self,color_num = 1, conv=default_conv):
109 | super(EDSR, self).__init__()
110 |
111 | n_resblocks = 16
112 | n_feats = 64
113 | kernel_size = 3
114 | scale = 4
115 | act = nn.ReLU(True)
116 | url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale)
117 | if color_num == 1:
118 | self.sub_mean = MeanShift_GRAY(1)
119 | self.add_mean = MeanShift_GRAY(1, sign=1)
120 | else:
121 | self.sub_mean = MeanShift(255)
122 | self.add_mean = MeanShift(255, sign=1)
123 |
124 | # define head module
125 | m_head = [conv(color_num, n_feats, kernel_size)]
126 |
127 | # define body module
128 | m_body = [
129 | ResBlock(
130 | conv, n_feats, kernel_size, act=act, res_scale=1
131 | ) for _ in range(n_resblocks)
132 | ]
133 | m_body.append(conv(n_feats, n_feats, kernel_size))
134 |
135 | # define tail module
136 | m_tail = [
137 | Upsampler(conv, scale, n_feats, act=False),
138 | conv(n_feats, color_num, kernel_size)
139 | ]
140 |
141 | self.head = nn.Sequential(*m_head)
142 | self.body = nn.Sequential(*m_body)
143 | self.tail = nn.Sequential(*m_tail)
144 |
145 | def forward(self, x):
146 | x = self.sub_mean(x)
147 | x = self.head(x)
148 |
149 | res = self.body(x)
150 | res += x
151 |
152 | x = self.tail(res)
153 | x = self.add_mean(x)
154 |
155 | return x
156 |
--------------------------------------------------------------------------------
/codes/model/ldn_model.py:
--------------------------------------------------------------------------------
1 | from thop import profile
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from codes.model.networks import *
6 |
7 | def conv(inDim,outDim,ks,s,p):
8 | # inDim,outDim,kernel_size,stride,padding
9 | conv = nn.Conv2d(inDim, outDim, kernel_size=ks, stride=s, padding=p)
10 | relu = nn.ReLU(True)
11 | seq = nn.Sequential(conv, relu)
12 | return seq
13 |
14 | def de_conv(inDim,outDim,ks,s,p,op):
15 | # inDim,outDim,kernel_size,stride,padding
16 | conv_t = nn.ConvTranspose2d(inDim,outDim, kernel_size=ks, stride=s,
17 | padding=p,output_padding= op)
18 | relu = nn.ReLU(inplace=True)
19 | seq = nn.Sequential(conv_t, relu)
20 | return seq
21 |
22 | class ChannelAttention(nn.Module):
23 | ## channel attention block
24 | def __init__(self, in_planes, ratio=16):
25 | super(ChannelAttention, self).__init__()
26 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
27 | self.max_pool = nn.AdaptiveMaxPool2d(1)
28 |
29 | self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
30 | nn.ReLU(),
31 | nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))
32 | self.sigmoid = nn.Sigmoid()
33 |
34 | def forward(self, x):
35 | avg_out = self.fc(self.avg_pool(x))
36 | max_out = self.fc(self.max_pool(x))
37 | out = avg_out + max_out
38 | return self.sigmoid(out)
39 |
40 | class SpatialAttention(nn.Module):
41 | ## spatial attention block
42 | def __init__(self, kernel_size=7):
43 | super(SpatialAttention, self).__init__()
44 |
45 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
46 | self.sigmoid = nn.Sigmoid()
47 |
48 | def forward(self, x):
49 | avg_out = torch.mean(x, dim=1, keepdim=True)
50 | max_out, _ = torch.max(x, dim=1, keepdim=True)
51 | x = torch.cat([avg_out, max_out], dim=1)
52 | x = self.conv1(x)
53 | return self.sigmoid(x)
54 |
55 |
56 | class Deblur_Net(nn.Module):
57 | def __init__(self,spike_dim = 21) -> None:
58 | super().__init__()
59 | # down_sample
60 | self.blur_enc = conv(3,16,5,2,2)
61 | self.blur_enc2 = conv(16,32,5,2,2)
62 | # sample_same
63 | self.spike_enc = conv(spike_dim,16,3,1,1)
64 | self.spike_enc2 = conv(16,32,3,1,1)
65 | # res
66 | self.resBlock1 = ResnetBlock(64,'zero',get_norm_layer('none'), False, True)
67 | self.resBlock2 = ResnetBlock(64,'zero',get_norm_layer('none'), False, True)
68 | # CBAM
69 | self.ca = ChannelAttention(64)
70 | self.sa = SpatialAttention()
71 | # up_sample
72 | self.decoder1 = de_conv(64,32,5,2,2,(1,1))
73 | self.decoder2 = de_conv(32,16,5,2,2,(1,1))
74 | self.pred = nn.Conv2d(3 + 16, 3, kernel_size=1, stride=1)
75 |
76 | def forward(self,blur,spike):
77 | # blur branch
78 |
79 | blur_re = self.blur_enc(blur)
80 | blur_re = self.blur_enc2(blur_re)
81 | # spike branch
82 | spike_re = self.spike_enc(spike)
83 | spike_re = self.spike_enc2(spike_re)
84 | # fusion
85 | fusion = torch.cat([blur_re,spike_re],dim = 1)
86 | fusion = self.ca(fusion) * fusion
87 | fusion = self.sa(fusion) * fusion
88 | fusion = self.resBlock1(fusion)
89 | fusion = self.resBlock2(fusion)
90 | fusion = self.decoder1(fusion)
91 | fusion = self.decoder2(fusion)
92 | result = self.pred(torch.cat([blur,fusion],dim = 1))
93 | return result
94 |
95 |
96 | if __name__ == '__main__':
97 | net = Deblur_Net()
98 | blur = torch.zeros((1,3,720,1280))
99 | spike = torch.zeros((1,21,720//4,1280//4))
100 | flops, params = profile((net).cpu(), inputs=(blur,spike))
101 | print("FLOPs=", str(flops/1e9) +'{}'.format("G"))
102 | total = sum(p.numel() for p in net.parameters())
103 | print("Total params: %.5fM" % (total/1e6))
104 | print(net(blur,spike))
105 |
--------------------------------------------------------------------------------
/codes/utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import numpy as np
4 | import imageio
5 | import os
6 | import torch.nn as nn
7 | import random
8 | # Save Network
9 | def save_network(network, save_path):
10 | if isinstance(network, nn.DataParallel):
11 | network = network.module
12 | state_dict = network.state_dict()
13 | for key, param in state_dict.items():
14 | state_dict[key] = param.cpu()
15 | torch.save(state_dict, save_path)
16 |
17 | def set_random_seed(seed):
18 | """Set random seeds."""
19 | torch.manual_seed(seed)
20 | torch.cuda.manual_seed_all(seed)
21 | np.random.seed(seed)
22 | random.seed(seed)
23 | torch.backends.cudnn.deterministic = True
24 | torch.cuda.manual_seed(seed)
25 |
26 | def save_opt(opt,opt_path):
27 | with open(opt_path, 'w') as f:
28 | for key, value in vars(opt).items():
29 | f.write(f"{key}: {value}\n")
30 |
31 | def save_gif(image_list, gif_path = 'test', duration = 2,RGB = True,nor = False):
32 | imgs = []
33 | os.makedirs('Video',exist_ok = True)
34 | with imageio.get_writer(os.path.join('Video',gif_path + '.gif'), mode='I',duration = 1000 * duration / len(image_list),loop=0) as writer:
35 | for i in range(len(image_list)):
36 | img = normal_img(image_list[i],RGB,nor)
37 | writer.append_data(img)
38 |
39 | def save_video(image_list,path = 'test',duration = 2,RGB = True,nor = False):
40 | os.makedirs('Video',exist_ok = True)
41 | imgs = []
42 | for i in range(len(image_list)):
43 | img = normal_img(image_list[i],RGB,nor)
44 | imgs.append(img)
45 | imageio.mimwrite(os.path.join('Video',path + '.mp4'), imgs, fps=30, quality=8)
46 |
47 |
48 | def normal_img(img,RGB = True,nor = True):
49 | if nor:
50 | img = 255 * ((img - img.min()) / (img.max() - img.min()))
51 | if (img.shape[0] == 3 or img.shape[0] == 1) and isinstance(img,torch.Tensor):
52 | img = img.permute(1,2,0)
53 | if isinstance(img,torch.Tensor):
54 | img = np.array(img.detach().cpu())
55 | if len(img.shape) == 2:
56 | img = img[...,None]
57 | if img.shape[-1] == 1:
58 | img = np.repeat(img,3,axis = -1)
59 | img = img.astype(np.uint8)
60 | if RGB == False:
61 | img = img[...,::-1]
62 | return img
63 |
64 | def save_img(path = 'test.png',img = None,nor = True):
65 | if nor:
66 | img = 255 * ((img - img.min()) / (img.max() - img.min()))
67 | if isinstance(img,torch.Tensor):
68 | img = np.array(img.detach().cpu())
69 | img = img.astype(np.uint8)
70 | cv2.imwrite(path,img)
71 |
72 | def make_folder(path):
73 | os.makedirs(path,exist_ok = True)
74 |
75 | class AverageMeter(object):
76 | """Computes and stores the average and current value"""
77 |
78 | def __init__(self):
79 | self.reset()
80 |
81 | def reset(self):
82 | self.val = 0
83 | self.avg = 0
84 | self.sum = 0
85 | self.count = 0
86 |
87 | def update(self, val, n=1):
88 | self.val = val
89 | self.count += n
90 | self.sum += val * n
91 | self.avg = self.sum / self.count
92 |
93 | def video_to_spike(
94 | sourefolder=None,
95 | imgs = None,
96 | savefolder_debug=None,
97 | threshold=5.0,
98 | init_noise=True,
99 | format="png",
100 | ):
101 | """
102 | 函数说明
103 | :param 参数名: 参数说明
104 | :return: 返回值名: 返回值说明
105 | """
106 | if sourefolder != None:
107 | filelist = sorted(os.listdir(sourefolder))
108 | datas = [fn for fn in filelist if fn.endswith(format)]
109 |
110 | T = len(datas)
111 |
112 | frame0 = cv2.imread(os.path.join(sourefolder, datas[0]))
113 | H, W, C = frame0.shape
114 |
115 | frame0 = cv2.cvtColor(frame0, cv2.COLOR_BGR2GRAY)
116 |
117 | spikematrix = np.zeros([T, H, W], np.uint8)
118 |
119 | if init_noise:
120 | integral = np.random.random(size=([H,W])) * threshold
121 | else:
122 | integral = np.random.zeros(size=([H,W]))
123 |
124 | Thr = np.ones_like(integral).astype(np.float32) * threshold
125 |
126 | for t in range(0, T):
127 | frame = cv2.imread(os.path.join(sourefolder, datas[t]))
128 | if C > 1:
129 | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
130 | gray = gray / 255.0
131 | integral += gray
132 | fire = (integral - Thr) >= 0
133 | fire_pos = fire.nonzero()
134 |
135 | integral[fire_pos] -= threshold
136 | spikematrix[t][fire_pos] = 1
137 |
138 | if savefolder_debug:
139 | np.save(os.path.join(savefolder_debug, "spike_debug.npy"), spikematrix)
140 | elif imgs != None:
141 | frame0 = imgs[0]
142 | H, W, C = frame0.shape
143 | T = len(imgs)
144 | spikematrix = np.zeros([T, H, W], np.uint8)
145 |
146 | if init_noise:
147 | integral = np.random.random(size=([H,W])) * threshold
148 | else:
149 | integral = np.random.zeros(size=([H,W]))
150 |
151 | Thr = np.ones_like(integral).astype(np.float32) * threshold
152 |
153 | for t in range(0, T):
154 | frame = imgs[t]
155 | if C > 1:
156 | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
157 | gray = gray / 255.0
158 | integral += gray
159 | fire = (integral - Thr) >= 0
160 | fire_pos = fire.nonzero()
161 |
162 | integral[fire_pos] -= threshold
163 | spikematrix[t][fire_pos] = 1
164 | return spikematrix
165 |
166 |
167 | def load_vidar_dat(filename, left_up=(0, 0), window=None, frame_cnt = None,height = 800, width = 800, **kwargs):
168 | if isinstance(filename, str):
169 | array = np.fromfile(filename, dtype=np.uint8)
170 | elif isinstance(filename, (list, tuple)):
171 | l = []
172 | for name in filename:
173 | a = np.fromfile(name, dtype=np.uint8)
174 | l.append(a)
175 | array = np.concatenate(l)
176 | else:
177 | raise NotImplementedError
178 |
179 | if window == None:
180 | window = (height - left_up[0], width - left_up[0])
181 |
182 | len_per_frame = height * width // 8
183 | framecnt = frame_cnt if frame_cnt != None else len(array) // len_per_frame
184 |
185 | spikes = []
186 |
187 | for i in range(framecnt):
188 | compr_frame = array[i * len_per_frame: (i + 1) * len_per_frame]
189 | blist = []
190 | for b in range(8):
191 | blist.append(np.right_shift(np.bitwise_and(compr_frame, np.left_shift(1, b)), b))
192 |
193 | frame_ = np.stack(blist).transpose()
194 | frame_ = np.flipud(frame_.reshape((height, width), order='C'))
195 |
196 | if window is not None:
197 | spk = frame_[left_up[0]:left_up[0] + window[0], left_up[1]:left_up[1] + window[1]]
198 | else:
199 | spk = frame_
200 |
201 | spk = spk.copy().astype(np.float32)[None]
202 |
203 | spikes.append(spk)
204 |
205 | return np.concatenate(spikes)
206 |
207 |
208 | import logging
209 | # log info
210 | def setup_logging(log_file):
211 | logger = logging.getLogger('training_logger')
212 | logger.setLevel(logging.INFO)
213 | if logger.hasHandlers():
214 | logger.handlers.clear()
215 | file_handler = logging.FileHandler(log_file, mode='w') # 使用'w'模式打开文件
216 | file_handler.setLevel(logging.INFO)
217 | console_handler = logging.StreamHandler()
218 | console_handler.setLevel(logging.INFO)
219 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
220 | file_handler.setFormatter(formatter)
221 | console_handler.setFormatter(formatter)
222 | logger.addHandler(file_handler)
223 | logger.addHandler(console_handler)
224 | return logger
225 |
226 | def normalize(arr):
227 | return (arr - arr.min()) / (arr.max() - arr.min())
228 |
229 |
230 | def generate_labels(file_name):
231 | num_part = file_name.split('/')[-1]
232 | non_num_part = file_name.replace(num_part, '')
233 | num = int(num_part)
234 | labels = [non_num_part + str(num + 2 * i).zfill(len(num_part)) + '.png' for i in range(-3, 4)]
235 | return labels
--------------------------------------------------------------------------------
/imgs/gopro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/imgs/gopro.png
--------------------------------------------------------------------------------
/imgs/high_calib_compress.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/imgs/high_calib_compress.gif
--------------------------------------------------------------------------------
/imgs/middle_calib_compress.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/imgs/middle_calib_compress.gif
--------------------------------------------------------------------------------
/imgs/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/imgs/overview.png
--------------------------------------------------------------------------------
/imgs/rsb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/imgs/rsb.png
--------------------------------------------------------------------------------
/log/BSN/opt.txt:
--------------------------------------------------------------------------------
1 | base_folder: GOPRO/
2 | save_folder: exp/BSN
3 | data_type: GOPRO
4 | exp_name: NEW_GOPRO_9_test_full
5 | epochs: 1001
6 | lr: 0.0003
7 | seed: 42
8 | bsn_len: 9
9 | width: 1280
10 | height: 720
11 | scheduler: test
12 | use_small: False
13 | spike_full: False
14 | test_mode: False
15 | bsn_path: ...
16 |
--------------------------------------------------------------------------------
/log/BSN/test_log.txt:
--------------------------------------------------------------------------------
1 | 2024-03-14 17:32:06,108 - INFO - Namespace(base_folder='GOPRO/', bsn_len=9, bsn_path='model/BSN_1000.pth', data_type='GOPRO', epochs=1001, exp_name='test', height=720, lr=0.0003, save_folder='exp/BSN', seed=42, spike_full=False, test_mode=True, use_small=False, width=1280)
2 | 2024-03-14 17:32:06,391 - INFO - Start Training!
3 | 2024-03-14 17:40:11,730 - INFO - TFP: mse: 0.010 ssim: 0.733 psnr: 25.997 lpips: 0.211
4 | 2024-03-14 17:40:11,731 - INFO - BSN: mse: 0.008 ssim: 0.842 psnr: 27.365 lpips: 0.109
5 |
--------------------------------------------------------------------------------
/log/Deblur/opt.txt:
--------------------------------------------------------------------------------
1 | base_folder: GOPRO/
2 | save_folder: exp/Deblur
3 | data_type: GOPRO
4 | exp_name: NEW_GOPRO_9_reblur100_24_full
5 | bsn_path: exp/BSN/NEW_GOPRO_9_test_full/ckpts/BSN_1000.pth
6 | sr_path: exp/SR/NEW_GOPRO_9_bsn1000/ckpts/SR_70.pth
7 | deblur_path: exp/Deblur/NEW_GOPRO_9_reblur10_24_full/ckpts/DeblurNet_185.pth
8 | crop_size: 256
9 | epochs: 151
10 | lr: 0.001
11 | seed: 42
12 | spike_deblur_len: 21
13 | spike_bsn_len: 9
14 | lambda_tea: 1
15 | lambda_reblur: 100.0
16 | blur_step: 24
17 | use_small: False
18 | test_mode: False
19 | use_ssim: False
20 | bs: 4
21 |
--------------------------------------------------------------------------------
/log/Deblur/test_log.txt:
--------------------------------------------------------------------------------
1 | 2024-03-15 12:47:41,145 - INFO - Namespace(base_folder='GOPRO/', blur_step=24, bs=4, bsn_path='model/BSN_1000.pth', data_type='GOPRO', deblur_path='model/DeblurNet_100.pth', epochs=101, exp_name='test', lambda_reblur=100, lambda_tea=1, lr=0.001, roi_size=512, save_folder='exp/Deblur', seed=42, spike_bsn_len=9, spike_deblur_len=21, sr_path='model/SR_70.pth', test_mode=True, use_small=False, use_ssim=False)
2 | 2024-03-15 12:47:42,745 - INFO - Start Training!
3 | 2024-03-15 13:10:10,764 - INFO - SDM: ssim: 0.732 psnr: 26.362
4 | 2024-03-15 13:10:10,765 - INFO - Deblur_Net: ssim: 0.786 psnr: 27.928
5 |
--------------------------------------------------------------------------------
/log/Deblur/train_log.txt:
--------------------------------------------------------------------------------
1 | 2024-02-21 11:11:15,473 - INFO - Namespace(base_folder='GOPRO/', save_folder='exp/Deblur', data_type='GOPRO', exp_name='NEW_GOPRO_9_reblur100_24_full', bsn_path='exp/BSN/NEW_GOPRO_9_test_full/ckpts/BSN_1000.pth', sr_path='exp/SR/NEW_GOPRO_9_bsn1000/ckpts/SR_70.pth', deblur_path='exp/Deblur/NEW_GOPRO_9_reblur10_24_full/ckpts/DeblurNet_185.pth', crop_size=256, epochs=151, lr=0.001, seed=42, spike_deblur_len=21, spike_bsn_len=9, lambda_tea=1, lambda_reblur=100.0, blur_step=24, use_small=False, test_mode=False, use_ssim=False, bs=4)
2 | 2024-02-21 11:11:18,625 - INFO - Start Training!
3 | 2024-02-21 11:15:07,104 - INFO - EPOCH 0/151: Total Train Loss: 1.981791118780772, Tea Loss: 0.5576680742558979, Reblur Loss: 1.4241230393642987
4 | 2024-02-21 11:20:35,766 - INFO - EPOCH 1/151: Total Train Loss: 0.7465140743431075, Tea Loss: 0.45105000272457735, Reblur Loss: 0.295464072908674
5 | 2024-02-21 11:24:23,129 - INFO - EPOCH 2/151: Total Train Loss: 0.56218161495217, Tea Loss: 0.3737060081907165, Reblur Loss: 0.18847560703560903
6 | 2024-02-21 11:28:10,566 - INFO - EPOCH 3/151: Total Train Loss: 0.4921396754004739, Tea Loss: 0.3222369218801523, Reblur Loss: 0.16990275352032153
7 | 2024-02-21 11:31:57,957 - INFO - EPOCH 4/151: Total Train Loss: 0.4323549967307549, Tea Loss: 0.29852057851496194, Reblur Loss: 0.13383441787713019
8 | 2024-02-21 11:35:45,372 - INFO - EPOCH 5/151: Total Train Loss: 0.3681258845638919, Tea Loss: 0.2753991732349643, Reblur Loss: 0.09272671207076028
9 | 2024-02-21 11:41:04,270 - INFO - EPOCH 6/151: Total Train Loss: 0.2876348569021597, Tea Loss: 0.23525384732913146, Reblur Loss: 0.05238100977058147
10 | 2024-02-21 11:44:51,678 - INFO - EPOCH 7/151: Total Train Loss: 0.2613711170690916, Tea Loss: 0.21692385534187417, Reblur Loss: 0.04444726203362663
11 | 2024-02-21 11:48:39,157 - INFO - EPOCH 8/151: Total Train Loss: 0.24577252315236375, Tea Loss: 0.20893679752752378, Reblur Loss: 0.03683572499589487
12 | 2024-02-21 11:52:26,640 - INFO - EPOCH 9/151: Total Train Loss: 0.2426675031066457, Tea Loss: 0.20484271858419692, Reblur Loss: 0.037824784236198124
13 | 2024-02-21 11:56:14,252 - INFO - EPOCH 10/151: Total Train Loss: 0.2309971262753268, Tea Loss: 0.19534648903262564, Reblur Loss: 0.03565063678308741
14 | 2024-02-21 12:01:35,223 - INFO - EPOCH 11/151: Total Train Loss: 0.22711923157239888, Tea Loss: 0.19388342555209157, Reblur Loss: 0.03323580622189231
15 | 2024-02-21 12:05:22,875 - INFO - EPOCH 12/151: Total Train Loss: 0.2145585086825606, Tea Loss: 0.1866386016080906, Reblur Loss: 0.027919907808239327
16 | 2024-02-21 12:09:10,451 - INFO - EPOCH 13/151: Total Train Loss: 0.21047288682553675, Tea Loss: 0.18185828832578865, Reblur Loss: 0.028614598761808562
17 | 2024-02-21 12:12:58,052 - INFO - EPOCH 14/151: Total Train Loss: 0.21015932614153082, Tea Loss: 0.1806152764207873, Reblur Loss: 0.029544049349827167
18 | 2024-02-21 12:16:45,702 - INFO - EPOCH 15/151: Total Train Loss: 0.20325037359675288, Tea Loss: 0.17644172584339654, Reblur Loss: 0.026808647801736734
19 | 2024-02-21 12:22:07,721 - INFO - EPOCH 16/151: Total Train Loss: 0.1940440669720307, Tea Loss: 0.17125921996382923, Reblur Loss: 0.022784847088835457
20 | 2024-02-21 12:25:55,181 - INFO - EPOCH 17/151: Total Train Loss: 0.1925107462581618, Tea Loss: 0.1693128842167008, Reblur Loss: 0.02319786160200576
21 | 2024-02-21 12:29:42,707 - INFO - EPOCH 18/151: Total Train Loss: 0.19356717943371116, Tea Loss: 0.16971592559958948, Reblur Loss: 0.023851253963136053
22 | 2024-02-21 12:33:30,282 - INFO - EPOCH 19/151: Total Train Loss: 0.18988661661550596, Tea Loss: 0.16595214095724609, Reblur Loss: 0.023934475448611495
23 | 2024-02-21 12:37:17,834 - INFO - EPOCH 20/151: Total Train Loss: 0.1869647025545954, Tea Loss: 0.16441954630406905, Reblur Loss: 0.022545156766583908
24 | 2024-02-21 12:42:39,842 - INFO - EPOCH 21/151: Total Train Loss: 0.18639814550484413, Tea Loss: 0.16518269585711615, Reblur Loss: 0.021215449617490237
25 | 2024-02-21 12:46:28,692 - INFO - EPOCH 22/151: Total Train Loss: 0.18183056739243594, Tea Loss: 0.1610855016034919, Reblur Loss: 0.02074506588973653
26 | 2024-02-21 12:50:16,389 - INFO - EPOCH 23/151: Total Train Loss: 0.1875665687766426, Tea Loss: 0.16315624221062763, Reblur Loss: 0.024410326541824776
27 | 2024-02-21 12:54:04,005 - INFO - EPOCH 24/151: Total Train Loss: 0.1784304248177128, Tea Loss: 0.15827988965428752, Reblur Loss: 0.02015053503441088
28 | 2024-02-21 12:57:51,563 - INFO - EPOCH 25/151: Total Train Loss: 0.17520826700187864, Tea Loss: 0.15595671118466886, Reblur Loss: 0.019251555817209796
29 | 2024-02-21 13:03:13,996 - INFO - EPOCH 26/151: Total Train Loss: 0.176363579767607, Tea Loss: 0.1570910235375037, Reblur Loss: 0.019272556786477824
30 | 2024-02-21 13:07:01,638 - INFO - EPOCH 27/151: Total Train Loss: 0.1756473018493487, Tea Loss: 0.15534867197920232, Reblur Loss: 0.020298629934653575
31 | 2024-02-21 13:10:49,271 - INFO - EPOCH 28/151: Total Train Loss: 0.17550866305828094, Tea Loss: 0.15469595496401642, Reblur Loss: 0.02081270805797923
32 | 2024-02-21 13:14:36,833 - INFO - EPOCH 29/151: Total Train Loss: 0.1699268434728895, Tea Loss: 0.15181923270612568, Reblur Loss: 0.01810761091996839
33 | 2024-02-21 13:18:24,393 - INFO - EPOCH 30/151: Total Train Loss: 0.16909387423878625, Tea Loss: 0.15078474823143576, Reblur Loss: 0.018309126261347557
34 | 2024-02-21 13:23:46,279 - INFO - EPOCH 31/151: Total Train Loss: 0.17189460425149827, Tea Loss: 0.15200508085938244, Reblur Loss: 0.019889523835602777
35 | 2024-02-21 13:27:34,205 - INFO - EPOCH 32/151: Total Train Loss: 0.16703752579885128, Tea Loss: 0.14821302036057302, Reblur Loss: 0.01882450535764426
36 | 2024-02-21 13:31:23,375 - INFO - EPOCH 33/151: Total Train Loss: 0.16504471545869653, Tea Loss: 0.14643452690664308, Reblur Loss: 0.018610188056154428
37 | 2024-02-21 13:35:11,038 - INFO - EPOCH 34/151: Total Train Loss: 0.1643026904626326, Tea Loss: 0.1461101443994613, Reblur Loss: 0.01819254650665826
38 | 2024-02-21 13:38:58,722 - INFO - EPOCH 35/151: Total Train Loss: 0.1635446385581256, Tea Loss: 0.14505956334856165, Reblur Loss: 0.01848507593123815
39 | 2024-02-21 13:44:22,901 - INFO - EPOCH 36/151: Total Train Loss: 0.1661887692682671, Tea Loss: 0.14559309490960398, Reblur Loss: 0.020595673975651655
40 | 2024-02-21 13:48:10,525 - INFO - EPOCH 37/151: Total Train Loss: 0.16147938199760595, Tea Loss: 0.1428786571020688, Reblur Loss: 0.018600724242401844
41 | 2024-02-21 13:51:58,155 - INFO - EPOCH 38/151: Total Train Loss: 0.1581530040734774, Tea Loss: 0.14090757004239343, Reblur Loss: 0.01724543399076699
42 | 2024-02-21 13:55:45,762 - INFO - EPOCH 39/151: Total Train Loss: 0.1634903107628678, Tea Loss: 0.1447599632786466, Reblur Loss: 0.0187303473995555
43 | 2024-02-21 13:59:33,394 - INFO - EPOCH 40/151: Total Train Loss: 0.16260723298762267, Tea Loss: 0.14292728078442735, Reblur Loss: 0.01967995221125873
44 | 2024-02-21 14:04:56,610 - INFO - EPOCH 41/151: Total Train Loss: 0.15650884378137012, Tea Loss: 0.13880101101093995, Reblur Loss: 0.017707833064744223
45 | 2024-02-21 14:08:44,278 - INFO - EPOCH 42/151: Total Train Loss: 0.1539752763651666, Tea Loss: 0.13657345406678847, Reblur Loss: 0.017401822056476172
46 | 2024-02-21 14:12:31,886 - INFO - EPOCH 43/151: Total Train Loss: 0.15841850093303822, Tea Loss: 0.14041390366755524, Reblur Loss: 0.018004597213070888
47 | 2024-02-21 14:16:19,511 - INFO - EPOCH 44/151: Total Train Loss: 0.15566800070273412, Tea Loss: 0.1376432954813495, Reblur Loss: 0.01802470464888331
48 | 2024-02-21 14:20:07,092 - INFO - EPOCH 45/151: Total Train Loss: 0.15683167250383467, Tea Loss: 0.1378219990761249, Reblur Loss: 0.019009673439804867
49 | 2024-02-21 14:25:30,825 - INFO - EPOCH 46/151: Total Train Loss: 0.15287813766823186, Tea Loss: 0.13571227503158315, Reblur Loss: 0.017165863527654313
50 | 2024-02-21 14:29:18,514 - INFO - EPOCH 47/151: Total Train Loss: 0.15194799763944758, Tea Loss: 0.13455750331992195, Reblur Loss: 0.017390494650124988
51 | 2024-02-21 14:33:06,217 - INFO - EPOCH 48/151: Total Train Loss: 0.1535301176565034, Tea Loss: 0.13565615332358844, Reblur Loss: 0.01787396442967576
52 | 2024-02-21 14:36:53,853 - INFO - EPOCH 49/151: Total Train Loss: 0.1541611450962174, Tea Loss: 0.13536655722242413, Reblur Loss: 0.0187945876278596
53 | 2024-02-21 14:40:41,541 - INFO - EPOCH 50/151: Total Train Loss: 0.15028684750779883, Tea Loss: 0.13304561966812456, Reblur Loss: 0.017241228154146825
54 | 2024-02-21 14:53:15,101 - INFO - SDM: mse: 0.010 ssim: 0.732 psnr: 26.362 lpips: 0.259
55 | 2024-02-21 14:53:15,106 - INFO - Deblur_Net: mse: 0.008 ssim: 0.776 psnr: 27.586 lpips: 0.264
56 | 2024-02-21 14:57:02,284 - INFO - EPOCH 51/151: Total Train Loss: 0.15076929337276526, Tea Loss: 0.1338261765835089, Reblur Loss: 0.016943117418201453
57 | 2024-02-21 15:00:49,880 - INFO - EPOCH 52/151: Total Train Loss: 0.15185928489867742, Tea Loss: 0.13381835553811225, Reblur Loss: 0.018040929525864845
58 | 2024-02-21 15:04:37,557 - INFO - EPOCH 53/151: Total Train Loss: 0.14984404589190628, Tea Loss: 0.1322726078363724, Reblur Loss: 0.017571437765251507
59 | 2024-02-21 15:08:25,127 - INFO - EPOCH 54/151: Total Train Loss: 0.14938556348090565, Tea Loss: 0.13165730376651277, Reblur Loss: 0.017728259468459205
60 | 2024-02-21 15:12:12,710 - INFO - EPOCH 55/151: Total Train Loss: 0.1466710292996266, Tea Loss: 0.1298487519288992, Reblur Loss: 0.016822277068349965
61 | 2024-02-21 15:17:35,712 - INFO - EPOCH 56/151: Total Train Loss: 0.1474687285818063, Tea Loss: 0.13019359347237136, Reblur Loss: 0.017275134859469545
62 | 2024-02-21 15:21:23,447 - INFO - EPOCH 57/151: Total Train Loss: 0.1507295116588667, Tea Loss: 0.13240870517311673, Reblur Loss: 0.018320806296260067
63 | 2024-02-21 15:25:11,137 - INFO - EPOCH 58/151: Total Train Loss: 0.14794867637353543, Tea Loss: 0.13034166179565124, Reblur Loss: 0.01760701478350085
64 | 2024-02-21 15:28:58,859 - INFO - EPOCH 59/151: Total Train Loss: 0.1444331771690092, Tea Loss: 0.12768478304534764, Reblur Loss: 0.0167483939019181
65 | 2024-02-21 15:32:46,617 - INFO - EPOCH 60/151: Total Train Loss: 0.14571062539950078, Tea Loss: 0.12870393090304874, Reblur Loss: 0.017006695077016756
66 | 2024-02-21 15:38:09,601 - INFO - EPOCH 61/151: Total Train Loss: 0.15109377325355233, Tea Loss: 0.1326590437458191, Reblur Loss: 0.018434729588367206
67 | 2024-02-21 15:41:57,392 - INFO - EPOCH 62/151: Total Train Loss: 0.1454399266800323, Tea Loss: 0.1282080735737111, Reblur Loss: 0.017231853106321193
68 | 2024-02-21 15:45:45,130 - INFO - EPOCH 63/151: Total Train Loss: 0.1435814949147629, Tea Loss: 0.12696699456219032, Reblur Loss: 0.01661450036466767
69 | 2024-02-21 15:49:32,973 - INFO - EPOCH 64/151: Total Train Loss: 0.14347542728825566, Tea Loss: 0.12685335059831668, Reblur Loss: 0.016622076875397136
70 | 2024-02-21 15:53:20,998 - INFO - EPOCH 65/151: Total Train Loss: 0.14517085844433153, Tea Loss: 0.12781795230391738, Reblur Loss: 0.01735290601946317
71 | 2024-02-21 15:58:43,330 - INFO - EPOCH 66/151: Total Train Loss: 0.1446716483214717, Tea Loss: 0.1272475728482911, Reblur Loss: 0.017424075803779936
72 | 2024-02-21 16:02:31,304 - INFO - EPOCH 67/151: Total Train Loss: 0.1435448806329723, Tea Loss: 0.1264306746455498, Reblur Loss: 0.017114206168848973
73 | 2024-02-21 16:06:19,216 - INFO - EPOCH 68/151: Total Train Loss: 0.14082041811762433, Tea Loss: 0.12448064908707812, Reblur Loss: 0.01633976892169033
74 | 2024-02-21 16:10:06,925 - INFO - EPOCH 69/151: Total Train Loss: 0.14659886720118584, Tea Loss: 0.12843387777155096, Reblur Loss: 0.01816498963121986
75 | 2024-02-21 16:13:54,798 - INFO - EPOCH 70/151: Total Train Loss: 0.14360976877150597, Tea Loss: 0.12647196563420357, Reblur Loss: 0.017137803302602075
76 | 2024-02-21 16:19:16,886 - INFO - EPOCH 71/151: Total Train Loss: 0.16021025964579025, Tea Loss: 0.13979781754476167, Reblur Loss: 0.02041244211715537
77 | 2024-02-21 16:23:04,451 - INFO - EPOCH 72/151: Total Train Loss: 0.1543136409350804, Tea Loss: 0.13649254250320006, Reblur Loss: 0.017821098323024455
78 | 2024-02-21 16:26:52,086 - INFO - EPOCH 73/151: Total Train Loss: 0.14720702142297448, Tea Loss: 0.1300513201203697, Reblur Loss: 0.01715570112520998
79 | 2024-02-21 16:30:39,705 - INFO - EPOCH 74/151: Total Train Loss: 0.14704671389225757, Tea Loss: 0.12937674810101976, Reblur Loss: 0.017669965980727693
80 | 2024-02-21 16:34:27,306 - INFO - EPOCH 75/151: Total Train Loss: 0.1434673515362141, Tea Loss: 0.12661278025283443, Reblur Loss: 0.016854570815702536
81 | 2024-02-21 16:39:49,710 - INFO - EPOCH 76/151: Total Train Loss: 0.1419998449293566, Tea Loss: 0.1251836247193865, Reblur Loss: 0.01681622104856359
82 | 2024-02-21 16:43:37,228 - INFO - EPOCH 77/151: Total Train Loss: 0.14403526752671122, Tea Loss: 0.1264730587085604, Reblur Loss: 0.017562208681073024
83 | 2024-02-21 16:47:24,841 - INFO - EPOCH 78/151: Total Train Loss: 0.14474546919008355, Tea Loss: 0.127241933016808, Reblur Loss: 0.017503536636920978
84 | 2024-02-21 16:51:12,997 - INFO - EPOCH 79/151: Total Train Loss: 0.14180082314974302, Tea Loss: 0.12507406135142107, Reblur Loss: 0.01672676197168502
85 | 2024-02-21 16:55:01,443 - INFO - EPOCH 80/151: Total Train Loss: 0.14336753639823946, Tea Loss: 0.12627104866556274, Reblur Loss: 0.017096487591567237
86 | 2024-02-21 17:00:25,386 - INFO - EPOCH 81/151: Total Train Loss: 0.1428917437404781, Tea Loss: 0.12610096681169616, Reblur Loss: 0.016790777166652216
87 | 2024-02-21 17:04:13,867 - INFO - EPOCH 82/151: Total Train Loss: 0.14355934830584052, Tea Loss: 0.12629533078505364, Reblur Loss: 0.017264017097458437
88 | 2024-02-21 17:08:02,698 - INFO - EPOCH 83/151: Total Train Loss: 0.14320857848697927, Tea Loss: 0.12600807868170016, Reblur Loss: 0.017200499970578785
89 | 2024-02-21 17:11:51,564 - INFO - EPOCH 84/151: Total Train Loss: 0.14073882974458463, Tea Loss: 0.12396387397855907, Reblur Loss: 0.016774955568472288
90 | 2024-02-21 17:15:40,399 - INFO - EPOCH 85/151: Total Train Loss: 0.14087068154053253, Tea Loss: 0.12417404866450793, Reblur Loss: 0.01669663344046254
91 | 2024-02-21 17:21:13,439 - INFO - EPOCH 86/151: Total Train Loss: 0.1422664277352296, Tea Loss: 0.12493182070456542, Reblur Loss: 0.017334606982283778
92 | 2024-02-21 17:25:01,949 - INFO - EPOCH 87/151: Total Train Loss: 0.14089638091527024, Tea Loss: 0.12385588400549703, Reblur Loss: 0.01704049695815359
93 | 2024-02-21 17:28:50,445 - INFO - EPOCH 88/151: Total Train Loss: 0.1399132420683836, Tea Loss: 0.12308049456768738, Reblur Loss: 0.016832747403935436
94 | 2024-02-21 17:32:38,902 - INFO - EPOCH 89/151: Total Train Loss: 0.13818140953650207, Tea Loss: 0.12187661856283873, Reblur Loss: 0.016304791469562364
95 | 2024-02-21 17:36:27,002 - INFO - EPOCH 90/151: Total Train Loss: 0.14573748516184942, Tea Loss: 0.12821079980888409, Reblur Loss: 0.017526685252172863
96 | 2024-02-21 17:41:49,848 - INFO - EPOCH 91/151: Total Train Loss: 0.14120482869478532, Tea Loss: 0.12441673494262613, Reblur Loss: 0.016788093941649058
97 | 2024-02-21 17:45:37,999 - INFO - EPOCH 92/151: Total Train Loss: 0.1404902297051954, Tea Loss: 0.12345530028209026, Reblur Loss: 0.017034929011871803
98 | 2024-02-21 17:49:26,217 - INFO - EPOCH 93/151: Total Train Loss: 0.13812280343079464, Tea Loss: 0.121862444397691, Reblur Loss: 0.016260358750884666
99 | 2024-02-21 17:53:14,424 - INFO - EPOCH 94/151: Total Train Loss: 0.13867980148240086, Tea Loss: 0.12204451436475242, Reblur Loss: 0.016635287343423604
100 | 2024-02-21 17:57:02,561 - INFO - EPOCH 95/151: Total Train Loss: 0.14069516056930864, Tea Loss: 0.12385144284664294, Reblur Loss: 0.016843716847786915
101 | 2024-02-21 18:02:26,331 - INFO - EPOCH 96/151: Total Train Loss: 0.13945014778024706, Tea Loss: 0.1225252967182692, Reblur Loss: 0.016924851303879832
102 | 2024-02-21 18:06:14,652 - INFO - EPOCH 97/151: Total Train Loss: 0.1372118221345918, Tea Loss: 0.12085662469332352, Reblur Loss: 0.016355197259841803
103 | 2024-02-21 18:10:02,894 - INFO - EPOCH 98/151: Total Train Loss: 0.14199044339326553, Tea Loss: 0.12498571965601538, Reblur Loss: 0.017004724116229906
104 | 2024-02-21 18:13:51,253 - INFO - EPOCH 99/151: Total Train Loss: 0.13994478705254468, Tea Loss: 0.12313165651126341, Reblur Loss: 0.016813130605788457
105 | 2024-02-21 18:17:39,473 - INFO - EPOCH 100/151: Total Train Loss: 0.1380174332237863, Tea Loss: 0.1214408864513104, Reblur Loss: 0.016576546941807258
106 | 2024-02-21 18:30:09,010 - INFO - SDM: mse: 0.010 ssim: 0.732 psnr: 26.362 lpips: 0.259
107 | 2024-02-21 18:30:09,011 - INFO - Deblur_Net: mse: 0.007 ssim: 0.786 psnr: 27.928 lpips: 0.252
108 | 2024-02-21 18:33:57,030 - INFO - EPOCH 101/151: Total Train Loss: 0.1378288431923627, Tea Loss: 0.1212237568902763, Reblur Loss: 0.01660508629805469
109 | 2024-02-21 18:37:45,769 - INFO - EPOCH 102/151: Total Train Loss: 0.13755620919264755, Tea Loss: 0.12160171765140641, Reblur Loss: 0.015954491690414036
110 | 2024-02-21 18:41:34,434 - INFO - EPOCH 103/151: Total Train Loss: 0.138257441808393, Tea Loss: 0.12145415109731418, Reblur Loss: 0.016803290505462136
111 | 2024-02-21 18:45:22,749 - INFO - EPOCH 104/151: Total Train Loss: 0.13809712418230066, Tea Loss: 0.12142424404750139, Reblur Loss: 0.016672880183179657
112 | 2024-02-21 18:49:11,083 - INFO - EPOCH 105/151: Total Train Loss: 0.13738702982664108, Tea Loss: 0.12067290224554218, Reblur Loss: 0.016714127782683868
113 | 2024-02-21 18:54:33,484 - INFO - EPOCH 106/151: Total Train Loss: 0.13537142006478783, Tea Loss: 0.11937310443425075, Reblur Loss: 0.015998315215272046
114 | 2024-02-21 18:58:21,890 - INFO - EPOCH 107/151: Total Train Loss: 0.1379812436895969, Tea Loss: 0.12124085948838816, Reblur Loss: 0.01674038420524045
115 | 2024-02-21 19:02:10,348 - INFO - EPOCH 108/151: Total Train Loss: 0.14069764654744754, Tea Loss: 0.12356308567059504, Reblur Loss: 0.01713456069542603
116 | 2024-02-21 19:05:58,802 - INFO - EPOCH 109/151: Total Train Loss: 0.138514020903544, Tea Loss: 0.12166547230182788, Reblur Loss: 0.016848548738793892
117 | 2024-02-21 19:09:47,265 - INFO - EPOCH 110/151: Total Train Loss: 0.13842449227581807, Tea Loss: 0.12143795095480882, Reblur Loss: 0.0169865418290034
118 | 2024-02-21 19:15:09,966 - INFO - EPOCH 111/151: Total Train Loss: 0.13819570491066227, Tea Loss: 0.1219581923572532, Reblur Loss: 0.016237512605821156
119 | 2024-02-21 19:18:58,428 - INFO - EPOCH 112/151: Total Train Loss: 0.13839785670821284, Tea Loss: 0.12155075126138085, Reblur Loss: 0.016847105438768606
120 | 2024-02-21 19:22:46,795 - INFO - EPOCH 113/151: Total Train Loss: 0.13677882712769818, Tea Loss: 0.12005357699089753, Reblur Loss: 0.016725250156959155
121 | 2024-02-21 19:26:35,302 - INFO - EPOCH 114/151: Total Train Loss: 0.13452625832645407, Tea Loss: 0.1182897954340621, Reblur Loss: 0.016236462710965505
122 | 2024-02-21 19:30:23,838 - INFO - EPOCH 115/151: Total Train Loss: 0.13718709688295017, Tea Loss: 0.12045202155907948, Reblur Loss: 0.016735075799611223
123 | 2024-02-21 19:35:46,377 - INFO - EPOCH 116/151: Total Train Loss: 0.13697202916527207, Tea Loss: 0.12038201193659853, Reblur Loss: 0.01659001698677158
124 | 2024-02-21 19:39:35,025 - INFO - EPOCH 117/151: Total Train Loss: 0.13589300969978432, Tea Loss: 0.11910017815245179, Reblur Loss: 0.016792831700537112
125 | 2024-02-21 19:43:23,577 - INFO - EPOCH 118/151: Total Train Loss: 0.13500712547467383, Tea Loss: 0.11859367100836395, Reblur Loss: 0.016413455091223314
126 | 2024-02-21 19:47:12,306 - INFO - EPOCH 119/151: Total Train Loss: 0.13317141862529697, Tea Loss: 0.1173663471877833, Reblur Loss: 0.01580507102224863
127 | 2024-02-21 19:51:04,440 - INFO - EPOCH 120/151: Total Train Loss: 0.13608860605077827, Tea Loss: 0.11919505568422796, Reblur Loss: 0.016893550374613694
128 | 2024-02-21 19:56:31,905 - INFO - EPOCH 121/151: Total Train Loss: 0.1362584302951763, Tea Loss: 0.11948918902770782, Reblur Loss: 0.016769240852203462
129 | 2024-02-21 20:00:24,714 - INFO - EPOCH 122/151: Total Train Loss: 0.13465122036732635, Tea Loss: 0.11808150193907997, Reblur Loss: 0.016569718218598015
130 | 2024-02-21 20:04:17,498 - INFO - EPOCH 123/151: Total Train Loss: 0.13412981302965254, Tea Loss: 0.11801906629945293, Reblur Loss: 0.016110746468139158
131 | 2024-02-21 20:08:10,286 - INFO - EPOCH 124/151: Total Train Loss: 0.1353475305360633, Tea Loss: 0.11868680578160595, Reblur Loss: 0.01666072496410572
132 | 2024-02-21 20:12:03,182 - INFO - EPOCH 125/151: Total Train Loss: 0.13473823502079232, Tea Loss: 0.11823303658853877, Reblur Loss: 0.01650519877091631
133 | 2024-02-21 20:17:33,181 - INFO - EPOCH 126/151: Total Train Loss: 0.13619285976731932, Tea Loss: 0.11934208008763078, Reblur Loss: 0.01685077938134278
134 | 2024-02-21 20:21:26,034 - INFO - EPOCH 127/151: Total Train Loss: 0.1337554477380984, Tea Loss: 0.11781134275775967, Reblur Loss: 0.015944105105321387
135 | 2024-02-21 20:25:18,917 - INFO - EPOCH 128/151: Total Train Loss: 0.13385922952002777, Tea Loss: 0.11756700577286931, Reblur Loss: 0.016292223755221862
136 | 2024-02-21 20:29:11,864 - INFO - EPOCH 129/151: Total Train Loss: 0.13444969645052246, Tea Loss: 0.117829749207476, Reblur Loss: 0.01661994808567164
137 | 2024-02-21 20:33:04,695 - INFO - EPOCH 130/151: Total Train Loss: 0.1344283643739048, Tea Loss: 0.11783283923095439, Reblur Loss: 0.016595525122791915
138 | 2024-02-21 20:38:33,836 - INFO - EPOCH 131/151: Total Train Loss: 0.13254633594255943, Tea Loss: 0.11647765380350543, Reblur Loss: 0.016068682638984737
139 | 2024-02-21 20:42:26,698 - INFO - EPOCH 132/151: Total Train Loss: 0.13348609583202378, Tea Loss: 0.11717449247966082, Reblur Loss: 0.0163116031548097
140 | 2024-02-21 20:46:19,622 - INFO - EPOCH 133/151: Total Train Loss: 0.13426269087698553, Tea Loss: 0.11791426914207863, Reblur Loss: 0.016348422267091222
141 | 2024-02-21 20:50:12,632 - INFO - EPOCH 134/151: Total Train Loss: 0.13518822267329023, Tea Loss: 0.11855961766206857, Reblur Loss: 0.016628604914460863
142 | 2024-02-21 20:54:05,636 - INFO - EPOCH 135/151: Total Train Loss: 0.13607895332368422, Tea Loss: 0.1191918769956151, Reblur Loss: 0.01688707603375504
143 | 2024-02-21 20:59:40,029 - INFO - EPOCH 136/151: Total Train Loss: 0.13298061241706213, Tea Loss: 0.11717816913153702, Reblur Loss: 0.01580244323311301
144 | 2024-02-21 21:03:32,881 - INFO - EPOCH 137/151: Total Train Loss: 0.13413896153628568, Tea Loss: 0.1176624967357813, Reblur Loss: 0.01647646492548706
145 | 2024-02-21 21:07:25,749 - INFO - EPOCH 138/151: Total Train Loss: 0.1341735723214748, Tea Loss: 0.11770590485045404, Reblur Loss: 0.01646766753955966
146 | 2024-02-21 21:11:18,778 - INFO - EPOCH 139/151: Total Train Loss: 0.1323508085174994, Tea Loss: 0.11597479141248769, Reblur Loss: 0.01637601697196563
147 | 2024-02-21 21:15:11,637 - INFO - EPOCH 140/151: Total Train Loss: 0.1313024704784026, Tea Loss: 0.11536454802854752, Reblur Loss: 0.015937922744169122
148 | 2024-02-21 21:20:45,685 - INFO - EPOCH 141/151: Total Train Loss: 0.14401600925953356, Tea Loss: 0.12672924763196475, Reblur Loss: 0.017286761490491045
149 | 2024-02-21 21:24:38,675 - INFO - EPOCH 142/151: Total Train Loss: 0.138105887871284, Tea Loss: 0.1207235726120668, Reblur Loss: 0.017382314892332534
150 | 2024-02-21 21:28:31,715 - INFO - EPOCH 143/151: Total Train Loss: 0.13395060392427238, Tea Loss: 0.1174174669436562, Reblur Loss: 0.016533136859665187
151 | 2024-02-21 21:32:24,643 - INFO - EPOCH 144/151: Total Train Loss: 0.13170721843129113, Tea Loss: 0.11586916924039006, Reblur Loss: 0.015838049380390934
152 | 2024-02-21 21:36:17,628 - INFO - EPOCH 145/151: Total Train Loss: 0.13255383874172771, Tea Loss: 0.11643216858952593, Reblur Loss: 0.01612167028524788
153 | 2024-02-21 21:41:53,258 - INFO - EPOCH 146/151: Total Train Loss: 0.13475415675025998, Tea Loss: 0.11823435321256712, Reblur Loss: 0.01651980377153143
154 | 2024-02-21 21:45:46,174 - INFO - EPOCH 147/151: Total Train Loss: 0.13381607143522858, Tea Loss: 0.11708475265539053, Reblur Loss: 0.01673131895320111
155 | 2024-02-21 21:49:39,078 - INFO - EPOCH 148/151: Total Train Loss: 0.13221219717424154, Tea Loss: 0.11615527588954735, Reblur Loss: 0.01605692087346083
156 | 2024-02-21 21:53:32,073 - INFO - EPOCH 149/151: Total Train Loss: 0.13134429555434685, Tea Loss: 0.1154829551711743, Reblur Loss: 0.01586134055250393
157 | 2024-02-21 21:57:25,062 - INFO - EPOCH 150/151: Total Train Loss: 0.13437037873061705, Tea Loss: 0.11755732788797064, Reblur Loss: 0.01681305057655423
158 | 2024-02-21 22:13:18,640 - INFO - SDM: mse: 0.010 ssim: 0.732 psnr: 26.362 lpips: 0.259
159 | 2024-02-21 22:13:18,642 - INFO - Deblur_Net: mse: 0.007 ssim: 0.785 psnr: 27.902 lpips: 0.252
160 |
--------------------------------------------------------------------------------
/log/SR/opt.txt:
--------------------------------------------------------------------------------
1 | base_folder: GOPRO/
2 | save_folder: exp/SR
3 | data_type: GOPRO
4 | exp_name: NEW_GOPRO_9_bsn1000
5 | bsn_path: exp/BSN/NEW_GOPRO_9_test_full/ckpts/BSN_1000.pth
6 | sr_path: exp/SR/NEW_GOPRO_9/ckpts/SR_100.pth
7 | epochs: 101
8 | lr: 0.0002
9 | seed: 42
10 | bsn_len: 9
11 | use_small: False
12 | test_mode: False
13 |
--------------------------------------------------------------------------------
/log/SR/test_log.txt:
--------------------------------------------------------------------------------
1 | 2024-03-15 11:42:47,789 - INFO - Namespace(base_folder='GOPRO', bsn_len=9, bsn_path='model/BSN_1000.pth', data_type='GOPRO', epochs=101, exp_name='test', lr=0.0002, save_folder='exp/SR', seed=42, sr_path='model/SR_70.pth', test_mode=True, use_small=False)
2 | 2024-03-15 11:42:48,131 - INFO - Start Training!
3 | 2024-03-15 11:42:48,131 - INFO - EPOCH 0/101: Train Loss: 0
4 | 2024-03-15 12:09:24,118 - INFO - SDM: ssim: 0.732 psnr: 26.362
5 | 2024-03-15 12:09:24,119 - INFO - Blur_SR: ssim: 0.945 psnr: 37.286
6 | 2024-03-15 12:09:24,120 - INFO - Blur_Resize: ssim: 0.855 psnr: 31.124
7 | 2024-03-15 12:09:24,120 - INFO - BSN_SR: ssim: 0.708 psnr: 26.144
8 | 2024-03-15 12:09:24,121 - INFO - BSN_Resize: ssim: 0.645 psnr: 24.490
9 | 2024-03-15 12:09:24,121 - INFO - TFP_Resize: ssim: 0.461 psnr: 22.943
10 |
--------------------------------------------------------------------------------
/log/SR/train_log.txt:
--------------------------------------------------------------------------------
1 | 2024-02-18 21:44:09,917 - INFO - Namespace(base_folder='GOPRO/', save_folder='exp/SR', data_type='GOPRO', exp_name='NEW_GOPRO_9_bsn1000', bsn_path='exp/BSN/NEW_GOPRO_9_test_full/ckpts/BSN_1000.pth', sr_path='exp/SR/NEW_GOPRO_9/ckpts/SR_100.pth', epochs=101, lr=0.0002, seed=42, bsn_len=9, use_small=False, test_mode=False)
2 | 2024-02-18 21:44:11,448 - INFO - Start Training!
3 | 2024-02-18 21:44:35,386 - INFO - EPOCH 0/101: Train Loss: 0.0014211533537479238
4 | 2024-02-18 22:04:49,354 - INFO - SDM: mse: 0.010 ssim: 0.717 psnr: 26.293 lpips: 0.287
5 | 2024-02-18 22:04:49,359 - INFO - Blur_SR: mse: 0.002 ssim: 0.918 psnr: 34.995 lpips: 0.187
6 | 2024-02-18 22:04:49,359 - INFO - Blur_Resize: mse: 0.004 ssim: 0.855 psnr: 31.124 lpips: 0.433
7 | 2024-02-18 22:04:49,359 - INFO - BSN_SR: mse: 0.012 ssim: 0.686 psnr: 25.692 lpips: 0.391
8 | 2024-02-18 22:04:49,360 - INFO - BSN_Resize: mse: 0.016 ssim: 0.645 psnr: 24.490 lpips: 0.411
9 | 2024-02-18 22:04:49,360 - INFO - TFP_Resize: mse: 0.021 ssim: 0.461 psnr: 22.943 lpips: 0.567
10 | 2024-02-18 22:05:10,622 - INFO - EPOCH 1/101: Train Loss: 0.00045315917495394214
11 | 2024-02-18 22:05:30,408 - INFO - EPOCH 2/101: Train Loss: 0.0004365248677423587
12 | 2024-02-18 22:05:49,968 - INFO - EPOCH 3/101: Train Loss: 0.0004334324128972165
13 | 2024-02-18 22:06:10,202 - INFO - EPOCH 4/101: Train Loss: 0.0004208742305651525
14 | 2024-02-18 22:06:30,051 - INFO - EPOCH 5/101: Train Loss: 0.000399760951347764
15 | 2024-02-18 22:08:16,945 - INFO - EPOCH 6/101: Train Loss: 0.00040361383598871346
16 | 2024-02-18 22:08:36,443 - INFO - EPOCH 7/101: Train Loss: 0.00039817116557024784
17 | 2024-02-18 22:08:56,198 - INFO - EPOCH 8/101: Train Loss: 0.0003857492040631107
18 | 2024-02-18 22:09:16,297 - INFO - EPOCH 9/101: Train Loss: 0.000374428103321375
19 | 2024-02-18 22:09:35,824 - INFO - EPOCH 10/101: Train Loss: 0.00038021592708940524
20 | 2024-02-18 22:29:48,360 - INFO - SDM: mse: 0.010 ssim: 0.731 psnr: 26.402 lpips: 0.268
21 | 2024-02-18 22:29:48,365 - INFO - Blur_SR: mse: 0.002 ssim: 0.933 psnr: 36.117 lpips: 0.141
22 | 2024-02-18 22:29:48,366 - INFO - Blur_Resize: mse: 0.004 ssim: 0.855 psnr: 31.124 lpips: 0.433
23 | 2024-02-18 22:29:48,366 - INFO - BSN_SR: mse: 0.011 ssim: 0.700 psnr: 25.961 lpips: 0.372
24 | 2024-02-18 22:29:48,367 - INFO - BSN_Resize: mse: 0.016 ssim: 0.645 psnr: 24.490 lpips: 0.411
25 | 2024-02-18 22:29:48,367 - INFO - TFP_Resize: mse: 0.021 ssim: 0.461 psnr: 22.943 lpips: 0.567
26 | 2024-02-18 22:30:09,554 - INFO - EPOCH 11/101: Train Loss: 0.0003734087890005989
27 | 2024-02-18 22:30:30,197 - INFO - EPOCH 12/101: Train Loss: 0.00038408602982733267
28 | 2024-02-18 22:30:49,718 - INFO - EPOCH 13/101: Train Loss: 0.0003796412013707989
29 | 2024-02-18 22:31:08,348 - INFO - EPOCH 14/101: Train Loss: 0.0003678143284048228
30 | 2024-02-18 22:31:28,271 - INFO - EPOCH 15/101: Train Loss: 0.00036798545054953294
31 | 2024-02-18 22:33:18,243 - INFO - EPOCH 16/101: Train Loss: 0.0003596360475687689
32 | 2024-02-18 22:33:37,820 - INFO - EPOCH 17/101: Train Loss: 0.00037178791592248363
33 | 2024-02-18 22:33:57,316 - INFO - EPOCH 18/101: Train Loss: 0.00035189093822347265
34 | 2024-02-18 22:34:16,896 - INFO - EPOCH 19/101: Train Loss: 0.0003537243204476066
35 | 2024-02-18 22:34:36,585 - INFO - EPOCH 20/101: Train Loss: 0.0003496722201501761
36 | 2024-02-18 22:54:58,323 - INFO - SDM: mse: 0.011 ssim: 0.720 psnr: 26.013 lpips: 0.277
37 | 2024-02-18 22:54:58,324 - INFO - Blur_SR: mse: 0.003 ssim: 0.896 psnr: 32.358 lpips: 0.131
38 | 2024-02-18 22:54:58,324 - INFO - Blur_Resize: mse: 0.004 ssim: 0.855 psnr: 31.124 lpips: 0.433
39 | 2024-02-18 22:54:58,325 - INFO - BSN_SR: mse: 0.013 ssim: 0.684 psnr: 25.471 lpips: 0.364
40 | 2024-02-18 22:54:58,325 - INFO - BSN_Resize: mse: 0.016 ssim: 0.645 psnr: 24.490 lpips: 0.411
41 | 2024-02-18 22:54:58,325 - INFO - TFP_Resize: mse: 0.021 ssim: 0.461 psnr: 22.943 lpips: 0.567
42 | 2024-02-18 22:55:19,391 - INFO - EPOCH 21/101: Train Loss: 0.0003841655182153757
43 | 2024-02-18 22:55:39,402 - INFO - EPOCH 22/101: Train Loss: 0.000341238439826066
44 | 2024-02-18 22:55:58,865 - INFO - EPOCH 23/101: Train Loss: 0.0003420646372074826
45 | 2024-02-18 22:56:18,752 - INFO - EPOCH 24/101: Train Loss: 0.0003486189204140907
46 | 2024-02-18 22:56:38,449 - INFO - EPOCH 25/101: Train Loss: 0.00033939894858507836
47 | 2024-02-18 22:58:28,952 - INFO - EPOCH 26/101: Train Loss: 0.00033795114644290074
48 | 2024-02-18 22:58:47,653 - INFO - EPOCH 27/101: Train Loss: 0.000348002309133591
49 | 2024-02-18 22:59:06,305 - INFO - EPOCH 28/101: Train Loss: 0.0003404388012805726
50 | 2024-02-18 22:59:25,697 - INFO - EPOCH 29/101: Train Loss: 0.0003469659714658192
51 | 2024-02-18 22:59:44,931 - INFO - EPOCH 30/101: Train Loss: 0.00033793993217636755
52 | 2024-02-18 23:20:06,934 - INFO - SDM: mse: 0.010 ssim: 0.731 psnr: 26.371 lpips: 0.264
53 | 2024-02-18 23:20:06,939 - INFO - Blur_SR: mse: 0.001 ssim: 0.940 psnr: 36.720 lpips: 0.125
54 | 2024-02-18 23:20:06,940 - INFO - Blur_Resize: mse: 0.004 ssim: 0.855 psnr: 31.124 lpips: 0.433
55 | 2024-02-18 23:20:06,940 - INFO - BSN_SR: mse: 0.011 ssim: 0.704 psnr: 26.044 lpips: 0.364
56 | 2024-02-18 23:20:06,940 - INFO - BSN_Resize: mse: 0.016 ssim: 0.645 psnr: 24.490 lpips: 0.411
57 | 2024-02-18 23:20:06,940 - INFO - TFP_Resize: mse: 0.021 ssim: 0.461 psnr: 22.943 lpips: 0.567
58 | 2024-02-18 23:20:28,590 - INFO - EPOCH 31/101: Train Loss: 0.00033414859230956667
59 | 2024-02-18 23:20:48,187 - INFO - EPOCH 32/101: Train Loss: 0.0003287021171952219
60 | 2024-02-18 23:21:07,387 - INFO - EPOCH 33/101: Train Loss: 0.00032866710166495243
61 | 2024-02-18 23:21:27,734 - INFO - EPOCH 34/101: Train Loss: 0.00033242032808141814
62 | 2024-02-18 23:21:46,919 - INFO - EPOCH 35/101: Train Loss: 0.00033035686930532674
63 | 2024-02-18 23:23:34,118 - INFO - EPOCH 36/101: Train Loss: 0.00032164452114828875
64 | 2024-02-18 23:23:53,515 - INFO - EPOCH 37/101: Train Loss: 0.0003143733826748628
65 | 2024-02-18 23:24:13,141 - INFO - EPOCH 38/101: Train Loss: 0.00031801123744420085
66 | 2024-02-18 23:24:32,576 - INFO - EPOCH 39/101: Train Loss: 0.00032052480803023483
67 | 2024-02-18 23:24:51,224 - INFO - EPOCH 40/101: Train Loss: 0.00031738842848095703
68 | 2024-02-18 23:45:14,654 - INFO - SDM: mse: 0.010 ssim: 0.732 psnr: 26.372 lpips: 0.264
69 | 2024-02-18 23:45:14,656 - INFO - Blur_SR: mse: 0.001 ssim: 0.942 psnr: 36.864 lpips: 0.120
70 | 2024-02-18 23:45:14,656 - INFO - Blur_Resize: mse: 0.004 ssim: 0.855 psnr: 31.124 lpips: 0.433
71 | 2024-02-18 23:45:14,656 - INFO - BSN_SR: mse: 0.011 ssim: 0.707 psnr: 26.112 lpips: 0.359
72 | 2024-02-18 23:45:14,657 - INFO - BSN_Resize: mse: 0.016 ssim: 0.645 psnr: 24.490 lpips: 0.411
73 | 2024-02-18 23:45:14,657 - INFO - TFP_Resize: mse: 0.021 ssim: 0.461 psnr: 22.943 lpips: 0.567
74 | 2024-02-18 23:45:34,343 - INFO - EPOCH 41/101: Train Loss: 0.0003112662171033308
75 | 2024-02-18 23:45:54,126 - INFO - EPOCH 42/101: Train Loss: 0.0003128112463867171
76 | 2024-02-18 23:46:13,857 - INFO - EPOCH 43/101: Train Loss: 0.00038736918827748625
77 | 2024-02-18 23:46:33,461 - INFO - EPOCH 44/101: Train Loss: 0.000311614652130263
78 | 2024-02-18 23:46:53,385 - INFO - EPOCH 45/101: Train Loss: 0.00031099672609052187
79 | 2024-02-18 23:48:43,192 - INFO - EPOCH 46/101: Train Loss: 0.0003036633764786214
80 | 2024-02-18 23:49:02,453 - INFO - EPOCH 47/101: Train Loss: 0.0002961308712762623
81 | 2024-02-18 23:49:22,003 - INFO - EPOCH 48/101: Train Loss: 0.0003085716928529767
82 | 2024-02-18 23:49:41,660 - INFO - EPOCH 49/101: Train Loss: 0.00030397324178023144
83 | 2024-02-18 23:50:01,265 - INFO - EPOCH 50/101: Train Loss: 0.00030308918473970957
84 | 2024-02-19 00:10:45,013 - INFO - SDM: mse: 0.010 ssim: 0.731 psnr: 26.374 lpips: 0.263
85 | 2024-02-19 00:10:45,018 - INFO - Blur_SR: mse: 0.001 ssim: 0.943 psnr: 37.075 lpips: 0.114
86 | 2024-02-19 00:10:45,018 - INFO - Blur_Resize: mse: 0.004 ssim: 0.855 psnr: 31.124 lpips: 0.433
87 | 2024-02-19 00:10:45,018 - INFO - BSN_SR: mse: 0.011 ssim: 0.706 psnr: 26.114 lpips: 0.356
88 | 2024-02-19 00:10:45,018 - INFO - BSN_Resize: mse: 0.016 ssim: 0.645 psnr: 24.490 lpips: 0.411
89 | 2024-02-19 00:10:45,019 - INFO - TFP_Resize: mse: 0.021 ssim: 0.461 psnr: 22.943 lpips: 0.567
90 | 2024-02-19 00:11:05,652 - INFO - EPOCH 51/101: Train Loss: 0.00029659506737154273
91 | 2024-02-19 00:11:25,623 - INFO - EPOCH 52/101: Train Loss: 0.00030188390887992507
92 | 2024-02-19 00:11:44,842 - INFO - EPOCH 53/101: Train Loss: 0.00028545065259851226
93 | 2024-02-19 00:12:05,023 - INFO - EPOCH 54/101: Train Loss: 0.00029897224343884774
94 | 2024-02-19 00:12:25,539 - INFO - EPOCH 55/101: Train Loss: 0.00029665410318447164
95 | 2024-02-19 00:14:11,359 - INFO - EPOCH 56/101: Train Loss: 0.00029271354684500363
96 | 2024-02-19 00:14:30,650 - INFO - EPOCH 57/101: Train Loss: 0.00028587002436946914
97 | 2024-02-19 00:14:49,898 - INFO - EPOCH 58/101: Train Loss: 0.00029207505800411533
98 | 2024-02-19 00:15:09,992 - INFO - EPOCH 59/101: Train Loss: 0.000288543489447178
99 | 2024-02-19 00:15:29,946 - INFO - EPOCH 60/101: Train Loss: 0.00029507283475003066
100 | 2024-02-19 00:35:49,782 - INFO - SDM: mse: 0.010 ssim: 0.735 psnr: 26.444 lpips: 0.262
101 | 2024-02-19 00:35:49,783 - INFO - Blur_SR: mse: 0.001 ssim: 0.943 psnr: 37.077 lpips: 0.122
102 | 2024-02-19 00:35:49,783 - INFO - Blur_Resize: mse: 0.004 ssim: 0.855 psnr: 31.124 lpips: 0.433
103 | 2024-02-19 00:35:49,784 - INFO - BSN_SR: mse: 0.011 ssim: 0.707 psnr: 26.121 lpips: 0.362
104 | 2024-02-19 00:35:49,784 - INFO - BSN_Resize: mse: 0.016 ssim: 0.645 psnr: 24.490 lpips: 0.411
105 | 2024-02-19 00:35:49,784 - INFO - TFP_Resize: mse: 0.021 ssim: 0.461 psnr: 22.943 lpips: 0.567
106 | 2024-02-19 00:36:11,003 - INFO - EPOCH 61/101: Train Loss: 0.000283065656955617
107 | 2024-02-19 00:36:31,520 - INFO - EPOCH 62/101: Train Loss: 0.00029362727483120665
108 | 2024-02-19 00:36:50,988 - INFO - EPOCH 63/101: Train Loss: 0.0002813149460675034
109 | 2024-02-19 00:37:10,857 - INFO - EPOCH 64/101: Train Loss: 0.00028284162286597455
110 | 2024-02-19 00:37:30,360 - INFO - EPOCH 65/101: Train Loss: 0.00027693967154507576
111 | 2024-02-19 00:39:19,888 - INFO - EPOCH 66/101: Train Loss: 0.0002798173035526733
112 | 2024-02-19 00:39:39,484 - INFO - EPOCH 67/101: Train Loss: 0.00028095614078512126
113 | 2024-02-19 00:39:59,174 - INFO - EPOCH 68/101: Train Loss: 0.000282799996639999
114 | 2024-02-19 00:40:19,057 - INFO - EPOCH 69/101: Train Loss: 0.00027430040608245873
115 | 2024-02-19 00:40:39,029 - INFO - EPOCH 70/101: Train Loss: 0.0002796248417802244
116 | 2024-02-19 01:00:59,204 - INFO - SDM: mse: 0.010 ssim: 0.732 psnr: 26.362 lpips: 0.259
117 | 2024-02-19 01:00:59,205 - INFO - Blur_SR: mse: 0.001 ssim: 0.945 psnr: 37.286 lpips: 0.109
118 | 2024-02-19 01:00:59,206 - INFO - Blur_Resize: mse: 0.004 ssim: 0.855 psnr: 31.124 lpips: 0.433
119 | 2024-02-19 01:00:59,206 - INFO - BSN_SR: mse: 0.011 ssim: 0.708 psnr: 26.144 lpips: 0.349
120 | 2024-02-19 01:00:59,206 - INFO - BSN_Resize: mse: 0.016 ssim: 0.645 psnr: 24.490 lpips: 0.411
121 | 2024-02-19 01:00:59,206 - INFO - TFP_Resize: mse: 0.021 ssim: 0.461 psnr: 22.943 lpips: 0.567
122 | 2024-02-19 01:01:20,403 - INFO - EPOCH 71/101: Train Loss: 0.0002738387952042135
123 | 2024-02-19 01:01:39,201 - INFO - EPOCH 72/101: Train Loss: 0.0002767562307001624
124 | 2024-02-19 01:01:58,023 - INFO - EPOCH 73/101: Train Loss: 0.00027031866407467385
125 | 2024-02-19 01:02:16,770 - INFO - EPOCH 74/101: Train Loss: 0.00026943054254563903
126 | 2024-02-19 01:02:36,184 - INFO - EPOCH 75/101: Train Loss: 0.0002645513773584603
127 | 2024-02-19 01:04:26,003 - INFO - EPOCH 76/101: Train Loss: 0.0002584966490388136
128 | 2024-02-19 01:04:45,514 - INFO - EPOCH 77/101: Train Loss: 0.00027077237312121065
129 | 2024-02-19 01:05:05,113 - INFO - EPOCH 78/101: Train Loss: 0.000273223011823748
130 | 2024-02-19 01:05:23,908 - INFO - EPOCH 79/101: Train Loss: 0.00026525524041078204
131 | 2024-02-19 01:05:42,648 - INFO - EPOCH 80/101: Train Loss: 0.00026188710400430504
132 | 2024-02-19 01:26:04,975 - INFO - SDM: mse: 0.010 ssim: 0.734 psnr: 26.402 lpips: 0.261
133 | 2024-02-19 01:26:04,979 - INFO - Blur_SR: mse: 0.001 ssim: 0.945 psnr: 37.316 lpips: 0.113
134 | 2024-02-19 01:26:04,980 - INFO - Blur_Resize: mse: 0.004 ssim: 0.855 psnr: 31.124 lpips: 0.433
135 | 2024-02-19 01:26:04,980 - INFO - BSN_SR: mse: 0.011 ssim: 0.708 psnr: 26.142 lpips: 0.354
136 | 2024-02-19 01:26:04,980 - INFO - BSN_Resize: mse: 0.016 ssim: 0.645 psnr: 24.490 lpips: 0.411
137 | 2024-02-19 01:26:04,980 - INFO - TFP_Resize: mse: 0.021 ssim: 0.461 psnr: 22.943 lpips: 0.567
138 | 2024-02-19 01:26:25,886 - INFO - EPOCH 81/101: Train Loss: 0.0002638731355357056
139 | 2024-02-19 01:26:45,491 - INFO - EPOCH 82/101: Train Loss: 0.00026061319765857585
140 | 2024-02-19 01:27:04,218 - INFO - EPOCH 83/101: Train Loss: 0.0002606669278967585
141 | 2024-02-19 01:27:24,915 - INFO - EPOCH 84/101: Train Loss: 0.0002591617188584745
142 |
--------------------------------------------------------------------------------
/model/BSN_1000.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/model/BSN_1000.pth
--------------------------------------------------------------------------------
/model/DeblurNet_100.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/model/DeblurNet_100.pth
--------------------------------------------------------------------------------
/model/SR_70.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/model/SR_70.pth
--------------------------------------------------------------------------------
/scripts/GOPRO_dataset.md:
--------------------------------------------------------------------------------
1 | ## Step 1: Obtain the Original GOPRO Dataset
2 | Step into the `scripts` subfolder:
3 | ```
4 | cd scripts
5 | ```
6 |
7 | Download the [GOPRO_Large_all](https://drive.google.com/file/d/1rJTmM9_mLCNzBUUhYIGldBYgup279E_f/view) from the [GOPRO website](https://seungjunnah.github.io/Datasets/gopro) to get the sharp sequence for simulating spikes and synthesizing blurry frames. We also provide a [small GOPRO dataset](https://pan.baidu.com/s/1FGqlMFtnL5jwI39I5mNkTw?pwd=1623) for debugging. After downloading the data, rename the data file to `GOPRO` and place it in the `scripts` directory. Add a subfolder `raw_folder` in both `train` and `test` folders. The file structure is as follows:
8 | ```
9 | scripts
10 | ├── XVFI-main
11 | ├── run.sh
12 | ├── ...
13 | └── GOPRO
14 | ├── train
15 | │ └── raw_data
16 | │ ├── GOPR0372_07_00
17 | │ ├── ...
18 | │ └── GOPR0884_11_00
19 | └── test
20 | └── raw_data
21 | ├── GOPR0384_11_00
22 | ├── ...
23 | └── GOPR0881_11_01
24 | ```
25 |
26 | If you prefer not to process the following steps sequentially, you can skip them and simply run:
27 | ```
28 | bash run.sh
29 | ```
30 |
31 | ## Step 2: Frame Interpolation
32 | We use the XVFI frame interpolation algorithm to insert 7 additional imgs between two adjacent imgs, increasing the frame rate of the GOPRO image sequence, which is time-consuming and takes up a large amount of space. Run
33 |
34 | ```
35 | cd XVFI-main/
36 | python main.py --custom_path ../GOPRO/test/raw_data --gpu 0 --phase test_custom --exp_num 1 --dataset X4K1000FPS --module_scale_factor 4 --S_tst 5 --multiple 8
37 | python main.py --custom_path ../GOPRO/train/raw_data --gpu 0 --phase test_custom --exp_num 1 --dataset X4K1000FPS --module_scale_factor 4 --S_tst 5 --multiple 8
38 | cd ..
39 | ```
40 | ## Step 3: Blur Synthesis
41 | We synthesize a blurred frame using 97 images in the dataset after frame interpolation (corresponding to 13 images before interpolation). Run
42 |
43 | ```
44 | python blur_syn.py --raw_folder GOPRO/test/raw_data --overlap_len 7 --blur_num 13
45 | python blur_syn.py --raw_folder GOPRO/train/raw_data --overlap_len 7 --blur_num 13
46 | ```
47 |
48 | ❗: Please note that the output `Synthesize blurry 000006.png from 000323.png to 000335.png` is not an error. This is because we rename our output files starting from `000000.png`, which may differ slightly from the GOPRO dataset that starts with `000001.png`, `000323.png`, etc.
49 |
50 | ## Step 4: Spike Simulation
51 |
52 | We resize the image size from `720×1280` to `180×320` and apply a spike generation physical model to simulate low-resolution spikes, obtaining the spike stream corresponding to the virtual exposure time in `Step 3: Blur Synthesis`. Run
53 |
54 | ```
55 | python spike_simulate.py --raw_folder GOPRO/test/raw_data --overlap_len 7 --use_resize --blur_num 13
56 | python spike_simulate.py --raw_folder GOPRO/train/raw_data --overlap_len 7 --use_resize --blur_num 13
57 | ```
58 |
59 | Since the spike stream also contains the additional 20 spike frames out of the exposure period, the start and end segments of the spike stream cannot be utilized. (For example, `000006.png` is the first frame of the `blur_data` folder while `000019.dat` is the first spike data of the `spike_data` folder.)
60 |
61 | ## Step 5: Sharp Extract
62 |
63 | For obtaining the single sharp frame corresponding to the blurry frame:
64 |
65 | ```
66 | python sharp_extract.py --raw_folder GOPRO/test/raw_data --overlap_len 7 --blur_num 13
67 | python sharp_extract.py --raw_folder GOPRO/train/raw_data --overlap_len 7 --blur_num 13
68 | ```
69 |
70 | For obtaining the sharp sequence (7 images in this example) corresponding to the blurry frame:
71 | ```
72 | python sharp_extract.py --raw_folder GOPRO/test/raw_data --overlap_len 7 --blur_num 13 --multi
73 | python sharp_extract.py --raw_folder GOPRO/train/raw_data --overlap_len 7 --blur_num 13 --multi
74 | ```
75 |
76 | ## Step 6: Final
77 | Omit the raw_folder, the structure of the `Spike-GOPRO` dataset is as follows:
78 | ```
79 | GOPRO
80 | ├── test
81 | │ ├── blur_data
82 | │ ├── sharp_data
83 | │ └── spike_data
84 | └── train
85 | ├── blur_data
86 | ├── sharp_data
87 | └── spike_data
88 | ```
89 |
90 |
--------------------------------------------------------------------------------
/scripts/XVFI-main/(0000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/(0000.png
--------------------------------------------------------------------------------
/scripts/XVFI-main/README.md:
--------------------------------------------------------------------------------
1 | # XVFI (ICCV2021, Oral) [](https://paperswithcode.com/sota/video-frame-interpolation-on-x4k1000fps?p=xvfi-extreme-video-frame-interpolation) [](https://paperswithcode.com/sota/video-frame-interpolation-on-vimeo90k?p=xvfi-extreme-video-frame-interpolation)
2 |
3 |
4 | [](https://arxiv.org/abs/2103.16206)
5 | [](https://openaccess.thecvf.com/content/ICCV2021/papers/Sim_XVFI_eXtreme_Video_Frame_Interpolation_ICCV_2021_paper.pdf)
6 | [](https://github.com/JihyongOh/XVFI)
7 | [](https://www.youtube.com/watch?v=5qAiffYFJh8&ab_channel=VICLabKAIST)
8 | 
9 |
10 | **This is the official repository of XVFI (eXtreme Video Frame Interpolation)**
11 |
12 | \[[ArXiv_ver.](https://arxiv.org/abs/2103.16206)\] \[[ICCV2021_ver.](https://openaccess.thecvf.com/content/ICCV2021/papers/Sim_XVFI_eXtreme_Video_Frame_Interpolation_ICCV_2021_paper.pdf)\] \[[Supp.](https://openaccess.thecvf.com/content/ICCV2021/supplemental/Sim_XVFI_eXtreme_Video_ICCV_2021_supplemental.pdf)\] \[[Demo(YouTube)](https://www.youtube.com/watch?v=5qAiffYFJh8)\] \[[Oral12mins(YouTube)](https://www.youtube.com/watch?v=igwy1TJQiRc&t=13s)\] \[[Flowframes(GUI)](https://nmkd.itch.io/flowframes)\] \[[Poster](https://drive.google.com/file/d/16HtZdAKmUWLPkKR9FQX9ua-xDM3-ei6c/view?usp=sharing)\]
13 |
14 |
15 |
16 |
17 | Last Update: 20211130 - We provide extended input sequences for X-TEST. Please refer to [X4K1000FPS](#X4K1000FPS)
18 |
19 | We provide the training and test code along with the trained weights and the dataset (train+test) used for XVFI.
20 | If you find this repository useful, please consider citing our [paper](https://openaccess.thecvf.com/content/ICCV2021/papers/Sim_XVFI_eXtreme_Video_Frame_Interpolation_ICCV_2021_paper.pdf).
21 |
22 | ### Examples of the VFI (x8 Multi-Frame Interpolation) results on X-TEST
23 | 
24 | 
25 | \
26 | The 4K@30fps input frames are interpolated to be 4K@240fps frames. All results are encoded at 30fps to be played as x8 slow motion and spatially down-scaled due to the limit of file sizes. All methods are trained on X-TRAIN.
27 |
28 |
29 | ## Table of Contents
30 | 1. [X4K1000FPS](#X4K1000FPS)
31 | 1. [Requirements](#Requirements)
32 | 1. [Test](#Test)
33 | 1. [Test_Custom](#Test_Custom)
34 | 1. [Training](#Training)
35 | 1. [Collection_of_Visual_Results](#Collection_of_Visual_Results)
36 | 1. [Reference](#Reference)
37 | 1. [Contact](#Contact)
38 |
39 | ## X4K1000FPS
40 | #### Dataset of high-resolution (4096×2160), high-fps (1000fps) video frames with extreme motion.
41 |   
42 |   \
43 | Some examples of X4K1000FPS dataset, which are frames of 1000-fps and 4K-resolution. Our dataset contains the various scenes with extreme motions. (Displayed in spatiotemporally subsampled .gif files)
44 |
45 | We provide our X4K1000FPS dataset which consists of X-TEST and X-TRAIN. Please refer to our main/suppl. [paper](https://arxiv.org/abs/2103.16206) for the details of the dataset. You can download the dataset from this dropbox [link](https://www.dropbox.com/sh/duisote638etlv2/AABJw5Vygk94AWjGM4Se0Goza?dl=0).
46 |
47 | `X-TEST` consists of 15 video clips with 33-length of 4K-1000fps frames. It follows the below directory format:
48 | ```
49 | ├──── YOUR_DIR/
50 | ├──── test/
51 | ├──── Type1/
52 | ├──── TEST01/
53 | ├──── 0000.png
54 | ├──── ...
55 | └──── 0032.png
56 | ├──── TEST02/
57 | ├──── 0000.png
58 | ├──── ...
59 | └──── 0032.png
60 | ├──── ...
61 | ├──── ...
62 | ```
63 | `Extended version of X-TEST` [issue#9](https://github.com/JihyongOh/XVFI/issues/9).
64 | As described in our paper, we assume that the number of input frames for VFI is fixed to 2 in X-TEST. However, for the VFI methods that require more than 2 input frames, we provide an **extended version of X-TEST** which contains **8 input frames** (in a temporal distance of 32 frames) for each test seqeuence. The middle two adjacent frames among the 8 frames are the same input frames in the original X-TEST. To sort .png files properly by their file names, we added 1000 to the frame indices (e.g. '0000.png' and '0032.png' in the original version of X-TEST correspond to '1000.png' and '1032.png', respectively, in the extended version of X-TEST). Please note that the extended one consists of input frames only, without the ground truth intermediate frames ('1001.png'~'1031.png'). In addition, for the sequence 'TEST11_078_f4977', '1064.png', '1096.png' and '1128.png' are replicated frames since '1064.png' is the last frame of the raw video file.
65 | The **extended version of X-TEST** can be downloaded from the [link](https://www.dropbox.com/s/tjmxauo05axoi5v/png_test_8_input_frames.zip?dl=0).
66 |
67 |
68 | `X-TRAIN` consists of 4,408 clips from various types of 110 scenes. The clips are 65-length of 1000fps frames. Each frame is the size of 768x768 cropped from 4K frame. It follows the below directory format:
69 | ```
70 | ├──── YOUR_DIR/
71 | ├──── train/
72 | ├──── 002/
73 | ├──── occ008.320/
74 | ├──── 0000.png
75 | ├──── ...
76 | └──── 0064.png
77 | ├──── occ008.322/
78 | ├──── 0000.png
79 | ├──── ...
80 | └──── 0064.png
81 | ├──── ...
82 | ├──── ...
83 | ```
84 |
85 | After downloading the files from the link, decompress the `encoded_test.tar.gz` and `encoded_train.tar.gz`. The resulting .mp4 files can be decoded into .png files via running `mp4_decoding.py`. Please follow the instruction written in `mp4_decoding.py`.
86 |
87 |
88 | ## Requirements
89 | Our code is implemented using PyTorch1.7, and was tested under the following setting:
90 | * Python 3.7
91 | * PyTorch 1.7.1
92 | * CUDA 10.2
93 | * cuDNN 7.6.5
94 | * NVIDIA TITAN RTX GPU
95 | * Ubuntu 16.04 LTS
96 |
97 | **Caution**: since there is "align_corners" option in "nn.functional.interpolate" and "nn.functional.grid_sample" in PyTorch1.7, we recommend you to follow our settings.
98 | Especially, if you use the other PyTorch versions, it may lead to yield a different performance.
99 |
100 |
101 |
102 | ## Test
103 | ### Quick Start for X-TEST (x8 Multi-Frame Interpolation as in Table 2)
104 | 1. Download the source codes in a directory of your choice **\**.
105 | 2. First download our X-TEST test dataset by following the above section 'X4K1000FPS'.
106 | 3. Download the pre-trained weights, which was trained by X-TRAIN, from [this link](https://www.dropbox.com/s/xj2ixvay0e5ldma/XVFInet_X4K1000FPS_exp1_latest.pt?dl=0) to place in **\/checkpoint_dir/XVFInet_X4K1000FPS_exp1**.
107 | ```
108 | XVFI
109 | └── checkpoint_dir
110 | └── XVFInet_X4K1000FPS_exp1
111 | ├── XVFInet_X4K1000FPS_exp1_latest.pt
112 | ```
113 | 4. Run **main.py** with the following options in parse_args:
114 | ```bash
115 | python main.py --gpu 0 --phase 'test' --exp_num 1 --dataset 'X4K1000FPS' --module_scale_factor 4 --S_tst 5 --multiple 8
116 | ```
117 | ==> It would yield **(PSNR/SSIM/tOF) = (30.12/0.870/2.15)**.
118 | ```bash
119 | python main.py --gpu 0 --phase 'test' --exp_num 1 --dataset 'X4K1000FPS' --module_scale_factor 4 --S_tst 3 --multiple 8
120 | ```
121 | ==> It would yield **(PSNR/SSIM/tOF) = (28.86/0.858/2.67)**.
122 |
123 |
124 |
125 | ### Description
126 | * After running with the above test option, you can get the result images in **\/test_img_dir/XVFInet_X4K1000FPS_exp1**, then obtain the PSNR/SSIM/tOF results per each test clip as "total_metrics.csv" in the same folder.
127 | * Our proposed XVFI-Net can start from any downscaled input upward by regulating '--S_tst', which is adjustable in terms of
128 | the number of scales for inference according to the input resolutions or the motion magnitudes.
129 | * You can get any Multi-Frame Interpolation (x M) result by regulating '--multiple'.
130 |
131 |
132 |
133 | ### Quick Start for Vimeo90K (as in Fig. 8)
134 | 1. Download the source codes in a directory of your choice **\**.
135 | 2. First download Vimeo90K dataset from [this link](http://toflow.csail.mit.edu/) (including 'tri_trainlist.txt') to place in **\/vimeo_triplet**.
136 | ```
137 | XVFI
138 | └── vimeo_triplet
139 | ├── sequences
140 | readme.txt
141 | tri_testlist.txt
142 | tri_trainlist.txt
143 | ```
144 | 3. Download the pre-trained weights (XVFI-Net_v), which was trained by Vimeo90K, from [this link](https://www.dropbox.com/s/5v4dp81bto4x9xy/XVFInet_Vimeo_exp1_latest.pt?dl=0) to place in **\/checkpoint_dir/XVFInet_Vimeo_exp1**.
145 | ```
146 | XVFI
147 | └── checkpoint_dir
148 | └── XVFInet_Vimeo_exp1
149 | ├── XVFInet_Vimeo_exp1_latest.pt
150 | ```
151 | 4. Run **main.py** with the following options in parse_args:
152 | ```bash
153 | python main.py --gpu 0 --phase 'test' --exp_num 1 --dataset 'Vimeo' --module_scale_factor 2 --S_tst 1 --multiple 2
154 | ```
155 | ==> It would yield **PSNR = 35.07** on Vimeo90K.
156 |
157 | ### Description
158 | * After running with the above test option, you can get the result images in **\/test_img_dir/XVFInet_Vimeo_exp1**.
159 | * There are certain code lines in front of the 'def main()' for a convenience when running with the Vimeo option.
160 | * The SSIM result of 0.9760 as in Fig. 8 was measured by matlab ssim function for a fair comparison after running the above guide because other SOTA methods did so. We also upload "compare_psnr_ssim.m" matlab file to obtain it.
161 | * ~~It should be noted that there is a typo "S_trn
162 | and S_tst are set to 2" in the current version of XVFI paper, which should be modified to 1 (not 2), sorry for inconvenience.~~ -> Updated in the latest arXiv version.
163 |
164 | ## Test_Custom
165 | ### Quick Start for your own video data ('--custom_path') for any Multi-Frame Interpolation (x M)
166 | 1. Download the source codes in a directory of your choice **\**.
167 | 2. First prepare your own video datasets in **\/custom_path** by following a hierarchy as belows:
168 | ```
169 | XVFI
170 | └── custom_path
171 | ├── scene1
172 | ├── 'xxx.png'
173 | ├── ...
174 | └── 'xxx.png'
175 | ...
176 |
177 | ├── sceneN
178 | ├── 'xxxxx.png'
179 | ├── ...
180 | └── 'xxxxx.png'
181 |
182 | ```
183 | 3. Download the pre-trained weights trained on [X-TRAIN](#quick-start-for-x-test-x8-multi-frame-interpolation-as-in-table-2) or [Vimeo90K](#quick-start-for-vimeo90k-as-in-fig-8) as decribed above.
184 |
185 | 4. Run **main.py** with the following options in parse_args (ex) x8 Multi-Frame Interpolation):
186 | ```bash
187 | # For the model trained on X-TRAIN
188 | python main.py --gpu 0 --phase 'test_custom' --exp_num 1 --dataset 'X4K1000FPS' --module_scale_factor 4 --S_tst 5 --multiple 8 --custom_path './custom_path'
189 | ```
190 | ```bash
191 | # For the model trained on Vimeo90K
192 | python main.py --gpu 0 --phase 'test_custom' --exp_num 1 --dataset 'Vimeo' --module_scale_factor 2 --S_tst 1 --multiple 8 --custom_path './custom_path'
193 | ```
194 |
195 |
196 | ### Description
197 | * Our proposed XVFI-Net can start from any downscaled input upward by regulating '--S_tst', which is adjustable in terms of
198 | the number of scales for inference according to the input resolutions or the motion magnitudes.
199 | * You can get any Multi-Frame Interpolation (x M) result by regulating '--multiple'.
200 | * It only supports for '.png' format.
201 | * Since we can not cover diverse possibilites of naming rule for custom frames, please sort your own frames properly.
202 |
203 |
204 | ## Training
205 | ### Quick Start for X-TRAIN
206 | 1. Download the source codes in a directory of your choice **\**.
207 | 2. First download our X-TRAIN train/val/test datasets by following the above section 'X4K1000FPS' and place them as belows:
208 | ```
209 | XVFI
210 | └── X4K1000FPS
211 | ├── train
212 | ├── 002
213 | ├── ...
214 | └── 172
215 | ├── val
216 | ├── Type1
217 | ├── Type2
218 | ├── Type3
219 | ├── test
220 | ├── Type1
221 | ├── Type2
222 | ├── Type3
223 |
224 | ```
225 | 3. Run **main.py** with the following options in parse_args:
226 | ```bash
227 | python main.py --phase 'train' --exp_num 1 --dataset 'X4K1000FPS' --module_scale_factor 4 --S_trn 3 --S_tst 5
228 | ```
229 | ### Quick Start for Vimeo90K
230 | 1. Download the source codes in a directory of your choice **\**.
231 | 2. First download Vimeo90K dataset from [this link](http://toflow.csail.mit.edu/) (including 'tri_trainlist.txt') to place in **\/vimeo_triplet**.
232 | ```
233 | XVFI
234 | └── vimeo_triplet
235 | ├── sequences
236 | readme.txt
237 | tri_testlist.txt
238 | tri_trainlist.txt
239 | ```
240 | 3. Run **main.py** with the following options in parse_args:
241 | ```bash
242 | python main.py --phase 'train' --exp_num 1 --dataset 'Vimeo' --module_scale_factor 2 --S_trn 1 --S_tst 1
243 | ```
244 | ### Description
245 | * You can freely regulate other arguments in the parser of **main.py**, [here](https://github.com/JihyongOh/XVFI/blob/484bdea1448c22459b10548a488909c268e1dde9/main.py#L12-L72)
246 |
247 | ## Collection_of_Visual_Results
248 | * We also provide all visual results (x8 Multi-Frame Interpolation) on X-TEST for an easier comparison as belows. Each zip file has about 1~1.5GB.
249 | * [AdaCoF*o*](https://www.dropbox.com/s/6ivl96nwrdl7oh1/AdaCoF_final_x8%20%28pretrained%2C%20original%29.zip?dl=0), [AdaCoF*f*](https://www.dropbox.com/s/3iqwzyns0jld2xp/AdaCoF_final_x8%20Retrain.zip?dl=0), [FeFlow*o*](https://www.dropbox.com/s/ukn8acqrim5vg7b/FeFlow_final_x8%20%28pretrained%2C%20original%29.zip?dl=0), [FeFlow*f*](https://www.dropbox.com/s/q26w3c9tm455jau/FeFlow_final_x8%20Retrain.zip?dl=0), [DAIN*o*](https://www.dropbox.com/s/yjtj4tvfhs2niqq/DAIN_final_x8%20%28pretrained%2C%20original%29.zip?dl=0), [DAIN*f*](https://www.dropbox.com/s/ftvimsx4czab5z4/DAIN_final_x8%20Retrain.zip?dl=0), [XVFI-Net](https://www.dropbox.com/s/3sbjjy226njk8by/XVFI-Net_final_Strn3_Stst3.zip?dl=0) (S*tst*=3), [XVFI-Net](https://www.dropbox.com/s/dgf61z08wab3jie/XVFI-Net_final_Strn3_Stst5.zip?dl=0) (S*tst*=5)
250 | * The quantitative comparisons (Table2 and Figure5) are attached as belows for a reference.
251 | 
252 | \
253 |
254 |
255 | ## Reference
256 | > Hyeonjun Sim*, Jihyong Oh*, and Munchurl Kim "XVFI: eXtreme Video Frame Interpolation", In _ICCV_, 2021. (* *equal contribution*)
257 | >
258 | **BibTeX**
259 | ```bibtex
260 | @inproceedings{sim2021xvfi,
261 | title={XVFI: eXtreme Video Frame Interpolation},
262 | author={Sim, Hyeonjun and Oh, Jihyong and Kim, Munchurl},
263 | booktitle={Proceedings of the IEEE International Conference on Computer Vision (ICCV)},
264 | year={2021}
265 | }
266 | ```
267 |
268 |
269 |
270 | ## Contact
271 | If you have any question, please send an email to either \
272 | [[Hyeonjun Sim](https://sites.google.com/view/hjsim)] - flhy5836@kaist.ac.kr or \
273 | [[Jihyong Oh](https://sites.google.com/view/ozbro)] - jhoh94@kaist.ac.kr.
274 |
275 | ## License
276 | The source codes and datasets can be freely used for research and education only. Any commercial use should get formal permission first.
277 |
--------------------------------------------------------------------------------
/scripts/XVFI-main/XVFInet.py:
--------------------------------------------------------------------------------
1 | import functools, random
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | import numpy as np
7 |
8 |
9 | class XVFInet(nn.Module):
10 |
11 | def __init__(self, args):
12 | super(XVFInet, self).__init__()
13 | self.args = args
14 | self.device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') # will be used as "x.to(device)"
15 | self.nf = args.nf
16 | self.scale = args.module_scale_factor
17 | self.vfinet = VFInet(args)
18 | self.lrelu = nn.ReLU()
19 | self.in_channels = 3
20 | self.channel_converter = nn.Sequential(
21 | nn.Conv3d(self.in_channels, self.nf, [1, 3, 3], [1, 1, 1], [0, 1, 1]),
22 | nn.ReLU())
23 |
24 | self.rec_ext_ds_module = [self.channel_converter]
25 | self.rec_ext_ds = nn.Conv3d(self.nf, self.nf, [1, 3, 3], [1, 2, 2], [0, 1, 1])
26 | for _ in range(int(np.log2(self.scale))):
27 | self.rec_ext_ds_module.append(self.rec_ext_ds)
28 | self.rec_ext_ds_module.append(nn.ReLU())
29 | self.rec_ext_ds_module.append(nn.Conv3d(self.nf, self.nf, [1, 3, 3], 1, [0, 1, 1]))
30 | self.rec_ext_ds_module.append(RResBlock2D_3D(args, T_reduce_flag=False))
31 | self.rec_ext_ds_module = nn.Sequential(*self.rec_ext_ds_module)
32 |
33 | self.rec_ctx_ds = nn.Conv3d(self.nf, self.nf, [1, 3, 3], [1, 2, 2], [0, 1, 1])
34 |
35 | print("The lowest scale depth for training (S_trn): ", self.args.S_trn)
36 | print("The lowest scale depth for test (S_tst): ", self.args.S_tst)
37 |
38 | def forward(self, x, t_value, is_training=True):
39 | '''
40 | x shape : [B,C,T,H,W]
41 | t_value shape : [B,1] ###############
42 | '''
43 | B, C, T, H, W = x.size()
44 | B2, C2 = t_value.size()
45 | assert C2 == 1, "t_value shape is [B,]"
46 | assert T % 2 == 0, "T must be an even number"
47 | t_value = t_value.view(B, 1, 1, 1)
48 |
49 | flow_l = None
50 | feat_x = self.rec_ext_ds_module(x)
51 | feat_x_list = [feat_x]
52 | self.lowest_depth_level = self.args.S_trn if is_training else self.args.S_tst
53 | for level in range(1, self.lowest_depth_level+1):
54 | feat_x = self.rec_ctx_ds(feat_x)
55 | feat_x_list.append(feat_x)
56 |
57 | if is_training:
58 | out_l_list = []
59 | flow_refine_l_list = []
60 | out_l, flow_l, flow_refine_l = self.vfinet(x, feat_x_list[self.args.S_trn], flow_l, t_value, level=self.args.S_trn, is_training=True)
61 | out_l_list.append(out_l)
62 | flow_refine_l_list.append(flow_refine_l)
63 | for level in range(self.args.S_trn-1, 0, -1): ## self.args.S_trn, self.args.S_trn-1, ..., 1. level 0 is not included
64 | out_l, flow_l = self.vfinet(x, feat_x_list[level], flow_l, t_value, level=level, is_training=True)
65 | out_l_list.append(out_l)
66 | out_l, flow_l, flow_refine_l, occ_0_l0 = self.vfinet(x, feat_x_list[0], flow_l, t_value, level=0, is_training=True)
67 | out_l_list.append(out_l)
68 | flow_refine_l_list.append(flow_refine_l)
69 | return out_l_list[::-1], flow_refine_l_list[::-1], occ_0_l0, torch.mean(x, dim=2) # out_l_list should be reversed. [out_l0, out_l1, ...]
70 |
71 | else: # Testing
72 | for level in range(self.args.S_tst, 0, -1): ## self.args.S_tst, self.args.S_tst-1, ..., 1. level 0 is not included
73 | flow_l = self.vfinet(x, feat_x_list[level], flow_l, t_value, level=level, is_training=False)
74 | out_l = self.vfinet(x, feat_x_list[0], flow_l, t_value, level=0, is_training=False)
75 | return out_l
76 |
77 |
78 | class VFInet(nn.Module):
79 |
80 | def __init__(self, args):
81 | super(VFInet, self).__init__()
82 | self.args = args
83 | self.device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') # will be used as "x.to(device)"
84 | self.nf = args.nf
85 | self.scale = args.module_scale_factor
86 | self.in_channels = 3
87 |
88 | self.conv_flow_bottom = nn.Sequential(
89 | nn.Conv2d(2*self.nf, 2*self.nf, [4,4], 2, [1,1]),
90 | nn.ReLU(),
91 | nn.Conv2d(2*self.nf, 4*self.nf, [4,4], 2, [1,1]),
92 | nn.ReLU(),
93 | nn.UpsamplingNearest2d(scale_factor=2),
94 | nn.Conv2d(4 * self.nf, 2 * self.nf, [3, 3], 1, [1, 1]),
95 | nn.ReLU(),
96 | nn.UpsamplingNearest2d(scale_factor=2),
97 | nn.Conv2d(2 * self.nf, self.nf, [3, 3], 1, [1, 1]),
98 | nn.ReLU(),
99 | nn.Conv2d(self.nf, 6, [3,3], 1, [1,1]),
100 | )
101 |
102 | self.conv_flow1 = nn.Conv2d(2*self.nf, self.nf, [3, 3], 1, [1, 1])
103 |
104 | self.conv_flow2 = nn.Sequential(
105 | nn.Conv2d(2*self.nf + 4, 2 * self.nf, [4, 4], 2, [1, 1]),
106 | nn.ReLU(),
107 | nn.Conv2d(2 * self.nf, 4 * self.nf, [4, 4], 2, [1, 1]),
108 | nn.ReLU(),
109 | nn.UpsamplingNearest2d(scale_factor=2),
110 | nn.Conv2d(4 * self.nf, 2 * self.nf, [3, 3], 1, [1, 1]),
111 | nn.ReLU(),
112 | nn.UpsamplingNearest2d(scale_factor=2),
113 | nn.Conv2d(2 * self.nf, self.nf, [3, 3], 1, [1, 1]),
114 | nn.ReLU(),
115 | nn.Conv2d(self.nf, 6, [3, 3], 1, [1, 1]),
116 | )
117 |
118 | self.conv_flow3 = nn.Sequential(
119 | nn.Conv2d(4 + self.nf * 4, self.nf, [1, 1], 1, [0, 0]),
120 | nn.ReLU(),
121 | nn.Conv2d(self.nf, 2 * self.nf, [4, 4], 2, [1, 1]),
122 | nn.ReLU(),
123 | nn.Conv2d(2 * self.nf, 4 * self.nf, [4, 4], 2, [1, 1]),
124 | nn.ReLU(),
125 | nn.UpsamplingNearest2d(scale_factor=2),
126 | nn.Conv2d(4 * self.nf, 2 * self.nf, [3, 3], 1, [1, 1]),
127 | nn.ReLU(),
128 | nn.UpsamplingNearest2d(scale_factor=2),
129 | nn.Conv2d(2 * self.nf, self.nf, [3, 3], 1, [1, 1]),
130 | nn.ReLU(),
131 | nn.Conv2d(self.nf, 4, [3, 3], 1, [1, 1]),
132 | )
133 |
134 | self.refine_unet = RefineUNet(args)
135 | self.lrelu = nn.ReLU()
136 |
137 | def forward(self, x, feat_x, flow_l_prev, t_value, level, is_training):
138 | '''
139 | x shape : [B,C,T,H,W]
140 | t_value shape : [B,1] ###############
141 | '''
142 | B, C, T, H, W = x.size()
143 | assert T % 2 == 0, "T must be an even number"
144 |
145 | ####################### For a single level
146 | l = 2 ** level
147 | x_l = x.permute(0,2,1,3,4)
148 | x_l = x_l.contiguous().view(B * T, C, H, W)
149 |
150 | if level == 0:
151 | pass
152 | else:
153 | x_l = F.interpolate(x_l, scale_factor=(1.0 / l, 1.0 / l), mode='bicubic', align_corners=False)
154 | '''
155 | Down pixel-shuffle
156 | '''
157 | x_l = x_l.view(B, T, C, H//l, W//l)
158 | x_l = x_l.permute(0,2,1,3,4)
159 |
160 | B, C, T, H, W = x_l.size()
161 |
162 | ## Feature extraction
163 | feat0_l = feat_x[:,:,0,:,:]
164 | feat1_l = feat_x[:,:,1,:,:]
165 |
166 | ## Flow estimation
167 | if flow_l_prev is None:
168 | flow_l_tmp = self.conv_flow_bottom(torch.cat((feat0_l, feat1_l), dim=1))
169 | flow_l = flow_l_tmp[:,:4,:,:]
170 | else:
171 | up_flow_l_prev = 2.0*F.interpolate(flow_l_prev.detach(), scale_factor=(2,2), mode='bilinear', align_corners=False)
172 | warped_feat1_l = self.bwarp(feat1_l, up_flow_l_prev[:,:2,:,:])
173 | warped_feat0_l = self.bwarp(feat0_l, up_flow_l_prev[:,2:,:,:])
174 | flow_l_tmp = self.conv_flow2(torch.cat([self.conv_flow1(torch.cat([feat0_l, warped_feat1_l],dim=1)), self.conv_flow1(torch.cat([feat1_l, warped_feat0_l],dim=1)), up_flow_l_prev],dim=1))
175 | flow_l = flow_l_tmp[:,:4,:,:] + up_flow_l_prev
176 |
177 | if not is_training and level!=0:
178 | return flow_l
179 |
180 | flow_01_l = flow_l[:,:2,:,:]
181 | flow_10_l = flow_l[:,2:,:,:]
182 | z_01_l = torch.sigmoid(flow_l_tmp[:,4:5,:,:])
183 | z_10_l = torch.sigmoid(flow_l_tmp[:,5:6,:,:])
184 |
185 | ## Complementary Flow Reversal (CFR)
186 | flow_forward, norm0_l = self.z_fwarp(flow_01_l, t_value * flow_01_l, z_01_l) ## Actually, F (t) -> (t+1). Translation only. Not normalized yet
187 | flow_backward, norm1_l = self.z_fwarp(flow_10_l, (1-t_value) * flow_10_l, z_10_l) ## Actually, F (1-t) -> (-t). Translation only. Not normalized yet
188 |
189 | flow_t0_l = -(1-t_value) * ((t_value)*flow_forward) + (t_value) * ((t_value)*flow_backward) # The numerator of Eq.(1) in the paper.
190 | flow_t1_l = (1-t_value) * ((1-t_value)*flow_forward) - (t_value) * ((1-t_value)*flow_backward) # The numerator of Eq.(2) in the paper.
191 |
192 | norm_l = (1-t_value)*norm0_l + t_value*norm1_l
193 | mask_ = (norm_l.detach() > 0).type(norm_l.type())
194 | flow_t0_l = (1-mask_) * flow_t0_l + mask_ * (flow_t0_l.clone() / (norm_l.clone() + (1-mask_))) # Divide the numerator with denominator in Eq.(1)
195 | flow_t1_l = (1-mask_) * flow_t1_l + mask_ * (flow_t1_l.clone() / (norm_l.clone() + (1-mask_))) # Divide the numerator with denominator in Eq.(2)
196 |
197 | ## Feature warping
198 | warped0_l = self.bwarp(feat0_l, flow_t0_l)
199 | warped1_l = self.bwarp(feat1_l, flow_t1_l)
200 |
201 | ## Flow refinement
202 | flow_refine_l = torch.cat([feat0_l, warped0_l, warped1_l, feat1_l, flow_t0_l, flow_t1_l], dim=1)
203 | flow_refine_l = self.conv_flow3(flow_refine_l) + torch.cat([flow_t0_l, flow_t1_l], dim=1)
204 | flow_t0_l = flow_refine_l[:, :2, :, :]
205 | flow_t1_l = flow_refine_l[:, 2:4, :, :]
206 |
207 | warped0_l = self.bwarp(feat0_l, flow_t0_l)
208 | warped1_l = self.bwarp(feat1_l, flow_t1_l)
209 |
210 | ## Flow upscale
211 | flow_t0_l = self.scale * F.interpolate(flow_t0_l, scale_factor=(self.scale, self.scale), mode='bilinear',align_corners=False)
212 | flow_t1_l = self.scale * F.interpolate(flow_t1_l, scale_factor=(self.scale, self.scale), mode='bilinear',align_corners=False)
213 |
214 | ## Image warping and blending
215 | warped_img0_l = self.bwarp(x_l[:,:,0,:,:], flow_t0_l)
216 | warped_img1_l = self.bwarp(x_l[:,:,1,:,:], flow_t1_l)
217 |
218 | refine_out = self.refine_unet(torch.cat([F.pixel_shuffle(torch.cat([feat0_l, feat1_l, warped0_l, warped1_l],dim=1), self.scale), x_l[:,:,0,:,:], x_l[:,:,1,:,:], warped_img0_l, warped_img1_l, flow_t0_l, flow_t1_l],dim=1))
219 | occ_0_l = torch.sigmoid(refine_out[:, 0:1, :, :])
220 | occ_1_l = 1-occ_0_l
221 |
222 | out_l = (1-t_value)*occ_0_l*warped_img0_l + t_value*occ_1_l*warped_img1_l
223 | out_l = out_l / ( (1-t_value)*occ_0_l + t_value*occ_1_l ) + refine_out[:, 1:4, :, :]
224 |
225 | if not is_training and level==0:
226 | return out_l
227 |
228 | if is_training:
229 | if flow_l_prev is None:
230 | # if level == self.args.S_trn:
231 | return out_l, flow_l, flow_refine_l[:, 0:4, :, :]
232 | elif level != 0:
233 | return out_l, flow_l
234 | else: # level==0
235 | return out_l, flow_l, flow_refine_l[:, 0:4, :, :], occ_0_l
236 |
237 | def bwarp(self, x, flo):
238 | '''
239 | x: [B, C, H, W] (im2)
240 | flo: [B, 2, H, W] flow
241 | '''
242 | B, C, H, W = x.size()
243 | # mesh grid
244 | xx = torch.arange(0, W).view(1, 1, 1, W).expand(B, 1, H, W)
245 | yy = torch.arange(0, H).view(1, 1, H, 1).expand(B, 1, H, W)
246 | grid = torch.cat((xx, yy), 1).float()
247 |
248 | if x.is_cuda:
249 | grid = grid.to(self.device)
250 | vgrid = torch.autograd.Variable(grid) + flo
251 |
252 | # scale grid to [-1,1]
253 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0
254 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0
255 |
256 | vgrid = vgrid.permute(0, 2, 3, 1) # [B,H,W,2]
257 | output = nn.functional.grid_sample(x, vgrid, align_corners=True)
258 | mask = torch.autograd.Variable(torch.ones(x.size())).to(self.device)
259 | mask = nn.functional.grid_sample(mask, vgrid, align_corners=True)
260 |
261 | # mask[mask<0.9999] = 0
262 | # mask[mask>0] = 1
263 | mask = mask.masked_fill_(mask < 0.999, 0)
264 | mask = mask.masked_fill_(mask > 0, 1)
265 |
266 | return output * mask
267 |
268 | def fwarp(self, img, flo):
269 |
270 | """
271 | -img: image (N, C, H, W)
272 | -flo: optical flow (N, 2, H, W)
273 | elements of flo is in [0, H] and [0, W] for dx, dy
274 | https://github.com/lyh-18/EQVI/blob/EQVI-master/models/forward_warp_gaussian.py
275 | """
276 |
277 | # (x1, y1) (x1, y2)
278 | # +---------------+
279 | # | |
280 | # | o(x, y) |
281 | # | |
282 | # | |
283 | # | |
284 | # | |
285 | # +---------------+
286 | # (x2, y1) (x2, y2)
287 |
288 | N, C, _, _ = img.size()
289 |
290 | # translate start-point optical flow to end-point optical flow
291 | y = flo[:, 0:1:, :]
292 | x = flo[:, 1:2, :, :]
293 |
294 | x = x.repeat(1, C, 1, 1)
295 | y = y.repeat(1, C, 1, 1)
296 |
297 | # Four point of square (x1, y1), (x1, y2), (x2, y1), (y2, y2)
298 | x1 = torch.floor(x)
299 | x2 = x1 + 1
300 | y1 = torch.floor(y)
301 | y2 = y1 + 1
302 |
303 | # firstly, get gaussian weights
304 | w11, w12, w21, w22 = self.get_gaussian_weights(x, y, x1, x2, y1, y2)
305 |
306 | # secondly, sample each weighted corner
307 | img11, o11 = self.sample_one(img, x1, y1, w11)
308 | img12, o12 = self.sample_one(img, x1, y2, w12)
309 | img21, o21 = self.sample_one(img, x2, y1, w21)
310 | img22, o22 = self.sample_one(img, x2, y2, w22)
311 |
312 | imgw = img11 + img12 + img21 + img22
313 | o = o11 + o12 + o21 + o22
314 |
315 | return imgw, o
316 |
317 |
318 | def z_fwarp(self, img, flo, z):
319 | """
320 | -img: image (N, C, H, W)
321 | -flo: optical flow (N, 2, H, W)
322 | elements of flo is in [0, H] and [0, W] for dx, dy
323 | modified from https://github.com/lyh-18/EQVI/blob/EQVI-master/models/forward_warp_gaussian.py
324 | """
325 |
326 | # (x1, y1) (x1, y2)
327 | # +---------------+
328 | # | |
329 | # | o(x, y) |
330 | # | |
331 | # | |
332 | # | |
333 | # | |
334 | # +---------------+
335 | # (x2, y1) (x2, y2)
336 |
337 | N, C, _, _ = img.size()
338 |
339 | # translate start-point optical flow to end-point optical flow
340 | y = flo[:, 0:1:, :]
341 | x = flo[:, 1:2, :, :]
342 |
343 | x = x.repeat(1, C, 1, 1)
344 | y = y.repeat(1, C, 1, 1)
345 |
346 | # Four point of square (x1, y1), (x1, y2), (x2, y1), (y2, y2)
347 | x1 = torch.floor(x)
348 | x2 = x1 + 1
349 | y1 = torch.floor(y)
350 | y2 = y1 + 1
351 |
352 | # firstly, get gaussian weights
353 | w11, w12, w21, w22 = self.get_gaussian_weights(x, y, x1, x2, y1, y2, z+1e-5)
354 |
355 | # secondly, sample each weighted corner
356 | img11, o11 = self.sample_one(img, x1, y1, w11)
357 | img12, o12 = self.sample_one(img, x1, y2, w12)
358 | img21, o21 = self.sample_one(img, x2, y1, w21)
359 | img22, o22 = self.sample_one(img, x2, y2, w22)
360 |
361 | imgw = img11 + img12 + img21 + img22
362 | o = o11 + o12 + o21 + o22
363 |
364 | return imgw, o
365 |
366 |
367 | def get_gaussian_weights(self, x, y, x1, x2, y1, y2, z=1.0):
368 | # z 0.0 ~ 1.0
369 | w11 = z * torch.exp(-((x - x1) ** 2 + (y - y1) ** 2))
370 | w12 = z * torch.exp(-((x - x1) ** 2 + (y - y2) ** 2))
371 | w21 = z * torch.exp(-((x - x2) ** 2 + (y - y1) ** 2))
372 | w22 = z * torch.exp(-((x - x2) ** 2 + (y - y2) ** 2))
373 |
374 | return w11, w12, w21, w22
375 |
376 | def sample_one(self, img, shiftx, shifty, weight):
377 | """
378 | Input:
379 | -img (N, C, H, W)
380 | -shiftx, shifty (N, c, H, W)
381 | """
382 |
383 | N, C, H, W = img.size()
384 |
385 | # flatten all (all restored as Tensors)
386 | flat_shiftx = shiftx.view(-1)
387 | flat_shifty = shifty.view(-1)
388 | flat_basex = torch.arange(0, H, requires_grad=False).view(-1, 1)[None, None].to(self.device).long().repeat(N, C,1,W).view(-1)
389 | flat_basey = torch.arange(0, W, requires_grad=False).view(1, -1)[None, None].to(self.device).long().repeat(N, C,H,1).view(-1)
390 | flat_weight = weight.view(-1)
391 | flat_img = img.contiguous().view(-1)
392 |
393 | # The corresponding positions in I1
394 | idxn = torch.arange(0, N, requires_grad=False).view(N, 1, 1, 1).to(self.device).long().repeat(1, C, H, W).view(-1)
395 | idxc = torch.arange(0, C, requires_grad=False).view(1, C, 1, 1).to(self.device).long().repeat(N, 1, H, W).view(-1)
396 | idxx = flat_shiftx.long() + flat_basex
397 | idxy = flat_shifty.long() + flat_basey
398 |
399 | # recording the inside part the shifted
400 | mask = idxx.ge(0) & idxx.lt(H) & idxy.ge(0) & idxy.lt(W)
401 |
402 | # Mask off points out of boundaries
403 | ids = (idxn * C * H * W + idxc * H * W + idxx * W + idxy)
404 | ids_mask = torch.masked_select(ids, mask).clone().to(self.device)
405 |
406 | # Note here! accmulate fla must be true for proper bp
407 | img_warp = torch.zeros([N * C * H * W, ]).to(self.device)
408 | img_warp.put_(ids_mask, torch.masked_select(flat_img * flat_weight, mask), accumulate=True)
409 |
410 | one_warp = torch.zeros([N * C * H * W, ]).to(self.device)
411 | one_warp.put_(ids_mask, torch.masked_select(flat_weight, mask), accumulate=True)
412 |
413 | return img_warp.view(N, C, H, W), one_warp.view(N, C, H, W)
414 |
415 | class RefineUNet(nn.Module):
416 | def __init__(self, args):
417 | super(RefineUNet, self).__init__()
418 | self.args = args
419 | self.scale = args.module_scale_factor
420 | self.nf = args.nf
421 | self.conv1 = nn.Conv2d(self.nf, self.nf, [3,3], 1, [1,1])
422 | self.conv2 = nn.Conv2d(self.nf, self.nf, [3,3], 1, [1,1])
423 | self.lrelu = nn.ReLU()
424 | self.NN = nn.UpsamplingNearest2d(scale_factor=2)
425 |
426 | self.enc1 = nn.Conv2d((4*self.nf)//self.scale//self.scale + 4*args.img_ch + 4, self.nf, [4, 4], 2, [1, 1])
427 | self.enc2 = nn.Conv2d(self.nf, 2*self.nf, [4, 4], 2, [1, 1])
428 | self.enc3 = nn.Conv2d(2*self.nf, 4*self.nf, [4, 4], 2, [1, 1])
429 | self.dec0 = nn.Conv2d(4*self.nf, 4*self.nf, [3, 3], 1, [1, 1])
430 | self.dec1 = nn.Conv2d(4*self.nf + 2*self.nf, 2*self.nf, [3, 3], 1, [1, 1]) ## input concatenated with enc2
431 | self.dec2 = nn.Conv2d(2*self.nf + self.nf, self.nf, [3, 3], 1, [1, 1]) ## input concatenated with enc1
432 | self.dec3 = nn.Conv2d(self.nf, 1+args.img_ch, [3, 3], 1, [1, 1]) ## input added with warped image
433 |
434 | def forward(self, concat):
435 | enc1 = self.lrelu(self.enc1(concat))
436 | enc2 = self.lrelu(self.enc2(enc1))
437 | out = self.lrelu(self.enc3(enc2))
438 |
439 | out = self.lrelu(self.dec0(out))
440 | out = self.NN(out)
441 |
442 | out = torch.cat((out,enc2),dim=1)
443 | out = self.lrelu(self.dec1(out))
444 |
445 | out = self.NN(out)
446 | out = torch.cat((out,enc1),dim=1)
447 | out = self.lrelu(self.dec2(out))
448 |
449 | out = self.NN(out)
450 | out = self.dec3(out)
451 | return out
452 |
453 | class ResBlock2D_3D(nn.Module):
454 | ## Shape of input [B,C,T,H,W]
455 | ## Shape of output [B,C,T,H,W]
456 | def __init__(self, args):
457 | super(ResBlock2D_3D, self).__init__()
458 | self.args = args
459 | self.nf = args.nf
460 |
461 | self.conv3x3_1 = nn.Conv3d(self.nf, self.nf, [1,3,3], 1, [0,1,1])
462 | self.conv3x3_2 = nn.Conv3d(self.nf, self.nf, [1,3,3], 1, [0,1,1])
463 | self.lrelu = nn.ReLU()
464 |
465 | def forward(self, x):
466 | '''
467 | x shape : [B,C,T,H,W]
468 | '''
469 | B, C, T, H, W = x.size()
470 |
471 | out = self.conv3x3_2(self.lrelu(self.conv3x3_1(x)))
472 |
473 | return x + out
474 |
475 | class RResBlock2D_3D(nn.Module):
476 |
477 | def __init__(self, args, T_reduce_flag=False):
478 | super(RResBlock2D_3D, self).__init__()
479 | self.args = args
480 | self.nf = args.nf
481 | self.T_reduce_flag = T_reduce_flag
482 | self.resblock1 = ResBlock2D_3D(self.args)
483 | self.resblock2 = ResBlock2D_3D(self.args)
484 | if T_reduce_flag:
485 | self.reduceT_conv = nn.Conv3d(self.nf, self.nf, [3,1,1], 1, [0,0,0])
486 |
487 | def forward(self, x):
488 | '''
489 | x shape : [B,C,T,H,W]
490 | '''
491 | out = self.resblock1(x)
492 | out = self.resblock2(out)
493 | if self.T_reduce_flag:
494 | return self.reduceT_conv(out + x)
495 | else:
496 | return out + x
497 |
--------------------------------------------------------------------------------
/scripts/XVFI-main/__pycache__/XVFInet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/__pycache__/XVFInet.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/XVFI-main/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/XVFI-main/checkpoint_dir/XVFInet_Vimeo_exp1/info.txt:
--------------------------------------------------------------------------------
1 | *place the downloaded pre-trained weights 'XVFInet_Vimeo_exp1_latest.pt' in this folder*
--------------------------------------------------------------------------------
/scripts/XVFI-main/checkpoint_dir/XVFInet_X4K1000FPS_exp1/XVFInet_X4K1000FPS_exp1_latest.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/checkpoint_dir/XVFInet_X4K1000FPS_exp1/XVFInet_X4K1000FPS_exp1_latest.pt
--------------------------------------------------------------------------------
/scripts/XVFI-main/checkpoint_dir/XVFInet_X4K1000FPS_exp1/info.txt:
--------------------------------------------------------------------------------
1 | *place the downloaded pre-trained weights 'XVFInet_X4K1000FPS_exp1_latest.pt' in this folder*
--------------------------------------------------------------------------------
/scripts/XVFI-main/compare_psnr_ssim.m:
--------------------------------------------------------------------------------
1 | close all;
2 | clear;clc;
3 | fid = fopen('./vimeo_triplet/tri_testlist.txt','r');
4 | pred_path = './test_img_dir/XVFInet_Vimeo_exp1/epoch_00199_final_x2_S_tst1/';
5 |
6 | gt_path = './vimeo_triplet/sequences/';
7 | sample = fgetl(fid);
8 | cnt = 0;
9 |
10 | while ischar(sample)
11 | cnt = cnt+1;
12 |
13 | pred = imread(strcat(pred_path, sample, '/im2.png'));
14 | gt = imread(strcat(gt_path, sample, '/im2.png'));
15 |
16 |
17 | [h,w,c] = size(pred);
18 |
19 | pred_y = pred;
20 | gt_y = gt;
21 |
22 | total_psnr1(cnt) = psnr(pred, gt);
23 | total_ssim1(cnt) = ssim(pred_y, gt_y);
24 | fprintf('%s : %f, %f \n', sample, total_psnr1(cnt), total_ssim1(cnt))
25 |
26 | sample = fgetl(fid);
27 | end
28 | cnt
29 | mean(total_psnr1)
30 | mean(total_ssim1)
31 |
--------------------------------------------------------------------------------
/scripts/XVFI-main/figures/.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/scripts/XVFI-main/figures/003.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/figures/003.gif
--------------------------------------------------------------------------------
/scripts/XVFI-main/figures/004.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/figures/004.gif
--------------------------------------------------------------------------------
/scripts/XVFI-main/figures/045.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/figures/045.gif
--------------------------------------------------------------------------------
/scripts/XVFI-main/figures/078.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/figures/078.gif
--------------------------------------------------------------------------------
/scripts/XVFI-main/figures/081.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/figures/081.gif
--------------------------------------------------------------------------------
/scripts/XVFI-main/figures/146.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/figures/146.gif
--------------------------------------------------------------------------------
/scripts/XVFI-main/figures/Figure5.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/figures/Figure5.PNG
--------------------------------------------------------------------------------
/scripts/XVFI-main/figures/Table2.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/figures/Table2.PNG
--------------------------------------------------------------------------------
/scripts/XVFI-main/figures/results_045_resized_768.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/figures/results_045_resized_768.gif
--------------------------------------------------------------------------------
/scripts/XVFI-main/figures/results_079_resized_768.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/figures/results_079_resized_768.gif
--------------------------------------------------------------------------------
/scripts/XVFI-main/figures/results_158_resized_768.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/figures/results_158_resized_768.gif
--------------------------------------------------------------------------------
/scripts/XVFI-main/main.py:
--------------------------------------------------------------------------------
1 | import argparse, os, shutil, time, random, torch, cv2, datetime, torch.utils.data, math
2 | import torch.backends.cudnn as cudnn
3 | import torch.optim as optim
4 | import numpy as np
5 |
6 | from torch.autograd import Variable
7 | from utils import *
8 | from XVFInet import *
9 | from collections import Counter
10 |
11 |
12 | def parse_args():
13 | desc = "PyTorch implementation for XVFI"
14 | parser = argparse.ArgumentParser(description=desc)
15 | parser.add_argument('--gpu', type=int, default=0, help='gpu index')
16 | parser.add_argument('--net_type', type=str, default='XVFInet', choices=['XVFInet'], help='The type of Net')
17 | parser.add_argument('--net_object', default=XVFInet, choices=[XVFInet], help='The type of Net')
18 | parser.add_argument('--exp_num', type=int, default=1, help='The experiment number')
19 | parser.add_argument('--phase', type=str, default='test', choices=['train', 'test', 'test_custom', 'metrics_evaluation',])
20 | parser.add_argument('--continue_training', action='store_true', default=False, help='continue the training')
21 |
22 | """ Information of directories """
23 | parser.add_argument('--test_img_dir', type=str, default='./test_img_dir', help='test_img_dir path')
24 | parser.add_argument('--text_dir', type=str, default='./text_dir', help='text_dir path')
25 | parser.add_argument('--checkpoint_dir', type=str, default='./checkpoint_dir', help='checkpoint_dir')
26 | parser.add_argument('--log_dir', type=str, default='./log_dir', help='Directory name to save training logs')
27 |
28 | parser.add_argument('--dataset', default='X4K1000FPS', choices=['X4K1000FPS', 'Vimeo'],
29 | help='Training/test Dataset')
30 |
31 | # parser.add_argument('--train_data_path', type=str, default='./X4K1000FPS/train')
32 | # parser.add_argument('--val_data_path', type=str, default='./X4K1000FPS/val')
33 | # parser.add_argument('--test_data_path', type=str, default='./X4K1000FPS/test')
34 | parser.add_argument('--train_data_path', type=str, default='../Datasets/VIC_4K_1000FPS/train')
35 | parser.add_argument('--val_data_path', type=str, default='../Datasets/VIC_4K_1000FPS/val')
36 | parser.add_argument('--test_data_path', type=str, default='../Datasets/VIC_4K_1000FPS/test')
37 |
38 |
39 | parser.add_argument('--vimeo_data_path', type=str, default='./vimeo_triplet')
40 |
41 | """ Hyperparameters for Training (when [phase=='train']) """
42 | parser.add_argument('--epochs', type=int, default=200, help='The number of epochs to run')
43 | parser.add_argument('--freq_display', type=int, default=100, help='The number of iterations frequency for display')
44 | parser.add_argument('--save_img_num', type=int, default=4,
45 | help='The number of saved image while training for visualization. It should smaller than the batch_size')
46 | parser.add_argument('--init_lr', type=float, default=1e-4, help='The initial learning rate')
47 | parser.add_argument('--lr_dec_fac', type=float, default=0.25, help='step - lr_decreasing_factor')
48 | parser.add_argument('--lr_milestones', type=int, default=[100, 150, 180])
49 | parser.add_argument('--lr_dec_start', type=int, default=0,
50 | help='When scheduler is StepLR, lr decreases from epoch at lr_dec_start')
51 | parser.add_argument('--batch_size', type=int, default=8, help='The size of batch size.')
52 | parser.add_argument('--weight_decay', type=float, default=0, help='for optim., weight decay (default: 0)')
53 |
54 | parser.add_argument('--need_patch', default=True, help='get patch form image while training')
55 | parser.add_argument('--img_ch', type=int, default=3, help='base number of channels for image')
56 | parser.add_argument('--nf', type=int, default=64, help='base number of channels for feature maps') # 64
57 | parser.add_argument('--module_scale_factor', type=int, default=4, help='sptial reduction for pixelshuffle')
58 | parser.add_argument('--patch_size', type=int, default=384, help='patch size')
59 | parser.add_argument('--num_thrds', type=int, default=4, help='number of threads for data loading')
60 | parser.add_argument('--loss_type', default='L1', choices=['L1', 'MSE', 'L1_Charbonnier_loss'], help='Loss type')
61 |
62 | parser.add_argument('--S_trn', type=int, default=3, help='The lowest scale depth for training')
63 | parser.add_argument('--S_tst', type=int, default=5, help='The lowest scale depth for test')
64 |
65 | """ Weighting Parameters Lambda for Losses (when [phase=='train']) """
66 | parser.add_argument('--rec_lambda', type=float, default=1.0, help='Lambda for Reconstruction Loss')
67 |
68 | """ Settings for Testing (when [phase=='test' or 'test_custom']) """
69 | parser.add_argument('--saving_flow_flag', default=False)
70 | parser.add_argument('--multiple', type=int, default=8, help='Due to the indexing problem of the file names, we recommend to use the power of 2. (e.g. 2, 4, 8, 16 ...). CAUTION : For the provided X-TEST, multiple should be one of [2, 4, 8, 16, 32].')
71 | parser.add_argument('--metrics_types', type=list, default=["PSNR", "SSIM", "tOF"], choices=["PSNR", "SSIM", "tOF"])
72 |
73 | """ Settings for test_custom (when [phase=='test_custom']) """
74 | parser.add_argument('--custom_path', type=str, default='./custom_path', help='path for custom video containing frames')
75 |
76 | return check_args(parser.parse_args())
77 |
78 |
79 | def check_args(args):
80 | # --checkpoint_dir
81 | check_folder(args.checkpoint_dir)
82 |
83 | # --text_dir
84 | check_folder(args.text_dir)
85 |
86 | # --log_dir
87 | check_folder(args.log_dir)
88 |
89 | # --test_img_dir
90 | check_folder(args.test_img_dir)
91 |
92 | return args
93 |
94 |
95 | def main():
96 | args = parse_args()
97 | if args.dataset == 'Vimeo':
98 | if args.phase != 'test_custom':
99 | args.multiple = 2
100 | args.S_trn = 1
101 | args.S_tst = 1
102 | args.module_scale_factor = 2
103 | args.patch_size = 256
104 | args.batch_size = 16
105 | print('vimeo triplet data dir : ', args.vimeo_data_path)
106 |
107 | print("Exp:", args.exp_num)
108 | args.model_dir = args.net_type + '_' + args.dataset + '_exp' + str(
109 | args.exp_num) # ex) model_dir = "XVFInet_X4K1000FPS_exp1"
110 |
111 | if args is None:
112 | exit()
113 | for arg in vars(args):
114 | print('# {} : {}'.format(arg, getattr(args, arg)))
115 | device = torch.device(
116 | 'cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') # will be used as "x.to(device)"
117 | torch.cuda.set_device(device) # change allocation of current GPU
118 | # caution!!!! if not "torch.cuda.set_device()":
119 | # RuntimeError: grid_sampler(): expected input and grid to be on same device, but input is on cuda:1 and grid is on cuda:0
120 | print('Available devices: ', torch.cuda.device_count())
121 | print('Current cuda device: ', torch.cuda.current_device())
122 | print('Current cuda device name: ', torch.cuda.get_device_name(device))
123 | if args.gpu is not None:
124 | print("Use GPU: {} is used".format(args.gpu))
125 |
126 | SM = save_manager(args)
127 |
128 | """ Initialize a model """
129 | model_net = args.net_object(args).apply(weights_init).to(device)
130 | criterion = [set_rec_loss(args).to(device), set_smoothness_loss().to(device)]
131 |
132 | # to enable the inbuilt cudnn auto-tuner
133 | # to find the best algorithm to use for your hardware.
134 | cudnn.benchmark = True
135 |
136 | if args.phase == "train":
137 | train(model_net, criterion, device, SM, args)
138 | epoch = args.epochs - 1
139 |
140 | elif args.phase == "test" or args.phase == "metrics_evaluation" or args.phase == 'test_custom':
141 | checkpoint = SM.load_model()
142 | model_net.load_state_dict(checkpoint['state_dict_Model'])
143 | epoch = checkpoint['last_epoch']
144 |
145 | postfix = '_final_x' + str(args.multiple) + '_S_tst' + str(args.S_tst)
146 | if args.phase != "metrics_evaluation":
147 | print("\n-------------------------------------- Final Test starts -------------------------------------- ")
148 | print('Evaluate on test set (final test) with multiple = %d ' % (args.multiple))
149 |
150 | final_test_loader = get_test_data(args, multiple=args.multiple,
151 | validation=False) # multiple is only used for X4K1000FPS
152 |
153 | testLoss, testPSNR, testSSIM, final_pred_save_path = test(final_test_loader, model_net,
154 | criterion, epoch,
155 | args, device,
156 | multiple=args.multiple,
157 | postfix=postfix, validation=False)
158 | SM.write_info('Final 4k frames PSNR : {:.4}\n'.format(testPSNR))
159 |
160 | if args.dataset == 'X4K1000FPS' and args.phase != 'test_custom':
161 | final_pred_save_path = os.path.join(args.test_img_dir, args.model_dir, 'epoch_' + str(epoch).zfill(5)) + postfix
162 | metrics_evaluation_X_Test(final_pred_save_path, args.test_data_path, args.metrics_types,
163 | flow_flag=args.saving_flow_flag, multiple=args.multiple)
164 |
165 |
166 |
167 | print("------------------------- Test has been ended. -------------------------\n")
168 | print("Exp:", args.exp_num)
169 |
170 |
171 | def train(model_net, criterion, device, save_manager, args):
172 | SM = save_manager
173 | multi_scale_recon_loss = criterion[0]
174 | smoothness_loss = criterion[1]
175 |
176 | optimizer = optim.Adam(model_net.parameters(), lr=args.init_lr, betas=(0.9, 0.999),
177 | weight_decay=args.weight_decay) # optimizer
178 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=args.lr_dec_fac)
179 |
180 | last_epoch = 0
181 | best_PSNR = 0.0
182 |
183 | if args.continue_training:
184 | checkpoint = SM.load_model()
185 | last_epoch = checkpoint['last_epoch'] + 1
186 | best_PSNR = checkpoint['best_PSNR']
187 | model_net.load_state_dict(checkpoint['state_dict_Model'])
188 | optimizer.load_state_dict(checkpoint['state_dict_Optimizer'])
189 | scheduler.load_state_dict(checkpoint['state_dict_Scheduler'])
190 | print("Optimizer and Scheduler have been reloaded. ")
191 | scheduler.milestones = Counter(args.lr_milestones)
192 | scheduler.gamma = args.lr_dec_fac
193 | print("scheduler.milestones : {}, scheduler.gamma : {}".format(scheduler.milestones, scheduler.gamma))
194 | start_epoch = last_epoch
195 |
196 | # switch to train mode
197 | model_net.train()
198 |
199 | start_time = time.time()
200 |
201 | SM.write_info('Epoch\ttrainLoss\ttestPSNR\tbest_PSNR\n')
202 | print("[*] Training starts")
203 |
204 | # Main training loop for total epochs (start from 'epoch=0')
205 | valid_loader = get_test_data(args, multiple=4, validation=True) # multiple is only used for X4K1000FPS
206 |
207 | for epoch in range(start_epoch, args.epochs):
208 | train_loader = get_train_data(args,
209 | max_t_step_size=32) # max_t_step_size (temporal distance) is only used for X4K1000FPS
210 |
211 | batch_time = AverageClass('batch_time[s]:', ':6.3f')
212 | losses = AverageClass('Loss:', ':.4e')
213 | progress = ProgressMeter(len(train_loader), batch_time, losses, prefix="Epoch: [{}]".format(epoch))
214 |
215 | print('Start epoch {} at [{:s}], learning rate : [{}]'.format(epoch, (str(datetime.now())[:-7]),
216 | optimizer.param_groups[0]['lr']))
217 |
218 | # train for one epoch
219 | for trainIndex, (frames, t_value) in enumerate(train_loader):
220 |
221 | input_frames = frames[:, :, :-1, :] # [B, C, T, H, W]
222 | frameT = frames[:, :, -1, :] # [B, C, H, W]
223 |
224 | # Getting the input and the target from the training set
225 | input_frames = Variable(input_frames.to(device))
226 | frameT = Variable(frameT.to(device)) # ground truth for frameT
227 | t_value = Variable(t_value.to(device)) # [B,1]
228 |
229 | optimizer.zero_grad()
230 | # compute output
231 | pred_frameT_pyramid, pred_flow_pyramid, occ_map, simple_mean = model_net(input_frames, t_value)
232 | rec_loss = 0.0
233 | smooth_loss = 0.0
234 | for l, pred_frameT_l in enumerate(pred_frameT_pyramid):
235 | rec_loss += args.rec_lambda * multi_scale_recon_loss(pred_frameT_l,
236 | F.interpolate(frameT, scale_factor=1 / (2 ** l),
237 | mode='bicubic', align_corners=False))
238 | smooth_loss += 0.5 * smoothness_loss(pred_flow_pyramid[0],
239 | F.interpolate(frameT, scale_factor=1 / args.module_scale_factor,
240 | mode='bicubic',
241 | align_corners=False)) # Apply 1st order edge-aware smoothness loss to the fineset level
242 | rec_loss /= len(pred_frameT_pyramid)
243 | pred_frameT = pred_frameT_pyramid[0] # final result I^0_t at original scale (s=0)
244 | pred_coarse_flow = 2 ** (args.S_trn) * F.interpolate(pred_flow_pyramid[-1], scale_factor=2 ** (
245 | args.S_trn) * args.module_scale_factor, mode='bicubic', align_corners=False)
246 | pred_fine_flow = F.interpolate(pred_flow_pyramid[0], scale_factor=args.module_scale_factor, mode='bicubic',
247 | align_corners=False)
248 |
249 | total_loss = rec_loss + smooth_loss
250 |
251 | # compute gradient and do SGD step
252 | total_loss.backward() # Backpropagate
253 | optimizer.step() # Optimizer update
254 |
255 | # measure accumulated time and update average "batch" time consumptions via "AverageClass"
256 | # update average values via "AverageClass"
257 | losses.update(total_loss.item(), 1)
258 | batch_time.update(time.time() - start_time)
259 | start_time = time.time()
260 |
261 | if trainIndex % args.freq_display == 0:
262 | progress.print(trainIndex)
263 | batch_images = get_batch_images(args, save_img_num=args.save_img_num,
264 | save_images=[pred_frameT, pred_coarse_flow, pred_fine_flow, frameT,
265 | simple_mean, occ_map])
266 | cv2.imwrite(os.path.join(args.log_dir, '{:03d}_{:04d}_training.png'.format(epoch, trainIndex)), batch_images)
267 |
268 |
269 |
270 | if epoch >= args.lr_dec_start:
271 | scheduler.step()
272 |
273 | # if (epoch + 1) % 10 == 0 or epoch==0:
274 | val_multiple = 4 if args.dataset == 'X4K1000FPS' else 2
275 | print('\nEvaluate on test set (validation while training) with multiple = {}'.format(val_multiple))
276 | postfix = '_val_' + str(val_multiple) + '_S_tst' + str(args.S_tst)
277 | testLoss, testPSNR, testSSIM, final_pred_save_path = test(valid_loader, model_net, criterion, epoch, args,
278 | device, multiple=val_multiple, postfix=postfix,
279 | validation=True)
280 |
281 | # remember best best_PSNR and best_SSIM and save checkpoint
282 | print("best_PSNR : {:.3f}, testPSNR : {:.3f}".format(best_PSNR, testPSNR))
283 | best_PSNR_flag = testPSNR > best_PSNR
284 | best_PSNR = max(testPSNR, best_PSNR)
285 | # save checkpoint.
286 | combined_state_dict = {
287 | 'net_type': args.net_type,
288 | 'last_epoch': epoch,
289 | 'batch_size': args.batch_size,
290 | 'trainLoss': losses.avg,
291 | 'testLoss': testLoss,
292 | 'testPSNR': testPSNR,
293 | 'best_PSNR': best_PSNR,
294 | 'state_dict_Model': model_net.state_dict(),
295 | 'state_dict_Optimizer': optimizer.state_dict(),
296 | 'state_dict_Scheduler': scheduler.state_dict()}
297 |
298 | SM.save_best_model(combined_state_dict, best_PSNR_flag)
299 |
300 | if (epoch + 1) % 10 == 0:
301 | SM.save_epc_model(combined_state_dict, epoch)
302 | SM.write_info('{}\t{:.4}\t{:.4}\t{:.4}\n'.format(epoch, losses.avg, testPSNR, best_PSNR))
303 |
304 | print("------------------------- Training has been ended. -------------------------\n")
305 | print("information of model:", args.model_dir)
306 | print("best_PSNR of model:", best_PSNR)
307 |
308 |
309 | def test(test_loader, model_net, criterion, epoch, args, device, multiple, postfix, validation):
310 | batch_time = AverageClass('Time:', ':6.3f')
311 | losses = AverageClass('testLoss:', ':.4e')
312 | PSNRs = AverageClass('testPSNR:', ':.4e')
313 | SSIMs = AverageClass('testSSIM:', ':.4e')
314 | args.divide = 2 ** (args.S_tst) * args.module_scale_factor * 4
315 |
316 | # progress = ProgressMeter(len(test_loader), batch_time, accm_time, losses, PSNRs, SSIMs, prefix='Test after Epoch[{}]: '.format(epoch))
317 | progress = ProgressMeter(len(test_loader), PSNRs, SSIMs, prefix='Test after Epoch[{}]: '.format(epoch))
318 |
319 | multi_scale_recon_loss = criterion[0]
320 |
321 | # switch to evaluate mode
322 | model_net.eval()
323 |
324 | print("------------------------------------------- Test ----------------------------------------------")
325 | with torch.no_grad():
326 | start_time = time.time()
327 | for testIndex, (frames, t_value, scene_name, frameRange) in enumerate(test_loader):
328 | # Shape of 'frames' : [1,C,T+1,H,W]
329 | frameT = frames[:, :, -1, :, :] # [1,C,H,W]
330 | It_Path, I0_Path, I1_Path = frameRange
331 |
332 | frameT = Variable(frameT.to(device)) # ground truth for frameT
333 | t_value = Variable(t_value.to(device))
334 |
335 | if (testIndex % (multiple - 1)) == 0:
336 | input_frames = frames[:, :, :-1, :, :] # [1,C,T,H,W]
337 | input_frames = Variable(input_frames.to(device))
338 |
339 | B, C, T, H, W = input_frames.size()
340 | H_padding = (args.divide - H % args.divide) % args.divide
341 | W_padding = (args.divide - W % args.divide) % args.divide
342 | if H_padding != 0 or W_padding != 0:
343 | input_frames = F.pad(input_frames, (0, W_padding, 0, H_padding), "constant")
344 |
345 |
346 | pred_frameT = model_net(input_frames, t_value, is_training=False)
347 |
348 | if H_padding != 0 or W_padding != 0:
349 | pred_frameT = pred_frameT[:, :, :H, :W]
350 |
351 |
352 |
353 | if args.phase != 'test_custom':
354 | test_loss = args.rec_lambda * multi_scale_recon_loss(pred_frameT, frameT)
355 |
356 | pred_frameT = np.squeeze(pred_frameT.detach().cpu().numpy())
357 | frameT = np.squeeze(frameT.detach().cpu().numpy())
358 |
359 | """ compute PSNR & SSIM """
360 | output_img = np.around(denorm255_np(np.transpose(pred_frameT, [1, 2, 0]))) # [h,w,c] and [-1,1] to [0,255]
361 | target_img = denorm255_np(np.transpose(frameT, [1, 2, 0])) # [h,w,c] and [-1,1] to [0,255]
362 |
363 | test_psnr = psnr(target_img, output_img)
364 | test_ssim = ssim_bgr(target_img, output_img) ############### CAUTION: calculation for BGR
365 |
366 | """ save frame0 & frame1 """
367 | if validation:
368 | epoch_save_path = os.path.join(args.test_img_dir, args.model_dir, 'latest' + postfix)
369 | else:
370 | epoch_save_path = os.path.join(args.test_img_dir, args.model_dir,
371 | 'epoch_' + str(epoch).zfill(5) + postfix)
372 | check_folder(epoch_save_path)
373 | scene_save_path = os.path.join(epoch_save_path, scene_name[0])
374 | check_folder(scene_save_path)
375 |
376 | if (testIndex % (multiple - 1)) == 0:
377 | save_input_frames = frames[:, :, :-1, :, :]
378 | cv2.imwrite(os.path.join(scene_save_path, I0_Path[0]),
379 | np.transpose(np.squeeze(denorm255_np(save_input_frames[:, :, 0, :, :].detach().numpy())),
380 | [1, 2, 0]).astype(np.uint8))
381 | cv2.imwrite(os.path.join(scene_save_path, I1_Path[0]),
382 | np.transpose(np.squeeze(denorm255_np(save_input_frames[:, :, 1, :, :].detach().numpy())),
383 | [1, 2, 0]).astype(np.uint8))
384 |
385 | cv2.imwrite(os.path.join(scene_save_path, It_Path[0]), output_img.astype(np.uint8))
386 |
387 | # measure
388 | losses.update(test_loss.item(), 1)
389 | PSNRs.update(test_psnr, 1)
390 | SSIMs.update(test_ssim, 1)
391 |
392 | # measure elapsed time
393 | batch_time.update(time.time() - start_time)
394 | start_time = time.time()
395 |
396 | if (testIndex % (multiple - 1)) == multiple - 2:
397 | progress.print(testIndex)
398 |
399 | else:
400 | epoch_save_path = args.custom_path
401 | scene_save_path = os.path.join(epoch_save_path, scene_name[0])
402 | pred_frameT = np.squeeze(pred_frameT.detach().cpu().numpy())
403 | output_img = np.around(denorm255_np(np.transpose(pred_frameT, [1, 2, 0]))) # [h,w,c] and [-1,1] to [0,255]
404 | print(os.path.join(scene_save_path, It_Path[0]))
405 | cv2.imwrite(os.path.join(scene_save_path, It_Path[0]), output_img.astype(np.uint8))
406 |
407 | losses.update(0.0, 1)
408 | PSNRs.update(0.0, 1)
409 | SSIMs.update(0.0, 1)
410 |
411 | print("-----------------------------------------------------------------------------------------------")
412 |
413 | return losses.avg, PSNRs.avg, SSIMs.avg, epoch_save_path
414 |
415 |
416 | if __name__ == '__main__':
417 | main()
418 |
--------------------------------------------------------------------------------
/scripts/XVFI-main/run.sh:
--------------------------------------------------------------------------------
1 | python main.py --gpu 0 --phase test_custom --exp_num 1 --dataset X4K1000FPS --module_scale_factor 4 --S_tst 5 --multiple 8 --custom_path D:\Code\SpikeCamera\Nerf\blend_files\dataset\raw_data
2 |
--------------------------------------------------------------------------------
/scripts/XVFI-main/test.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/XVFI-main/test.gif
--------------------------------------------------------------------------------
/scripts/XVFI-main/text_dir/XVFInet_X4K1000FPS_exp1.txt:
--------------------------------------------------------------------------------
1 | ----- Model parameters -----
2 | 2023-11-14 19:45:33
3 | gpu : 0
4 | net_type : XVFInet
5 | net_object :
6 | exp_num : 1
7 | phase : test_custom
8 | continue_training : False
9 | test_img_dir : ./test_img_dir
10 | text_dir : ./text_dir
11 | checkpoint_dir : ./checkpoint_dir
12 | log_dir : ./log_dir
13 | dataset : X4K1000FPS
14 | train_data_path : ../Datasets/VIC_4K_1000FPS/train
15 | val_data_path : ../Datasets/VIC_4K_1000FPS/val
16 | test_data_path : ../Datasets/VIC_4K_1000FPS/test
17 | vimeo_data_path : ./vimeo_triplet
18 | epochs : 200
19 | freq_display : 100
20 | save_img_num : 4
21 | init_lr : 0.0001
22 | lr_dec_fac : 0.25
23 | lr_milestones : [100, 150, 180]
24 | lr_dec_start : 0
25 | batch_size : 8
26 | weight_decay : 0
27 | need_patch : True
28 | img_ch : 3
29 | nf : 64
30 | module_scale_factor : 4
31 | patch_size : 384
32 | num_thrds : 4
33 | loss_type : L1
34 | S_trn : 3
35 | S_tst : 5
36 | rec_lambda : 1.0
37 | saving_flow_flag : False
38 | multiple : 8
39 | metrics_types : ['PSNR', 'SSIM', 'tOF']
40 | custom_path : /d/Code/SpikeCamera/Nerf/blend_files/dataset/train_bg
41 | model_dir : XVFInet_X4K1000FPS_exp1
42 | Final 4k frames PSNR : 0.0
43 | Final 4k frames PSNR : 0.0
44 | Final 4k frames PSNR : 0.0
45 | Final 4k frames PSNR : 0.0
46 | Final 4k frames PSNR : 0.0
47 | Final 4k frames PSNR : 0.0
48 | Final 4k frames PSNR : 0.0
49 | Final 4k frames PSNR : 0.0
50 | Final 4k frames PSNR : 0.0
51 | Final 4k frames PSNR : 0.0
52 | Final 4k frames PSNR : 0.0
53 | Final 4k frames PSNR : 0.0
54 | Final 4k frames PSNR : 0.0
55 | Final 4k frames PSNR : 0.0
56 | Final 4k frames PSNR : 0.0
57 | Final 4k frames PSNR : 0.0
58 | Final 4k frames PSNR : 0.0
59 | Final 4k frames PSNR : 0.0
60 | Final 4k frames PSNR : 0.0
61 | Final 4k frames PSNR : 0.0
62 | Final 4k frames PSNR : 0.0
63 |
--------------------------------------------------------------------------------
/scripts/__pycache__/generate_spike.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/__pycache__/generate_spike.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenkang455/S-SDM/a98d8a1c5ddde019f507c68f2f70c585504d7edc/scripts/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/blur_syn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from tqdm import trange
5 | import argparse
6 | from torchvision import transforms
7 |
8 | #? convert the imgs under raw_folder to blurry imgs on blur_folder
9 | #? Structure as:
10 | #? base_folder
11 | #? ├── blur_folder
12 | #? ├── raw_folder
13 |
14 | # main function
15 | if __name__ == '__main__':
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--raw_folder', type=str, default='GOPRO/test/raw_data')
18 | parser.add_argument('--img_type', type=str, default='.png')
19 | parser.add_argument('--overlap_len', type=int, default= 7,
20 | help = 'Overlap length between two blurry images. Assume [0-12] -> blur_6, [13,25] -> blur_19, overlap_len is the length of interpolated frames between 12.png and 13.png,i.e.,12_0.png,...12_6.png')
21 | parser.add_argument('--height', type=int, default= 720)
22 | parser.add_argument('--width', type=int, default= 1280)
23 | parser.add_argument('--use_resize', action='store_true', help = 'Resize the image size to the half.')
24 | parser.add_argument('--blur_num', type=int, default= 13,help = 'Number of images before interpolation to synthesize one blurry frame.')
25 |
26 | opt = parser.parse_args()
27 | width = opt.width
28 | height = opt.height
29 | use_resize = opt.use_resize
30 | raw_folder = opt.raw_folder
31 | base_folder = os.path.dirname(raw_folder)
32 | blur_folder = os.path.join(base_folder,'blur_data')
33 | os.makedirs(blur_folder,exist_ok = True)
34 |
35 | for dirpath, sub_dirs, sub_files in os.walk(raw_folder):
36 | if len(sub_files) == 0 or sub_files[0].endswith(opt.img_type) == False:
37 | continue
38 | print(dirpath)
39 | output_folder = dirpath.replace(raw_folder,blur_folder)
40 | os.makedirs(output_folder,exist_ok = True)
41 | sub_files = sorted(sub_files)
42 | bais = 0 # bais that overlap between two blurry imgs
43 | num_blur_raw = opt.blur_num # number of sharp imgs before interpolated per blurry one
44 | num_inter = 7 # number of interpolated imgs between two imgs
45 | num_blur = (num_blur_raw - 1) * (num_inter + 1) + 1 # number of interpolated imgs per blurry one
46 | str_len = len(sub_files[0].split('.')[0])
47 | imgs = []
48 | start = 0
49 | for i in trange(len(sub_files)):
50 | if i + bais >= len(sub_files):
51 | break
52 | file_name = sub_files[i + bais]
53 | img = cv2.imread(os.path.join(dirpath,file_name))
54 | imgs.append(img)
55 | # synthesize the blurry image
56 | if i % (num_blur) == num_blur - 1:
57 | blur_img = np.mean(np.stack(imgs,axis = 0),axis = 0)
58 | end = i + bais
59 | # (num_blur_raw - 1 + (opt.overlap_len + 1) // (num_inter + 1)): number of imgs per blur
60 | # ((i + 1) // num_blur - 1): denotes the order of blur, 0 at first
61 | # num_blur_raw // 2: middle frame
62 | if use_resize:
63 | blur_img = cv2.resize(blur_img,(width // 2,height//2),interpolation=cv2.INTER_LINEAR)
64 | # ! the former denotes the renamed order, while the latter denotes the name in the raw image folder.
65 | print(f"Synthesize blurry {str((num_blur_raw - 1 + (opt.overlap_len + 1) // (num_inter + 1)) * ((i + 1) // num_blur - 1) + num_blur_raw // 2).zfill(str_len)}.png from {sub_files[start]} to {sub_files[end]}")
66 | cv2.imwrite(os.path.join(output_folder,f"{str((num_blur_raw - 1 + (opt.overlap_len + 1) // (num_inter + 1)) * ((i + 1) // num_blur - 1) + num_blur_raw // 2).zfill(str_len)}.png"),blur_img)
67 | imgs = []
68 | bais += opt.overlap_len
69 | start = i + bais + 1
70 |
--------------------------------------------------------------------------------
/scripts/run.sh:
--------------------------------------------------------------------------------
1 | # todo GOPRO Data
2 | # 1. Download Data
3 |
4 | #? Structure:
5 | #? GOPRO
6 | #? ├── train
7 | #? │ ├── raw_data
8 | #? │ │ ├── GOPR0372_07_00
9 | #? │ │ ├── ...
10 | #? │ │ └── GOPR0884_11_00
11 | #? │
12 | #? ├── test
13 | #? │ ├── raw_data
14 | #? │ │ ├── GOPR0384_11_00
15 | #? │ │ ├── ...
16 | #? │ │ └── GOPR0881_11_01
17 |
18 | # 2. Interpolate frames
19 | cd XVFI-main/
20 | python main.py --custom_path ../GOPRO/test/raw_data --gpu 0 --phase test_custom --exp_num 1 --dataset X4K1000FPS --module_scale_factor 4 --S_tst 5 --multiple 8
21 | python main.py --custom_path ../GOPRO/train/raw_data --gpu 0 --phase test_custom --exp_num 1 --dataset X4K1000FPS --module_scale_factor 4 --S_tst 5 --multiple 8
22 | cd ..
23 |
24 | # 3. Synthesize the blurry image
25 | python blur_syn.py --raw_folder GOPRO/test/raw_data --overlap_len 7 --blur_num 13
26 | python blur_syn.py --raw_folder GOPRO/train/raw_data --overlap_len 7 --blur_num 13
27 |
28 | # 4. Simulate the spike
29 | python spike_simulate.py --raw_folder GOPRO/test/raw_data --overlap_len 7 --use_resize --blur_num 13
30 | python spike_simulate.py --raw_folder GOPRO/train/raw_data --overlap_len 7 --use_resize --blur_num 13
31 |
32 | # 5. extract GT from raw_folder
33 | ## single frame
34 | python sharp_extract.py --raw_folder GOPRO/test/raw_data --overlap_len 7 --blur_num 13
35 | python sharp_extract.py --raw_folder GOPRO/train/raw_data --overlap_len 7 --blur_num 13
36 |
37 | ## sequence
38 | # python sharp_extract.py --raw_folder GOPRO/test/raw_data --overlap_len 7 --blur_num 13 --multi
39 | # python sharp_extract.py --raw_folder GOPRO/train/raw_data --overlap_len 7 --blur_num 13 --multi
--------------------------------------------------------------------------------
/scripts/sharp_extract.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from tqdm import trange
5 | import argparse
6 |
7 | #? extract the imgs under raw_folder to sharp sequence on sharp_folder
8 | #? Structure as:
9 | #? base_folder
10 | #? ├── blur_folder
11 | #? ├── raw_folder
12 | #? └── spike_folder
13 |
14 | # main function
15 | if __name__ == '__main__':
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--raw_folder', type=str, default='GOPRO/test/raw_data')
18 | parser.add_argument('--img_type', type=str, default='.png')
19 | parser.add_argument('--overlap_len', type=int, default= 7,
20 | help = 'Overlap length between two blurry images. Assume [0-12] -> blur_6, [13,25] -> blur_19, overlap_len is the length of interpolated frames between 12.png and 13.png,i.e.,12_0.png,...12_6.png')
21 | parser.add_argument('--height', type=int, default= 720)
22 | parser.add_argument('--width', type=int, default= 1280)
23 | parser.add_argument('--use_resize', action='store_true', help = 'Resize the image size to the half.')
24 | parser.add_argument('--blur_num', type=int, default= 13,help = 'Number of images before interpolation to synthesize one blurry frame.')
25 | parser.add_argument('--multi', action='store_true',default = False,help= 'extract gt sequence from raw_folder')
26 |
27 | opt = parser.parse_args()
28 | width = opt.width
29 | height = opt.height
30 | use_resize = opt.use_resize
31 | raw_folder = opt.raw_folder
32 | base_folder = os.path.dirname(raw_folder)
33 | blur_folder = os.path.join(base_folder,'blur_data')
34 | sharp_folder = os.path.join(base_folder,'sharp_data')
35 |
36 | for dirpath, sub_dirs, sub_files in os.walk(raw_folder):
37 | if len(sub_files) == 0 or sub_files[0].endswith(opt.img_type) == False:
38 | continue
39 | print(dirpath)
40 | output_folder = dirpath.replace(raw_folder,sharp_folder)
41 | os.makedirs(output_folder,exist_ok = True)
42 | sub_files = sorted(sub_files)
43 | bais = 0 # bais that overlap between two blurry imgs
44 | num_blur_raw = opt.blur_num # number of sharp imgs before interpolated per blurry one
45 | num_inter = 7 # number of interpolated imgs between two imgs
46 | num_blur = num_blur_raw * (num_inter + 1) + 1 # number of interpolated imgs per blurry one
47 | str_len = len(sub_files[0].split('.')[0])
48 | imgs = []
49 | idx = 0
50 | loop_bais = 0
51 | for i in trange(len(sub_files)):
52 | if i % (num_inter + 1) == 0:
53 | if idx % num_blur_raw == num_blur_raw // 2 or (opt.multi == True and ((idx - loop_bais) % 12) % 2 == 0): # % 2 is set to extract 7 images
54 | file_name = sub_files[i]
55 | print(file_name)
56 | img = cv2.imread(os.path.join(dirpath,file_name))
57 | if use_resize:
58 | img = cv2.resize(img,(width // 2,height//2),interpolation=cv2.INTER_LINEAR)
59 | cv2.imwrite(os.path.join(output_folder,f"{str(idx).zfill(str_len)}.png"),img)
60 | if idx % num_blur_raw == num_blur_raw - 1:
61 | loop_bais += 1
62 | idx += 1
63 |
--------------------------------------------------------------------------------
/scripts/spike_simulate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from tqdm import trange
5 | from utils import *
6 | from utils_spike import *
7 | import argparse
8 |
9 | #? convert the imgs under raw_folder to spike sequence on spike_folder
10 | #? Structure as:
11 | #? base_folder
12 | #? ├── blur_folder
13 | #? ├── raw_folder
14 | #? └── spike_folder
15 |
16 | # main function
17 | if __name__ == '__main__':
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--raw_folder', type=str, default=r'GOPRO/test/raw_data')
20 | parser.add_argument('--img_type', type=str, default='.png')
21 | parser.add_argument('--overlap_len', type=int, default= 7,
22 | help = 'Overlap length between two blurry images. Assume [0-12] -> blur_6, [13,25] -> blur_19, overlap_len is the length of interpolated frames between 12.png and 13.png,i.e.,12_0.png,...12_6.png')
23 | parser.add_argument('--height', type=int, default= 720)
24 | parser.add_argument('--width', type=int, default= 1280)
25 | parser.add_argument('--use_resize', action='store_true', help = 'Resize the image size to the half.')
26 | parser.add_argument('--blur_num', type=int, default= 13,help = 'Number of images before interpolation to synthesize one blurry frame.')
27 | parser.add_argument('--spike_add', type=int, default= 20,help = 'Additional spike out of the exposure period.')
28 |
29 | opt = parser.parse_args()
30 | width = opt.width
31 | height = opt.height
32 | resize = opt.use_resize
33 | raw_folder = opt.raw_folder
34 | base_folder = os.path.dirname(raw_folder)
35 | spike_folder = os.path.join(base_folder,'spike_data')
36 | os.makedirs(spike_folder,exist_ok = True)
37 | for dirpath, sub_dirs, sub_files in os.walk(raw_folder):
38 | if len(sub_files) == 0 or sub_files[0].endswith(opt.img_type) == False:
39 | continue
40 | print(dirpath)
41 | output_folder = dirpath.replace(raw_folder,spike_folder)
42 | os.makedirs(output_folder,exist_ok = True)
43 | sub_files = sorted(sub_files)
44 | bais = 0 # bais that overlap between two blurry imgs
45 | num_blur_raw = opt.blur_num # number of sharp imgs before interpolated per blurry one
46 | num_inter = 7 # number of interpolated imgs between two imgs
47 | num_blur = (num_blur_raw - 1) * (num_inter + 1) + 1 # number of interpolated imgs per blurry one
48 | spike_add = opt.spike_add # additional spike out of the exposure period
49 | num_omit = 1 # reduce the number of spike sequence to [num/num_omit]
50 | str_len = len(sub_files[0].split('.')[0])
51 | imgs = []
52 | start = 0
53 | for i in trange(len(sub_files)):
54 | if i + bais >= len(sub_files):
55 | break
56 | file_name = sub_files[i + bais]
57 | img = cv2.imread(os.path.join(dirpath,file_name))
58 | img = img.astype(np.float32) / 255
59 | if resize == True:
60 | img = cv2.resize(img,(width // 4,height // 4),interpolation=cv2.INTER_LINEAR)
61 | # GRAY=0.3*R+0.59*G+0.11*B
62 | img = 0.11 * img[...,0] + 0.59 * img[...,1] + 0.3 * img[...,2]
63 | imgs.append(img)
64 | # simulate the spike sequence during the exposure
65 | if i % (num_blur) == num_blur - 1:
66 | end = i + bais
67 | # skip the first blurry image
68 | if start == 0:
69 | imgs = []
70 | bais += opt.overlap_len
71 | start = i + bais + 1
72 | continue
73 | if end + spike_add >= len(sub_files):
74 | break
75 | # add the spike data out of the exposure period
76 | for jj in range(1,spike_add + 1):
77 | img_start = cv2.imread(os.path.join(dirpath,sub_files[start - jj]))
78 | if resize == True:
79 | img_start = cv2.resize(img_start,(width // 4,height // 4),interpolation=cv2.INTER_LINEAR )
80 | img_start = img_start.astype(np.float32) / 255
81 | img_start = 0.11 * img_start[...,0] + 0.59 * img_start[...,1] + 0.3 * img_start[...,2]
82 | img_end = cv2.imread(os.path.join(dirpath,sub_files[end + jj]))
83 | img_end = img_end.astype(np.float32) / 255
84 | if resize == True:
85 | img_end = cv2.resize(img_end,(width // 4,height // 4),interpolation=cv2.INTER_LINEAR )
86 | img_end = 0.11 * img_end[...,0] + 0.59 * img_end[...,1] + 0.3 * img_end[...,2]
87 | imgs.append(img_end)
88 | imgs.insert(0,img_start)
89 | #! reduce the number of spikes to / num_omit
90 | imgs = imgs[::num_omit]
91 | # todo from zjy
92 | spike = SimulationSimple_Video(imgs)
93 | noise = Inherent_Noise_fast_torch(spike.shape[0], H=spike.shape[1], W=spike.shape[2])
94 | spike = torch.bitwise_or(spike.permute((1,2,0)), noise)
95 | spike = spike.permute((2, 0, 1))
96 | # todo from csy
97 | # spike = v2s_interface(imgs,threshold=2)
98 |
99 | # save spike
100 | print(spike.shape)
101 | SpikeToRaw(os.path.join(output_folder, str((num_blur_raw - 1 + (opt.overlap_len + 1) // (num_inter + 1)) * ((i + 1) // num_blur - 1) + num_blur_raw // 2).zfill(str_len) + ".dat"), spike)
102 | print(f"Generating spikes from {sub_files[start - spike_add]} to {sub_files[end + spike_add]}")
103 | print(f"Blur area ranges from {sub_files[start]} to {sub_files[end]}")
104 | imgs = []
105 | bais += opt.overlap_len
106 | start = i + bais + 1
--------------------------------------------------------------------------------
/scripts/utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import numpy as np
4 | import imageio
5 | # from moviepy.editor import ImageSequenceClip
6 | import os
7 |
8 |
9 | def save_gif(image_list, gif_path = 'test', duration = 2,RGB = True,nor = False):
10 | imgs = []
11 | os.makedirs('Video',exist_ok = True)
12 | with imageio.get_writer(os.path.join('Video',gif_path + '.gif'), mode='I',duration = 1000 * duration / len(image_list),loop=0) as writer:
13 | for i in range(len(image_list)):
14 | img = normal_img(image_list[i],RGB,nor)
15 | writer.append_data(img)
16 |
17 | def save_video(image_list,path = 'test',duration = 2,RGB = True,nor = False):
18 | os.makedirs('Video',exist_ok = True)
19 | img_size = (image_list[0].shape[1], image_list[0].shape[0])
20 | fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
21 | videowriter = cv2.VideoWriter(os.path.join('Video',path + '.avi'), fourcc, len(image_list) / duration, img_size)
22 | for i in range(len(image_list)):
23 | img = normal_img(image_list[i],RGB,nor)
24 | videowriter.write(img)
25 |
26 |
27 | def normal_img(img,RGB = True,nor = True):
28 | if nor:
29 | img = 255 * ((img - img.min()) / (img.max() - img.min()))
30 | if isinstance(img,torch.Tensor):
31 | img = np.array(img.detach().cpu())
32 | if len(img.shape) == 2:
33 | img = img[...,None]
34 | if img.shape[-1] == 1:
35 | img = np.repeat(img,3,axis = -1)
36 | img = img.astype(np.uint8)
37 | if RGB == False:
38 | img = img[...,::-1]
39 | return img
40 |
41 | def save_img(path = 'test.png',img = None,nor = True):
42 | if nor:
43 | img = 255 * ((img - img.min()) / (img.max() - img.min()))
44 | if isinstance(img,torch.Tensor):
45 | img = np.array(img.detach().cpu())
46 | img = img.astype(np.uint8)
47 | cv2.imwrite(os.path.join(os.getcwd(),'imgs',path),img)
48 |
49 | def make_folder(path):
50 | os.makedirs(path,exist_ok = True)
51 |
52 | def video_to_spike(
53 | sourefolder=None,
54 | imgs = None,
55 | savefolder_debug=None,
56 | threshold=5.0,
57 | init_noise=True,
58 | format="png",
59 | ):
60 | """
61 | 函数说明
62 | :param 参数名: 参数说明
63 | :return: 返回值名: 返回值说明
64 | """
65 | if sourefolder != None:
66 | filelist = sorted(os.listdir(sourefolder))
67 | datas = [fn for fn in filelist if fn.endswith(format)]
68 |
69 | T = len(datas)
70 |
71 | frame0 = cv2.imread(os.path.join(sourefolder, datas[0]))
72 | H, W, C = frame0.shape
73 |
74 | frame0 = cv2.cvtColor(frame0, cv2.COLOR_BGR2GRAY)
75 |
76 | spikematrix = np.zeros([T, H, W], np.uint8)
77 |
78 | if init_noise:
79 | integral = np.random.random(size=([H,W])) * threshold
80 | else:
81 | integral = np.random.zeros(size=([H,W]))
82 |
83 | Thr = np.ones_like(integral).astype(np.float32) * threshold
84 |
85 | for t in range(0, T):
86 | frame = cv2.imread(os.path.join(sourefolder, datas[t]))
87 | if C > 1:
88 | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
89 | gray = gray / 255.0
90 | integral += gray
91 | fire = (integral - Thr) >= 0
92 | fire_pos = fire.nonzero()
93 |
94 | integral[fire_pos] -= threshold
95 | spikematrix[t][fire_pos] = 1
96 |
97 | if savefolder_debug:
98 | np.save(os.path.join(savefolder_debug, "spike_debug.npy"), spikematrix)
99 | elif imgs != None:
100 | frame0 = imgs[0]
101 | H, W, C = frame0.shape
102 | T = len(imgs)
103 | spikematrix = np.zeros([T, H, W], np.uint8)
104 |
105 | if init_noise:
106 | integral = np.random.random(size=([H,W])) * threshold
107 | else:
108 | integral = np.random.zeros(size=([H,W]))
109 |
110 | Thr = np.ones_like(integral).astype(np.float32) * threshold
111 |
112 | for t in range(0, T):
113 | frame = imgs[t]
114 | if C > 1:
115 | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
116 | gray = gray / 255.0
117 | integral += gray
118 | fire = (integral - Thr) >= 0
119 | fire_pos = fire.nonzero()
120 |
121 | integral[fire_pos] -= threshold
122 | spikematrix[t][fire_pos] = 1
123 | return spikematrix
124 |
125 |
126 | def load_vidar_dat(filename, left_up=(0, 0), window=None, frame_cnt = None, height = 800,width = 800):
127 | if isinstance(filename, str):
128 | array = np.fromfile(filename, dtype=np.uint8)
129 | elif isinstance(filename, (list, tuple)):
130 | l = []
131 | for name in filename:
132 | a = np.fromfile(name, dtype=np.uint8)
133 | l.append(a)
134 | array = np.concatenate(l)
135 | else:
136 | raise NotImplementedError
137 | if window == None:
138 | window = (height - left_up[0], width - left_up[0])
139 | len_per_frame = height * width // 8
140 | framecnt = frame_cnt if frame_cnt != None else len(array) // len_per_frame
141 | spikes = []
142 | for i in range(framecnt):
143 | compr_frame = array[i * len_per_frame: (i + 1) * len_per_frame]
144 | blist = []
145 | for b in range(8):
146 | blist.append(np.right_shift(np.bitwise_and(compr_frame, np.left_shift(1, b)), b))
147 | frame_ = np.stack(blist).transpose()
148 | frame_ = np.flipud(frame_.reshape((height, width), order='C'))
149 | if window is not None:
150 | spk = frame_[left_up[0]:left_up[0] + window[0], left_up[1]:left_up[1] + window[1]]
151 | else:
152 | spk = frame_
153 | spk = torch.from_numpy(spk.copy().astype(np.float32))
154 | spikes.append(spk)
155 | return torch.stack(spikes,dim = 0)
--------------------------------------------------------------------------------
/scripts/utils_spike.py:
--------------------------------------------------------------------------------
1 | from tqdm import trange
2 | import torch
3 | import os
4 | import numpy as np
5 | from PIL import Image
6 | import cv2
7 | import torchvision.transforms as transforms
8 |
9 |
10 | def Inherent_Noise_fast_torch(T, mu=140.0, std=50.0, H=250, W=400):
11 | """
12 | Generate Gaussian distributed inherent noise.
13 | args:
14 | - T: Simulation time length
15 | - mu: Mean
16 | - std: Standard deviation
17 | return:
18 | - The array records the location of noise
19 | """
20 | shape = [H, W, T]
21 | size = H * W * T
22 | gaussian = torch.normal(mu, std, size=(size,))
23 | noise = torch.zeros(size, dtype=torch.int16)
24 | keys = torch.cumsum(gaussian, dim=0)
25 | keys = keys[keys Vth
63 | syn_rec[n][vol_to_Thr] = 1
64 | vol[vol_to_Thr] = vol[vol_to_Thr] % Vth
65 | # for t in range(250):
66 | # # print(g.shape, noise_1[t].shape)
67 | # vol += g * K
68 | # vol_to_Thr = vol >= Vth
69 | # syn_rec[n][vol_to_Thr] = 1
70 | # vol[vol_to_Thr] = 0
71 | return syn_rec
72 |
73 | def v2s_interface(imgs, savefolder=None, threshold=5.0):
74 | T = len(imgs)
75 | # frame0[..., 0:2] = 0
76 | # cv2.imshow('red', frame0)
77 | # cv2.waitKey(0)
78 | H, W = imgs[0].shape
79 | # exit(0)
80 |
81 | spikematrix = np.zeros([T, H, W], np.uint8)
82 | # integral = np.array(frame0gray).astype(np.float32)
83 | integral = np.random.random(size=([H,W])) * threshold
84 | Thr = np.ones_like(integral).astype(np.float32) * threshold
85 |
86 | for t in range(0, T):
87 | # print('spike frame %s' % datas[t])
88 | frame = imgs[t]
89 | # gray = cv2.resize(frame, None, fx=0.25, fy=0.25, interpolation=cv2.INTER_AREA)
90 | gray = frame
91 | integral += gray
92 | fire = (integral - Thr) >= 0
93 | fire_pos = fire.nonzero()
94 | integral[fire_pos] -= threshold
95 | # integral[fire_pos] = 0.0
96 | spikematrix[t][fire_pos] = 1
97 | return spikematrix
98 |
99 | def SimulationSimple_Video(I):
100 | # Initialize Sensor Parameters
101 | Vth = 2.0
102 | Eta = 10**(-13)*1.09
103 | Lambda = 10**(-4)*1.83
104 | Cpd = 10.0**(-15)*15
105 | CLK = 10.0**(6)*10
106 | delta_t = 2 / CLK
107 | K = delta_t * Eta / (Lambda * Cpd)
108 | # print(K)
109 | # vol = 0
110 | T = len(I)
111 | H, W = I[0].shape
112 | vol = torch.zeros(size=(H,W))
113 | syn_rec = torch.zeros(size=(T+1,H,W), dtype=torch.int16)
114 |
115 | for n in range(T+1):
116 | g = torch.rand(size=(H,W)) if n == 0 else I[n-1]
117 | # g = I
118 | # print(noise_1.shape)
119 |
120 | # print(g_all[:, 100,100])
121 | # flag, vol = inner_clock(g * K, Vth, vol)
122 | # print(torch.max(flag), torch.max(vol))
123 | # syn_rec[n] = flag
124 | vol += g * K * 250
125 | vol_to_Thr = vol > Vth
126 | syn_rec[n][vol_to_Thr] = 1
127 | vol[vol_to_Thr] = vol[vol_to_Thr] % Vth
128 | # for t in range(250):
129 | # # print(g.shape, noise_1[t].shape)
130 | # vol += g * K
131 | # vol_to_Thr = vol >= Vth
132 | # syn_rec[n][vol_to_Thr] = 1
133 | # vol[vol_to_Thr] = 0
134 | return syn_rec[1:]
135 |
136 |
137 | def SpikeToRaw(save_path, SpikeSeq, filpud=True, delete_if_exists=True):
138 | """
139 | save spike sequence to .dat file
140 | save_path: full saving path (string)
141 | SpikeSeq: Numpy array (T x H x W)
142 | Rui Zhao
143 | """
144 | if delete_if_exists:
145 | if os.path.exists(save_path):
146 | os.remove(save_path)
147 |
148 | sfn, h, w = SpikeSeq.shape
149 | remainder = int((h * w) % 8)
150 | # assert (h * w) % 8 == 0
151 | base = np.power(2, np.linspace(0, 7, 8))
152 | fid = open(save_path, 'ab')
153 | for img_id in range(sfn):
154 | if filpud:
155 | # 模拟相机的倒像
156 | spike = np.flipud(SpikeSeq[img_id, :, :])
157 | else:
158 | spike = SpikeSeq[img_id, :, :]
159 | # numpy按自动按行排,数据也是按行存的
160 | # spike = spike.flatten()
161 | if remainder == 0:
162 | spike = spike.flatten()
163 | else:
164 | spike = np.concatenate([spike.flatten(), np.array([0]*(8-remainder))])
165 | spike = spike.reshape([int(h*w/8), 8])
166 | data = spike * base
167 | data = np.sum(data, axis=1).astype(np.uint8)
168 | fid.write(data.tobytes())
169 | fid.close()
170 | return
171 |
172 | def load_vidar_dat(filename, frame_cnt=None, width=640, height=480, reverse_spike=True):
173 | '''
174 | output: (frame_cnt, height, width) {0,1} float32
175 | '''
176 | array = np.fromfile(filename, dtype=np.uint8)
177 |
178 | len_per_frame = height * width // 8
179 | framecnt = frame_cnt if frame_cnt != None else len(array) // len_per_frame
180 |
181 | spikes = []
182 | for i in range(framecnt):
183 | compr_frame = array[i * len_per_frame: (i + 1) * len_per_frame]
184 | blist = []
185 | for b in range(8):
186 | blist.append(np.right_shift(np.bitwise_and(
187 | compr_frame, np.left_shift(1, b)), b))
188 |
189 | frame_ = np.stack(blist).transpose()
190 | frame_ = frame_.reshape((height, width), order='C')
191 | if reverse_spike:
192 | frame_ = np.flipud(frame_)
193 | spikes.append(frame_)
194 |
195 | return np.array(spikes).astype(np.float32)
196 |
197 | if __name__ == "__main__":
198 | # main()
199 | imgs = np.ones((100,300,300))
200 | print(v2s_interface(imgs).shape)
201 | # spike = load_vidar_dat('I:\\Datasets\\REDS\\train\\train_spike\\000\\00000007.dat', width=260, height=640)
202 | # cv2.imwrite('test_spike_load.png', np.mean(spike, axis=0)[..., None] * 255)
203 |
--------------------------------------------------------------------------------
/train_bsn.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 | import sys
7 | sys.path.append("..")
8 | import os
9 | import argparse
10 | from datetime import datetime
11 | from torch.optim import Adam
12 | from torchvision import transforms
13 | from torch.utils.tensorboard import SummaryWriter
14 | from tqdm import tqdm
15 | import shutil
16 | from torch.optim.lr_scheduler import CosineAnnealingLR,MultiStepLR
17 | from codes.utils import *
18 | from codes.dataset import *
19 | from codes.metrics import compute_img_metric
20 | from codes.model.bsn_model import BSN
21 |
22 |
23 | # SDM model
24 | def sdm_deblur(blur,spike,spike_idx):
25 | """ Basic SDM model.
26 |
27 | Args:
28 | blur (tensor): blurry input. [bs,3,w,h]
29 | spike (tensor): spike sequence. [bs,137,w,h]
30 | spike_idx (int): central idx of the short-exposure spike stream
31 |
32 | Returns:
33 | rgb_sdm: deblur result. [bs,3,w,h]
34 | """
35 | global spike_bsn_len
36 | spike_sum = torch.sum(spike[:,20:-20],dim = 1,keepdim = True)
37 | spike_bsn = spike[:,spike_idx - spike_bsn_len // 2:spike_idx + spike_bsn_len // 2 + 1,:,:]
38 | rgb_sdm = blur / spike_sum
39 | rgb_sdm[spike_sum.repeat(1,3,1,1) == 0] = 0
40 | rgb_sdm = rgb_sdm * torch.sum(spike_bsn,dim = 1,keepdim = True) * 97 / spike_bsn_len
41 | rgb_sdm = rgb_sdm.clip(0,1)
42 | return rgb_sdm
43 |
44 | # TFP model
45 | def cal_tfp(spike,spike_idx,tfp_len):
46 | """TPF Model
47 |
48 | Args:
49 | spike (tensor): spike sequence. [bs,137,w,h]
50 | spike_idx (int): central idx of the virtual exposure window
51 | tfp_len (_type_): length of the virtual exposure window. 97 for long-TFP, [7,9,11] for short-TFP.
52 |
53 | Returns:
54 | tfp_pred: tfp result
55 | """
56 | spike = spike[:,spike_idx - tfp_len // 2:spike_idx + tfp_len // 2 + 1,:,:]
57 | tfp_pred = torch.mean(spike,dim = 1,keepdim = True)
58 | return tfp_pred
59 |
60 | if __name__ == '__main__':
61 | # parameters
62 | parser = argparse.ArgumentParser()
63 | parser.add_argument('--base_folder', type=str,default=r'GOPRO/',help = 'base folder of the GOPRO dataset')
64 | parser.add_argument('--save_folder', type=str,default='exp/BSN', help = 'experimental results save folder')
65 | parser.add_argument('--data_type',type=str, default='GOPRO' ,help = 'dataset type')
66 | parser.add_argument('--exp_name', type=str,default='test', help = 'experiment name')
67 | parser.add_argument('--epochs', type=int, default=1001)
68 | parser.add_argument('--lr', type=float, default=3e-4)
69 | parser.add_argument('--seed', type=int, default=42)
70 | parser.add_argument('--bsn_len', type=int, default=9, help = 'spike length for BSN input TFP image')
71 | parser.add_argument('--width', type=int, default=1280)
72 | parser.add_argument('--height', type=int, default=720)
73 | parser.add_argument('--use_small', action='store_true',default = False,help='train at the small GOPRO dataset for debugging')
74 | parser.add_argument('--spike_full', action='store_true',default = False,help='train BSN under high resolution spike stream (1280 * 720)')
75 | parser.add_argument('--test_mode', action='store_true',default = False, help='test the metric')
76 | parser.add_argument('--bsn_path', type=str,default='model/BSN_1000.pth')
77 | opt = parser.parse_args()
78 |
79 | # prepare
80 | ckpt_folder = f"{opt.save_folder}/{opt.exp_name}/ckpts"
81 | img_folder = f"{opt.save_folder}/{opt.exp_name}/imgs"
82 | os.makedirs(ckpt_folder,exist_ok= True)
83 | os.makedirs(img_folder,exist_ok= True)
84 | set_random_seed(opt.seed)
85 | save_opt(opt,f"{opt.save_folder}/{opt.exp_name}/opt.txt")
86 | log_file = f"{opt.save_folder}/{opt.exp_name}/results.txt"
87 | logger = setup_logging(log_file)
88 | if os.path.exists(f'{opt.save_folder}/{opt.exp_name}/tensorboard'):
89 | shutil.rmtree(f'{opt.save_folder}/{opt.exp_name}/tensorboard')
90 | writer = SummaryWriter(f'{opt.save_folder}/{opt.exp_name}/tensorboard')
91 | resize_method = transforms.Resize((opt.height // 4,opt.width // 4),interpolation=transforms.InterpolationMode.BILINEAR)
92 | logger.info(opt)
93 |
94 | # train and test data splitting
95 | train_dataset = SpikeData(opt.base_folder,opt.data_type,'train',use_roi = True,
96 | roi_size = [128 * 4,128 * 4],use_small = opt.use_small,spike_full = opt.spike_full)
97 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True,num_workers=4,pin_memory=True)
98 | test_dataset = SpikeData(opt.base_folder,opt.data_type,'test',use_roi = False,
99 | use_small = opt.use_small,spike_full = opt.spike_full)
100 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False,num_workers=1,pin_memory=True)
101 |
102 | # config for network and training parameters
103 | bsn_teacher = BSN(n_channels = 1,n_output = 1).cuda()
104 | if opt.test_mode:
105 | bsn_teacher.load_state_dict(torch.load(opt.bsn_path))
106 | optim = Adam(bsn_teacher.parameters(), lr=opt.lr)
107 | scheduler = CosineAnnealingLR(optim, T_max=opt.epochs)
108 |
109 | criterion = nn.MSELoss()
110 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
111 | spike_bsn_len = opt.bsn_len
112 | # -------------------- train ----------------------
113 | train_start = datetime.now()
114 | logger.info("Start Training!")
115 | for epoch in range(opt.epochs):
116 | if opt.test_mode == False:
117 | train_loss = AverageMeter()
118 | for batch_idx, (blur,spike,sharp) in enumerate(tqdm(train_loader)):
119 | # read the data
120 | blur,spike = blur.to(device),spike.to(device)
121 | # reconstruct the initial result
122 | for spike_idx in range(20,117,3 * 8):
123 | # TFP Part
124 | tfp = cal_tfp(spike,spike_idx,spike_bsn_len)
125 | tfp_bsn = bsn_teacher(tfp).clip(0,1)
126 | loss = criterion(tfp,tfp_bsn)
127 | optim.zero_grad()
128 | loss.backward()
129 | optim.step()
130 | writer.add_scalar('Training Loss', loss.item())
131 | train_loss.update(loss.item())
132 | logger.info(f"EPOCH {epoch}/{opt.epochs}: Train Loss: {train_loss.avg}")
133 | writer.add_scalar('Epoch Loss', train_loss.avg, epoch)
134 | scheduler.step()
135 | # visualization result
136 | if epoch % 100 == 0:
137 | with torch.no_grad():
138 | # save the network
139 | save_network(bsn_teacher, f"{ckpt_folder}/BSN_{epoch}.pth")
140 | # visualization
141 | for batch_idx, (blur,spike,sharp) in enumerate(tqdm(test_loader)):
142 | if batch_idx in [int(i) for i in np.linspace(0,len(test_loader),5)]:
143 | blur,spike = blur.to(device),spike.to(device)
144 | spike_idx = len(spike[0]) // 2
145 | # TFP Part
146 | tfp = cal_tfp(spike,spike_idx,spike_bsn_len)
147 | tfp_bsn = bsn_teacher(tfp).clip(0,1)
148 | # visualization
149 | save_img(img = normal_img(blur[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_blur.png')
150 | save_img(img = normal_img(sharp[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_sharp.png')
151 | save_img(img = normal_img(tfp[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_tfp.png')
152 | save_img(img = normal_img(tfp_bsn[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_tfp_bsn.png')
153 | else:
154 | continue
155 | # save metric result
156 | if epoch % 100 == 0:
157 | with torch.no_grad():
158 | # calculate the metric
159 | metrics = {}
160 | method_list = ['TFP','BSN']
161 | metric_list = ['mse','ssim','psnr','lpips']
162 | for method_name in method_list:
163 | metrics[method_name] = {} # 初始化每个方法的字典
164 | for metric_name in metric_list:
165 | metrics[method_name][metric_name] = AverageMeter()
166 | for batch_idx, (blur,spike,sharp) in enumerate(tqdm(test_loader)):
167 | blur,spike = blur.to(device),spike.to(device)
168 | sharp = 0.11 * sharp[:,0:1] + 0.59 * sharp[:,1:2] + 0.3 * sharp[:,2:3]
169 | sharp = resize_method(sharp)
170 | spike_idx = len(spike[0]) // 2
171 | # TFP and BSN
172 | tfp = cal_tfp(spike,spike_idx,spike_bsn_len)
173 | tfp_bsn = bsn_teacher(tfp).clip(0,1)
174 | # Metric
175 | for key in metric_list :
176 | metrics['TFP'][key].update(compute_img_metric(tfp,sharp,key))
177 | metrics['BSN'][key].update(compute_img_metric(tfp_bsn,sharp,key))
178 | # Print all results
179 | for method_name in method_list:
180 | re_msg = ''
181 | for metric_name in metric_list:
182 | re_msg += metric_name + ": " + "{:.3f}".format(metrics[method_name][metric_name].avg) + " "
183 | logger.info(f"{method_name}: " + re_msg)
184 | writer.add_scalar(f'{method_name}/{metric_name}', metrics[method_name][metric_name].avg, epoch)
185 | # test mode
186 | if opt.test_mode:
187 | break
188 | writer.close()
189 |
190 |
--------------------------------------------------------------------------------
/train_deblur.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 | import sys
7 | sys.path.append("..")
8 | import os
9 | import argparse
10 | from datetime import datetime
11 | from torch.optim import Adam
12 | from codes.utils import *
13 | from codes.metrics import compute_img_metric
14 | import glob
15 | from torch.utils.tensorboard import SummaryWriter
16 | from tqdm import tqdm
17 | import shutil
18 | from codes.dataset import *
19 | from codes.model.ldn_model import Deblur_Net
20 | from codes.model.bsn_model import BSN
21 | from codes.model.edsr_model import EDSR
22 | import lpips
23 | from pytorch_msssim import SSIM
24 | from torch.optim.lr_scheduler import CosineAnnealingLR
25 |
26 | # SDM model
27 | def sdm_deblur(blur,tfp_long,tfp_short):
28 | """ General SDM model.
29 |
30 | Args:
31 | blur (tensor): blurry input. [bs,3,w,h]
32 | tfp_long (tensor): long tfp corresponding to the blurry input. [bs,1,w,h]
33 | tfp_short (tensor): short tfp corresponding to the short-exposure image. [bs,1,w,h]
34 |
35 | Returns:
36 | rgb_sdm: deblur result. [bs,3,w,h]
37 | """
38 | tfp_long = tfp_long.repeat(1,3,1,1)
39 | tfp_short = tfp_short.repeat(1,3,1,1)
40 | rgb_sdm = blur / tfp_long
41 | rgb_sdm[tfp_long == 0] = 0
42 | rgb_sdm = rgb_sdm * tfp_short
43 | return rgb_sdm
44 |
45 |
46 | # TFP model
47 | def cal_tfp(spike,spike_idx,tfp_len):
48 | """TPF Model
49 |
50 | Args:
51 | spike (tensor): spike sequence. [bs,137,w,h]
52 | spike_idx (int): central idx of the virtual exposure window
53 | tfp_len (_type_): length of the virtual exposure window. 97 for long-TFP, [7,9,11] for short-TFP.
54 |
55 | Returns:
56 | tfp_pred: tfp result
57 | """
58 | spike = spike[:,spike_idx - tfp_len // 2:spike_idx + tfp_len // 2 + 1,:,:]
59 | tfp_pred = torch.mean(spike,dim = 1,keepdim = True)
60 | return tfp_pred
61 |
62 | if __name__ == '__main__':
63 | # parameters
64 | parser = argparse.ArgumentParser()
65 | parser.add_argument('--base_folder', default='GOPRO/')
66 | parser.add_argument('--save_folder', default='exp/Deblur')
67 | parser.add_argument('--data_type', default='GOPRO')
68 | parser.add_argument('--exp_name', default='test')
69 | parser.add_argument('--bsn_path', default='model/BSN_1000.pth')
70 | parser.add_argument('--sr_path', default='model/SR_70.pth')
71 | parser.add_argument('--deblur_path', default='model/DeblurNet_100.pth')
72 | parser.add_argument('--epochs', type=int, default=101)
73 | parser.add_argument('--lr', type=float, default=1e-3)
74 | parser.add_argument('--seed', type=int, default=42)
75 | parser.add_argument('--spike_deblur_len', type=int, default=21)
76 | parser.add_argument('--spike_bsn_len', type=int, default=9)
77 | parser.add_argument('--lambda_tea', type=float, default=1)
78 | parser.add_argument('--lambda_reblur', type=float, default=100)
79 | parser.add_argument('--blur_step', type=int, default=24)
80 | parser.add_argument('--use_small', action='store_true',default = False)
81 | parser.add_argument('--test_mode', action='store_true',default = False)
82 | parser.add_argument('--use_ssim', action='store_true',default = False, help= 'use ssim loss or not')
83 | parser.add_argument('--bs', type=int, default=4)
84 | parser.add_argument('--roi_size', type=int, default= 512)
85 | opt = parser.parse_args()
86 |
87 | # prepare
88 | ckpt_folder = f"{opt.save_folder}/{opt.exp_name}/ckpts"
89 | img_folder = f"{opt.save_folder}/{opt.exp_name}/imgs"
90 | os.makedirs(ckpt_folder,exist_ok= True)
91 | os.makedirs(img_folder,exist_ok= True)
92 | set_random_seed(opt.seed)
93 | save_opt(opt,f"{opt.save_folder}/{opt.exp_name}/opt.txt")
94 | log_file = f"{opt.save_folder}/{opt.exp_name}/results.txt"
95 | logger = setup_logging(log_file)
96 | if os.path.exists(f'{opt.save_folder}/{opt.exp_name}/tensorboard'):
97 | shutil.rmtree(f'{opt.save_folder}/{opt.exp_name}/tensorboard')
98 | writer = SummaryWriter(f'{opt.save_folder}/{opt.exp_name}/tensorboard')
99 | logger.info(opt)
100 |
101 | # train and test data splitting
102 | train_dataset = SpikeData(opt.base_folder,opt.data_type,'test',use_roi = True, roi_size = [opt.roi_size,opt.roi_size],use_small= opt.use_small)
103 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True,num_workers=4,pin_memory=True)
104 | test_dataset = SpikeData(opt.base_folder,opt.data_type,'test',use_roi = False,use_small= opt.use_small)
105 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False,num_workers=1,pin_memory=True)
106 |
107 | # config for network and training parameters
108 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
109 | # BSN
110 | bsn_net = BSN(n_channels = 1,n_output = 1).to(device)
111 | bsn_net.load_state_dict(torch.load(opt.bsn_path))
112 | for param in bsn_net.parameters():
113 | param.requires_grad = False
114 | # SR
115 | sr_net = EDSR(color_num = 1).to(device)
116 | sr_net.load_state_dict(torch.load(opt.sr_path))
117 | for param in sr_net.parameters():
118 | param.requires_grad = False
119 | # Deblur
120 | spike_bsn_len = opt.spike_bsn_len
121 | spike_deblur_len = opt.spike_deblur_len
122 | deblur_net = Deblur_Net(spike_dim = spike_deblur_len).to(device)
123 | if opt.test_mode == True:
124 | deblur_net.load_state_dict(torch.load(opt.deblur_path))
125 | # other settting
126 | optim = Adam(deblur_net.parameters(), lr=opt.lr)
127 | scheduler = CosineAnnealingLR(optim, T_max=opt.epochs)
128 | loss_lpips = lpips.LPIPS(net='vgg').to(device)
129 | loss_ssim = SSIM(data_range=1.0, size_average=True, channel=3).to(device)
130 | loss_mse = nn.MSELoss()
131 | # -------------------- train ----------------------
132 | train_start = datetime.now()
133 | logger.info("Start Training!")
134 | for epoch in range(opt.epochs):
135 | # loss definition
136 | train_loss_all = AverageMeter()
137 | tea_loss_all = AverageMeter()
138 | reblur_loss_all = AverageMeter()
139 | if opt.test_mode == False:
140 | for batch_idx, (blur,spike,sharp) in enumerate(tqdm(train_loader)):
141 | # read the data
142 | blur,spike = blur.to(device),spike.to(device)
143 | # reconstruct the initial result
144 | train_loss = 0
145 | tea_loss = 0
146 | reblur_loss = 0
147 | reblur = []
148 | spike_start,start_end,spike_step = 20,117,opt.blur_step
149 | spike_num = (start_end - spike_start - 1) / spike_step + 1
150 | for spike_idx in range(spike_start,start_end,spike_step):
151 | # Long-TFP Part
152 | tfp_long = cal_tfp(spike,len(spike[0]) // 2 ,97)
153 | tfp_long_sr = sr_net(tfp_long).clip(0,1)
154 | # Short-TFP Part
155 | tfp = cal_tfp(spike,spike_idx,spike_bsn_len)
156 | tfp_bsn = bsn_net(tfp).clip(0,1)
157 | tfp_bsn_sr = sr_net(tfp_bsn).clip(0,1)
158 | deblur_tea = sdm_deblur(blur,tfp_long_sr,tfp_bsn_sr).clip(0,1)
159 | # Deblur_Net Part
160 | spike_roi = spike[:,spike_idx - spike_deblur_len // 2:spike_idx + spike_deblur_len // 2 + 1]
161 | tfp_long = cal_tfp(spike,len(spike[0]) // 2 ,97)
162 | deblur_pred = deblur_net(blur,spike_roi)
163 | deblur_pred = deblur_pred.clip(0,1)
164 | # Loss
165 | if opt.use_ssim:
166 | tea_loss += opt.lambda_tea * (1 - loss_ssim(deblur_tea,deblur_pred)) / spike_num
167 | tea_loss += opt.lambda_tea * torch.mean(loss_lpips(deblur_tea,deblur_pred) ) / spike_num
168 | reblur.append(deblur_pred)
169 | reblur = torch.mean(torch.stack(reblur,dim = 0),dim = 0,keepdim = False)
170 | reblur_loss += opt.lambda_reblur * loss_mse(blur,reblur)
171 | train_loss = tea_loss + reblur_loss
172 | # optimize
173 | optim.zero_grad()
174 | train_loss.backward()
175 | optim.step()
176 | scheduler.step()
177 | writer.add_scalar('Training Loss', train_loss.item())
178 | # update
179 | train_loss_all.update(train_loss.item())
180 | tea_loss_all.update(tea_loss.item())
181 | reblur_loss_all.update(reblur_loss.item())
182 | logger.info(f"EPOCH {epoch}/{opt.epochs}: Total Train Loss: {train_loss_all.avg}, Tea Loss: {tea_loss_all.avg}, Reblur Loss: {reblur_loss_all.avg}")
183 | writer.add_scalar('Epoch Loss', train_loss_all.avg, epoch)
184 | # visualization result
185 | if epoch % 5 == 0:
186 | with torch.no_grad():
187 | # save the network
188 | save_network(deblur_net, f"{ckpt_folder}/DeblurNet_{epoch}.pth")
189 | # visualization
190 | for batch_idx, (blur,spike,sharp) in enumerate(tqdm(test_loader)):
191 | blur,spike = blur.to(device),spike.to(device)
192 | spike_idx = len(spike[0]) // 2
193 | # Long-TFP Part
194 | tfp_long = cal_tfp(spike,len(spike[0]) // 2 ,97)
195 | tfp_long_sr = sr_net(tfp_long).clip(0,1)
196 | # Short-TFP Part
197 | tfp = cal_tfp(spike,spike_idx,spike_bsn_len)
198 | tfp_bsn = bsn_net(tfp).clip(0,1)
199 | tfp_bsn_sr = sr_net(tfp_bsn).clip(0,1)
200 | deblur_tea = sdm_deblur(blur,tfp_long_sr,tfp_bsn_sr).clip(0,1)
201 | # Deblur_Net Part
202 | spike_roi = spike[:,spike_idx - spike_deblur_len // 2:spike_idx + spike_deblur_len // 2 + 1]
203 | deblur_pred = deblur_net(blur,spike_roi)
204 | deblur_pred = deblur_pred.clip(0,1)
205 | # visualization
206 | if batch_idx in [int(i) for i in np.linspace(0,len(test_loader),5)]:
207 | save_img(img = normal_img(deblur_tea[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_tea.png')
208 | save_img(img = normal_img(blur[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_blur.png')
209 | save_img(img = normal_img(sharp[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_sharp.png')
210 | save_img(img = normal_img(deblur_pred[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_deblur.png')
211 |
212 | # save metric result
213 | if epoch % 50 == 0:
214 | with torch.no_grad():
215 | # calculate the metric
216 | metrics = {}
217 | method_list = ['SDM','Deblur_Net']
218 | # metric_list = ['mse','ssim','psnr','lpips']
219 | metric_list = ['ssim','psnr']
220 | for method_name in method_list:
221 | metrics[method_name] = {} # 初始化每个方法的字典
222 | for metric_name in metric_list:
223 | metrics[method_name][metric_name] = AverageMeter()
224 | for batch_idx, (blur,spike,sharp) in enumerate(tqdm(test_loader)):
225 | blur,spike = blur.to(device),spike.to(device)
226 | # Long-TFP Part
227 | tfp_long = cal_tfp(spike,len(spike[0]) // 2 ,97)
228 | tfp_long_sr = sr_net(tfp_long).clip(0,1)
229 | # Short-TFP Part
230 | tfp = cal_tfp(spike,spike_idx,spike_bsn_len)
231 | tfp_bsn = bsn_net(tfp).clip(0,1)
232 | tfp_bsn_sr = sr_net(tfp_bsn).clip(0,1)
233 | deblur_tea = sdm_deblur(blur,tfp_long_sr,tfp_bsn_sr).clip(0,1)
234 | # Deblur_Net Part
235 | spike_roi = spike[:,spike_idx - spike_deblur_len // 2:spike_idx + spike_deblur_len // 2 + 1]
236 | tfp_long = cal_tfp(spike,len(spike[0]) // 2 ,97)
237 | deblur_pred = deblur_net(blur,spike_roi)
238 | deblur_pred = deblur_pred.clip(0,1)
239 | # Metric
240 | for key in metric_list :
241 | metrics['SDM'][key].update(compute_img_metric(deblur_tea,sharp,key))
242 | metrics['Deblur_Net'][key].update(compute_img_metric(deblur_pred,sharp,key))
243 | # Print all results
244 | for method_name in method_list:
245 | re_msg = ''
246 | for metric_name in metric_list:
247 | re_msg += metric_name + ": " + "{:.3f}".format(metrics[method_name][metric_name].avg) + " "
248 | logger.info(f"{method_name}: " + re_msg)
249 | writer.add_scalar(f'{method_name}/{metric_name}', metrics[method_name][metric_name].avg, epoch)
250 | writer.close()
251 |
--------------------------------------------------------------------------------
/train_sr.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 | import sys
7 | sys.path.append("..")
8 | import os
9 | import argparse
10 | from datetime import datetime
11 | from torch.optim import Adam
12 | from torchvision import transforms
13 | from torch.utils.tensorboard import SummaryWriter
14 | from tqdm import tqdm
15 | import shutil
16 |
17 | from codes.dataset import *
18 | from codes.utils import *
19 | from codes.metrics import compute_img_metric
20 | from codes.model.bsn_model import BSN
21 | from codes.model.edsr_model import EDSR
22 |
23 |
24 | # SDM model
25 | def sdm_deblur(blur,tfp_long,tfp_short):
26 | """ General SDM model.
27 |
28 | Args:
29 | blur (tensor): blurry input. [bs,3,w,h]
30 | tfp_long (tensor): long tfp corresponding to the blurry input. [bs,1,w,h]
31 | tfp_short (tensor): short tfp corresponding to the short-exposure image. [bs,1,w,h]
32 |
33 | Returns:
34 | rgb_sdm: deblur result. [bs,3,w,h]
35 | """
36 | tfp_long = tfp_long.repeat(1,3,1,1)
37 | tfp_short = tfp_short.repeat(1,3,1,1)
38 | rgb_sdm = blur / tfp_long
39 | rgb_sdm[tfp_long == 0] = 0
40 | rgb_sdm = rgb_sdm * tfp_short
41 | return rgb_sdm
42 |
43 | # TFP model
44 | def cal_tfp(spike,spike_idx,tfp_len):
45 | """TPF Model
46 |
47 | Args:
48 | spike (tensor): spike sequence. [bs,137,w,h]
49 | spike_idx (int): central idx of the virtual exposure window
50 | tfp_len (_type_): length of the virtual exposure window. 97 for long-TFP, [7,9,11] for short-TFP.
51 |
52 | Returns:
53 | tfp_pred: tfp result
54 | """
55 | spike = spike[:,spike_idx - tfp_len // 2:spike_idx + tfp_len // 2 + 1,:,:]
56 | tfp_pred = torch.mean(spike,dim = 1,keepdim = True)
57 | return tfp_pred
58 |
59 | # main function
60 | if __name__ == '__main__':
61 | # parameters
62 | parser = argparse.ArgumentParser()
63 | parser.add_argument('--base_folder', type=str,default=r'GOPRO')
64 | parser.add_argument('--save_folder', type=str,default='exp/SR')
65 | parser.add_argument('--data_type',type=str, default='GOPRO')
66 | parser.add_argument('--exp_name', type=str,default='test')
67 | parser.add_argument('--bsn_path', type=str,default='model/BSN_1000.pth')
68 | parser.add_argument('--sr_path', type=str,default='model/SR_70.pth')
69 | parser.add_argument('--epochs', type=int, default=101)
70 | parser.add_argument('--lr', type=float, default=2e-4)
71 | parser.add_argument('--seed', type=int, default=42)
72 | parser.add_argument('--bsn_len', type=int, default=9)
73 | parser.add_argument('--use_small', action='store_true',default = False)
74 | parser.add_argument('--test_mode', action='store_true',default = False)
75 | opt = parser.parse_args()
76 |
77 | # prepare
78 | ckpt_folder = f"{opt.save_folder}/{opt.exp_name}/ckpts"
79 | img_folder = f"{opt.save_folder}/{opt.exp_name}/imgs"
80 | os.makedirs(ckpt_folder,exist_ok= True)
81 | os.makedirs(img_folder,exist_ok= True)
82 | set_random_seed(opt.seed)
83 | save_opt(opt,f"{opt.save_folder}/{opt.exp_name}/opt.txt")
84 | log_file = f"{opt.save_folder}/{opt.exp_name}/results.txt"
85 | logger = setup_logging(log_file)
86 | if os.path.exists(f'{opt.save_folder}/{opt.exp_name}/tensorboard'):
87 | shutil.rmtree(f'{opt.save_folder}/{opt.exp_name}/tensorboard')
88 | writer = SummaryWriter(f'{opt.save_folder}/{opt.exp_name}/tensorboard')
89 | logger.info(opt)
90 |
91 | # train and test data splitting
92 | train_dataset = SpikeData(opt.base_folder,opt.data_type,'train',use_roi = True,
93 | roi_size = [128 * 4,128 * 4],use_small = opt.use_small)
94 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True,num_workers=4,pin_memory=True)
95 | test_dataset = SpikeData(opt.base_folder,opt.data_type,'test',use_roi = False,use_small = opt.use_small)
96 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False,num_workers=1,pin_memory=True)
97 |
98 | # config for network and training parameters
99 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
100 | sr_net = EDSR(color_num = 1).to(device)
101 | if opt.test_mode:
102 | sr_net.load_state_dict(torch.load(opt.sr_path))
103 | optim = Adam(sr_net.parameters(), lr=opt.lr)
104 | bsn_net = BSN(n_channels=1, n_output=1).to(device)
105 | bsn_net.load_state_dict(torch.load(opt.bsn_path))
106 | for param in bsn_net.parameters():
107 | param.requires_grad = False
108 | criterion = nn.MSELoss()
109 | spike_bsn_len = opt.bsn_len
110 | resize_method = transforms.Resize((720,1280),interpolation=transforms.InterpolationMode.NEAREST)
111 | # -------------------- train ----------------------
112 | train_start = datetime.now()
113 | logger.info("Start Training!")
114 | for epoch in range(opt.epochs):
115 | train_loss = AverageMeter()
116 | if opt.test_mode == False:
117 | for batch_idx, (blur,spike,sharp) in enumerate(tqdm(train_loader)):
118 | # read the data
119 | blur = 0.11 * blur[:,0:1] + 0.59 * blur[:,1:2] + 0.3 * blur[:,2:3]
120 | blur,spike = blur.to(device),spike.to(device)
121 | for spike_idx in [len(spike[0]) // 2]:
122 | # Long-TFP Part
123 | tfp_long = cal_tfp(spike,spike_idx,97)
124 | sr_tfp = sr_net(tfp_long).clip(0,1)
125 | loss = criterion(sr_tfp,blur)
126 | optim.zero_grad()
127 | loss.backward()
128 | optim.step()
129 | writer.add_scalar('Training Loss', loss.item())
130 | train_loss.update(loss.item())
131 | logger.info(f"EPOCH {epoch}/{opt.epochs}: Train Loss: {train_loss.avg}")
132 | writer.add_scalar('Epoch Loss', train_loss.avg, epoch)
133 | # visualization result
134 | if epoch % 5 == 0:
135 | with torch.no_grad():
136 | # save the network
137 | save_network(sr_net, f"{ckpt_folder}/SR_{epoch}.pth")
138 | # visualization
139 | for batch_idx, (blur,spike,sharp) in enumerate(tqdm(test_loader)):
140 | if batch_idx in [i for i in range(0,len(test_loader),200)]:
141 | blur,spike = blur.to(device),spike.to(device)
142 | spike_idx = len(spike[0]) // 2
143 | # Long-TFP Part
144 | tfp_long = cal_tfp(spike,spike_idx,97)
145 | tfp_long_sr = sr_net(tfp_long).clip(0,1)
146 | save_img(img = normal_img(blur[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_blur.png')
147 | save_img(img = normal_img(sharp[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_sharp.png')
148 | save_img(img = normal_img(tfp_long[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_tfp_long.png')
149 | save_img(img = normal_img(tfp_long_sr[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_tfp_long_sr.png')
150 | # Short-TFP Part
151 | tfp = cal_tfp(spike,spike_idx,spike_bsn_len)
152 | tfp_resize = resize_method(tfp).clip(0,1)
153 | tfp_bsn = bsn_net(tfp).clip(0,1)
154 | tfp_bsn_resize = resize_method(tfp_bsn).clip(0,1)
155 | tfp_bsn_sr = sr_net(tfp_bsn).clip(0,1)
156 | deblur = sdm_deblur(blur,tfp_long_sr,tfp_bsn_sr)
157 | save_img(img = normal_img(tfp[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_tfp_short.png')
158 | save_img(img = normal_img(tfp_bsn[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_tfp_short_bsn.png')
159 | save_img(img = normal_img(tfp_bsn_sr[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_tfp_short_bsn_sr.png')
160 | save_img(img = normal_img(tfp_bsn_resize[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_tfp_short_bsn_resize.png')
161 | save_img(img = normal_img(tfp_resize[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_tfp_short_resize.png')
162 | save_img(img = normal_img(deblur[0]),path = f'{img_folder}/{epoch:03}_{batch_idx:04}_deblur.png')
163 | if opt.test_mode:
164 | break
165 | else:
166 | continue
167 | # save metric result
168 | if epoch % 10 == 0:
169 | with torch.no_grad():
170 | # calculate the metric
171 | metrics = {}
172 | method_list = ['SDM','Blur_SR','Blur_Resize','BSN_SR','BSN_Resize','TFP_Resize']
173 | # metric_list = ['mse','ssim','psnr','lpips']
174 | metric_list = ['ssim','psnr']
175 | for method_name in method_list:
176 | metrics[method_name] = {} # 初始化每个方法的字典
177 | for metric_name in metric_list:
178 | metrics[method_name][metric_name] = AverageMeter()
179 | for batch_idx, (blur,spike,sharp) in enumerate(tqdm(test_loader)):
180 | blur,spike = blur.to(device),spike.to(device)
181 | blur_gray = 0.11 * blur[:,0:1] + 0.59 * blur[:,1:2] + 0.3 * blur[:,2:3]
182 | sharp_gray = 0.11 * sharp[:,0:1] + 0.59 * sharp[:,1:2] + 0.3 * sharp[:,2:3]
183 | spike_idx = len(spike[0]) // 2
184 | # Metric
185 | # Long-TFP Part
186 | tfp_long = cal_tfp(spike,spike_idx,97)
187 | tfp_long_resize = resize_method(tfp_long).clip(0,1)
188 | tfp_long_sr = sr_net(tfp_long).clip(0,1)
189 | # Short-TFP Part
190 | tfp = cal_tfp(spike,spike_idx,spike_bsn_len)
191 | tfp_resize = resize_method(tfp).clip(0,1)
192 | tfp_bsn = bsn_net(tfp).clip(0,1)
193 | tfp_bsn_resize = resize_method(tfp_bsn).clip(0,1)
194 | tfp_bsn_sr = sr_net(tfp_bsn).clip(0,1)
195 | deblur = sdm_deblur(blur,tfp_long_sr,tfp_bsn_sr)
196 | for key in metric_list :
197 | # SDM
198 | metrics['SDM'][key].update(compute_img_metric(deblur,sharp,key))
199 | # BLUR
200 | metrics['Blur_SR'][key].update(compute_img_metric(tfp_long_sr,blur_gray,key))
201 | metrics['Blur_Resize'][key].update(compute_img_metric(tfp_long_resize,blur_gray,key))
202 | # BSN
203 | metrics['BSN_SR'][key].update(compute_img_metric(tfp_bsn_sr,sharp_gray,key))
204 | metrics['BSN_Resize'][key].update(compute_img_metric(tfp_bsn_resize,sharp_gray,key))
205 | # TFP
206 | metrics['TFP_Resize'][key].update(compute_img_metric(tfp_resize,sharp_gray,key))
207 |
208 | # Print all results
209 | for method_name in method_list:
210 | re_msg = ''
211 | for metric_name in metric_list:
212 | re_msg += metric_name + ": " + "{:.3f}".format(metrics[method_name][metric_name].avg) + " "
213 | logger.info(f"{method_name}: " + re_msg)
214 | writer.add_scalar(f'{method_name}/{metric_name}', metrics[method_name][metric_name].avg, epoch)
215 | # stop
216 | if opt.test_mode:
217 | break
218 | writer.close()
--------------------------------------------------------------------------------