├── .gitignore ├── README.md ├── SPyNet.py ├── dataset.py ├── loss.py ├── model.py ├── modules.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | spynet_20210409-c6c1bd09.pth 3 | log_dir 4 | REDS -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [WIP]pytorch_basicVSR 2 | unofficial code for basicVSR. 3 | 4 | 5 | This code is based on Open-MMLab's one(https://github.com/open-mmlab/mmediting) 6 | 7 | # Dependencies 8 | * Pytorch 1.10.0 9 | -------------------------------------------------------------------------------- /SPyNet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on Open-MMLab's one. 3 | https://github.com/open-mmlab/mmediting 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from modules import flow_warp 10 | 11 | class SPyNet(nn.Module): 12 | """SPyNet network structure. 13 | The difference to the SPyNet in [tof.py] is that 14 | 1. more SPyNetBasicModule is used in this version, and 15 | 2. no batch normalization is used in this version. 16 | Paper: 17 | Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017 18 | Args: 19 | pretrained (str): path for pre-trained SPyNet. Default: None. 20 | """ 21 | 22 | def __init__(self): 23 | super().__init__() 24 | 25 | self.basic_module = nn.ModuleList( 26 | [SPyNetBasicModule() for _ in range(6)] 27 | ) 28 | 29 | #self.load_state_dict(torch.load('spynet_20210409-c6c1bd09.pth')) 30 | 31 | self.register_buffer( 32 | 'mean', 33 | torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 34 | self.register_buffer( 35 | 'std', 36 | torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 37 | 38 | def compute_flow(self, ref, supp): 39 | """Compute flow from ref to supp. 40 | Note that in this function, the images are already resized to a 41 | multiple of 32. 42 | Args: 43 | ref (Tensor): Reference image with shape of (n, 3, h, w). 44 | supp (Tensor): Supporting image with shape of (n, 3, h, w). 45 | Returns: 46 | Tensor: Estimated optical flow: (n, 2, h, w). 47 | """ 48 | n, _, h, w = ref.size() 49 | 50 | # normalize the input images 51 | ref = [(ref - self.mean) / self.std] 52 | supp = [(supp - self.mean) / self.std] 53 | 54 | # generate downsampled frames 55 | for level in range(5): 56 | ref.append( 57 | F.avg_pool2d( 58 | input=ref[-1], 59 | kernel_size=2, 60 | stride=2, 61 | count_include_pad=False 62 | ) 63 | ) 64 | supp.append( 65 | F.avg_pool2d( 66 | input=supp[-1], 67 | kernel_size=2, 68 | stride=2, 69 | count_include_pad=False 70 | ) 71 | ) 72 | ref = ref[::-1] 73 | supp = supp[::-1] 74 | 75 | # flow computation 76 | flow = ref[0].new_zeros(n, 2, h // 32, w // 32) 77 | for level in range(len(ref)): 78 | if level == 0: 79 | flow_up = flow 80 | else: 81 | flow_up = F.interpolate( 82 | input=flow, 83 | scale_factor=2, 84 | mode='bilinear', 85 | align_corners=True) * 2.0 86 | 87 | # add the residue to the upsampled flow 88 | flow = flow_up + self.basic_module[level]( 89 | torch.cat([ 90 | ref[level], 91 | flow_warp( 92 | supp[level], 93 | flow_up.permute(0, 2, 3, 1), 94 | padding_mode='border'), flow_up 95 | ], 1)) 96 | 97 | return flow 98 | 99 | def forward(self, ref, supp): 100 | """Forward function of SPyNet. 101 | This function computes the optical flow from ref to supp. 102 | Args: 103 | ref (Tensor): Reference image with shape of (n, 3, h, w). 104 | supp (Tensor): Supporting image with shape of (n, 3, h, w). 105 | Returns: 106 | Tensor: Estimated optical flow: (n, 2, h, w). 107 | """ 108 | 109 | # upsize to a multiple of 32 110 | h, w = ref.shape[2:4] 111 | w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1) 112 | h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1) 113 | ref = F.interpolate( 114 | input=ref, size=(h_up, w_up), mode='bilinear', align_corners=False) 115 | supp = F.interpolate( 116 | input=supp, 117 | size=(h_up, w_up), 118 | mode='bilinear', 119 | align_corners=False) 120 | 121 | # compute flow, and resize back to the original resolution 122 | flow = F.interpolate( 123 | input=self.compute_flow(ref, supp), 124 | size=(h, w), 125 | mode='bilinear', 126 | align_corners=False) 127 | 128 | # adjust the flow values 129 | flow[:, 0, :, :] *= float(w) / float(w_up) 130 | flow[:, 1, :, :] *= float(h) / float(h_up) 131 | 132 | return flow 133 | 134 | 135 | class SPyNetBasicModule(nn.Module): 136 | """Basic Module for SPyNet. 137 | Paper: 138 | Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017 139 | """ 140 | 141 | def __init__(self): 142 | super().__init__() 143 | 144 | self.basic_module = nn.Sequential( 145 | nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), 146 | nn.ReLU(), 147 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), 148 | nn.ReLU(), 149 | nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), 150 | nn.ReLU(), 151 | nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), 152 | nn.ReLU(), 153 | nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3) 154 | ) 155 | 156 | def forward(self, tensor_input): 157 | """ 158 | Args: 159 | tensor_input (Tensor): Input tensor with shape (b, 8, h, w). 160 | 8 channels contain: 161 | [reference image (3), neighbor image (3), initial flow (2)]. 162 | Returns: 163 | Tensor: Refined flow with shape (b, 2, h, w) 164 | """ 165 | return self.basic_module(tensor_input) -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on Open-MMLab's one. 3 | https://github.com/open-mmlab/mmediting 4 | """ 5 | 6 | import torch 7 | from torch import nn 8 | from torch.utils.data import Dataset 9 | import torchvision.transforms as T 10 | from torchvision.io import read_image 11 | from PIL import Image 12 | import glob 13 | import numpy as np 14 | import os 15 | import random 16 | 17 | def generate_segment_indices(videopath1,videopath2,num_input_frames=10,filename_tmpl='{:08d}.png'): 18 | """generate segment function 19 | Args: 20 | videopath1,2 (str): input directory which contains sequential frames 21 | filename_tmpl (str): template which represents sequential frames 22 | Returns: 23 | Tensor, Tensor: Output sequence with shape (t, c, h, w) 24 | """ 25 | seq_length=len(glob.glob(f'{videopath1}/*.png')) 26 | seq_length2=len(glob.glob(f'{videopath2}/*.png')) 27 | 28 | if seq_length!=seq_length2: 29 | raise ValueError(f'videopath1 and videopath2 must have same number of frames\nbut they have {seq_length} and {seq_length2}') 30 | if num_input_frames > seq_length: 31 | raise ValueError(f'num_input_frames{num_input_frames} must be greater than frames in {videopath1} \n and {videopath2}') 32 | 33 | start_frame_idx = np.random.randint(0, seq_length - num_input_frames) 34 | end_frame_idx = start_frame_idx + num_input_frames 35 | segment1=[read_image(os.path.join(videopath1,filename_tmpl.format(i))) / 255. for i in range(start_frame_idx,end_frame_idx)] 36 | segment2=[read_image(os.path.join(videopath2,filename_tmpl.format(i))) / 255. for i in range(start_frame_idx,end_frame_idx)] 37 | return torch.stack(segment1),torch.stack(segment2) 38 | 39 | def pair_random_crop_seq(hr_seq,lr_seq,patch_size,scale_factor=4): 40 | """crop image pair for data augment 41 | Args: 42 | hr (Tensor): hr images with shape (t, c, 4h, 4w). 43 | lr (Tensor): lr images with shape (t, c, h, w). 44 | patch_size (int): the size of cropped image 45 | Returns: 46 | Tensor, Tensor: cropped images(hr,lr) 47 | """ 48 | seq_lenght=lr_seq.size(dim=0) 49 | gt_transformed=torch.empty(seq_lenght,3,patch_size*scale_factor,patch_size*scale_factor) 50 | lq_transformed=torch.empty(seq_lenght,3,patch_size,patch_size) 51 | i,j,h,w=T.RandomCrop.get_params(lr_seq[0],output_size=(patch_size,patch_size)) 52 | gt_transformed=T.functional.crop(hr_seq,i*scale_factor,j*scale_factor,h*scale_factor,w*scale_factor) 53 | lq_transformed=T.functional.crop(lr_seq,i,j,h,w) 54 | return gt_transformed,lq_transformed 55 | 56 | def pair_random_flip_seq(sequence1,sequence2,p=0.5,horizontal=True,vertical=True): 57 | """flip image pair for data augment 58 | Args: 59 | sequence1 (Tensor): images with shape (t, c, h, w). 60 | sequence2 (Tensor): images with shape (t, c, h, w). 61 | p (float): probability of the image being flipped. 62 | Default: 0.5 63 | horizontal (bool): Store `False` when don't flip horizontal 64 | Default: `True`. 65 | vertical (bool): Store `False` when don't flip vertical 66 | Default: `True`. 67 | Returns: 68 | Tensor, Tensor: cropped images 69 | """ 70 | T_length=sequence1.size(dim=0) 71 | # Random horizontal flipping 72 | hfliped1=sequence1.clone() 73 | hfliped2=sequence2.clone() 74 | if horizontal and random.random() > 0.5: 75 | hfliped1 = T.functional.hflip(sequence1) 76 | hfliped2 = T.functional.hflip(sequence2) 77 | 78 | # Random vertical flipping 79 | vfliped1=hfliped1.clone() 80 | vfliped2=hfliped2.clone() 81 | if vertical and random.random() > 0.5: 82 | vfliped1 = T.functional.vflip(hfliped1) 83 | vfliped2 = T.functional.vflip(hfliped2) 84 | return vfliped1,vfliped2 85 | 86 | def pair_random_transposeHW_seq(sequence1,sequence2,p=0.5): 87 | """crop image pair for data augment 88 | Args: 89 | sequence1 (Tensor): images with shape (t, c, h, w). 90 | sequence2 (Tensor): images with shape (t, c, h, w). 91 | p (float): probability of the image being cropped. 92 | Default: 0.5 93 | Returns: 94 | Tensor, Tensor: cropped images 95 | """ 96 | T_length=sequence1.size(dim=0) 97 | transformed1=sequence1.clone() 98 | transformed2=sequence2.clone() 99 | if random.random() > 0.5: 100 | transformed1=torch.transpose(sequence1,2,3) 101 | transformed2=torch.transpose(sequence2,2,3) 102 | return transformed1,transformed2 103 | 104 | class REDSDataset(Dataset): 105 | """REDS dataset for video super resolution. 106 | Args: 107 | gt_dir (str): Path to a gt folder. 108 | lq_dir (str): Path to a lq folder. 109 | patch_size (int): the size of training image 110 | Default: 256 111 | is_test (bool): Store `True` when building test dataset. 112 | Default: `False`. 113 | max_keys (int): clip names(make keys '000' to 'max_keys:03d') 114 | Default: 270(make keys '000' to '270') 115 | """ 116 | def __init__(self, gt_dir, lq_dir,scale_factor=4, patch_size=256, num_input_frames=10, is_test=False,max_keys=270,filename_tmpl='{:08d}.png'): 117 | val_keys=['000', '011', '015', '020'] 118 | if is_test: 119 | self.keys = [f'{i:03d}' for i in range(0, max_keys) if f'{i:03d}' in val_keys] 120 | else: 121 | self.keys = [f'{i:03d}' for i in range(0, max_keys) if f'{i:03d}' not in val_keys] 122 | self.gt_dir=gt_dir 123 | self.lq_dir=lq_dir 124 | self.scale_factor=scale_factor 125 | self.patch_size=patch_size 126 | self.num_input_frames=num_input_frames 127 | self.is_test=is_test 128 | self.gt_seq_paths=[os.path.join(self.gt_dir,k) for k in self.keys] 129 | self.lq_seq_paths=[os.path.join(self.lq_dir,k) for k in self.keys] 130 | self.filename_tmpl=filename_tmpl 131 | 132 | def transform(self,gt_seq,lq_seq): 133 | gt_transformed,lq_transformed=pair_random_crop_seq(gt_seq,lq_seq,patch_size=self.patch_size) 134 | gt_transformed,lq_transformed=pair_random_flip_seq(gt_transformed,lq_transformed,p=0.5) 135 | gt_transformed,lq_transformed=pair_random_transposeHW_seq(gt_transformed,lq_transformed,p=0.5) 136 | return gt_transformed,lq_transformed 137 | 138 | def __len__(self): 139 | return len(self.keys) 140 | 141 | def __getitem__(self,idx): 142 | gt_sequence, lq_sequence = generate_segment_indices(self.gt_seq_paths[idx],self.lq_seq_paths[idx],num_input_frames=self.num_input_frames,filename_tmpl=self.filename_tmpl) 143 | if not self.is_test: 144 | gt_sequence, lq_sequence = self.transform(gt_sequence,lq_sequence) 145 | return gt_sequence,lq_sequence 146 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on Open-MMLab's one. 3 | https://github.com/open-mmlab/mmediting 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | def charbonnier_loss(pred, target,weight=None,reduction='mean',sample_wise=False, eps=1e-12): 10 | """Charbonnier loss. 11 | 12 | Args: 13 | pred (Tensor): Prediction Tensor with shape (n, c, h, w). 14 | target ([type]): Target Tensor with shape (n, c, h, w). 15 | 16 | Returns: 17 | Tensor: Calculated Charbonnier loss. 18 | """ 19 | return torch.sqrt((pred - target)**2 + eps).mean() 20 | 21 | class CharbonnierLoss(nn.Module): 22 | """Charbonnier loss (one variant of Robust L1Loss, a differentiable variant of L1Loss). 23 | 24 | Described in "Deep Laplacian Pyramid Networks for Fast and Accurate Super-Resolution". 25 | 26 | Args: 27 | loss_weight (float): Loss weight for L1 loss. Default: 1.0. 28 | reduction (str): Specifies the reduction to apply to the output. 29 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 30 | sample_wise (bool): Whether calculate the loss sample-wise. This 31 | argument only takes effect when `reduction` is 'mean' and `weight` 32 | (argument of `forward()`) is not None. It will first reduces loss 33 | with 'mean' per-sample, and then it means over all the samples. 34 | Default: False. 35 | eps (float): A value used to control the curvature near zero. 36 | Default: 1e-12. 37 | """ 38 | 39 | def __init__(self, 40 | loss_weight=1.0, 41 | reduction='mean', 42 | sample_wise=False, 43 | eps=1e-12): 44 | super().__init__() 45 | if reduction not in ['none', 'mean', 'sum']: 46 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 47 | f"Supported ones are: ['none', 'mean', 'sum']") 48 | 49 | self.loss_weight = loss_weight 50 | self.reduction = reduction 51 | self.sample_wise = sample_wise 52 | self.eps = eps 53 | 54 | def forward(self, pred, target, weight=None, **kwargs): 55 | """Forward Function. 56 | 57 | Args: 58 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 59 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 60 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 61 | weights. Default: None. 62 | """ 63 | return self.loss_weight * charbonnier_loss( 64 | pred, 65 | target, 66 | weight, 67 | eps=self.eps, 68 | reduction=self.reduction, 69 | sample_wise=self.sample_wise) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on Open-MMLab's one. 3 | https://github.com/open-mmlab/mmediting 4 | """ 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from SPyNet import SPyNet 10 | from modules import PixelShuffle,ResidualBlocksWithInputConv,flow_warp 11 | 12 | class basicVSR(nn.Module): 13 | def __init__(self,scale_factor=4, mid_channels=64, num_blocks=30, spynet_pretrained=None): 14 | super().__init__() 15 | self.scale_factor=scale_factor 16 | self.mid_channels = mid_channels 17 | 18 | #alignment(optical flow network) 19 | self.spynet = self.get_spynet(spynet_pretrained) 20 | 21 | #propagation 22 | self.backward_resblocks=ResidualBlocksWithInputConv(mid_channels + 3, mid_channels, num_blocks) 23 | self.forward_resblocks=ResidualBlocksWithInputConv(mid_channels + 3, mid_channels, num_blocks) 24 | 25 | #upsample 26 | self.fusion = nn.Conv2d(mid_channels * 2, mid_channels, 1, 1, 0, bias=True) 27 | self.upsample1 = PixelShuffle(mid_channels, mid_channels, 2, upsample_kernel=3) 28 | self.upsample2 = PixelShuffle(mid_channels, 64, 2, upsample_kernel=3) 29 | self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1) 30 | self.conv_last = nn.Conv2d(64, 3, 3, 1, 1) 31 | self.img_upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False) 32 | 33 | # activation function 34 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 35 | 36 | def get_spynet(self,pretrained): 37 | model=SPyNet() 38 | if pretrained: 39 | model_p=model.state_dict() 40 | pre_p=torch.load(pretrained) 41 | ppl=list(pre_p) 42 | for i,k in enumerate(model_p.keys()): 43 | if i<2: 44 | continue 45 | model_p[k]=pre_p[ppl[i-2]] 46 | model.load_state_dict(model_p) 47 | return model 48 | 49 | def check_if_mirror_extended(self, lrs): 50 | """Check whether the input is a mirror-extended sequence. 51 | If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the 52 | (t-1-i)-th frame. 53 | Args: 54 | lrs (tensor): Input LR images with shape (n, t, c, h, w) 55 | """ 56 | 57 | self.is_mirror_extended = False 58 | if lrs.size(1) % 2 == 0: 59 | lrs_1, lrs_2 = torch.chunk(lrs, 2, dim=1) 60 | if torch.norm(lrs_1 - lrs_2.flip(1)) == 0: 61 | self.is_mirror_extended = True 62 | 63 | def compute_flow(self, lrs): 64 | """Compute optical flow using SPyNet for feature warping. 65 | Note that if the input is an mirror-extended sequence, 'flows_forward' 66 | is not needed, since it is equal to 'flows_backward.flip(1)'. 67 | Args: 68 | lrs (tensor): Input LR images with shape (n, t, c, h, w) 69 | Return: 70 | tuple(Tensor): Optical flow. 'flows_forward' corresponds to the 71 | flows used for forward-time propagation (current to previous). 72 | 'flows_backward' corresponds to the flows used for 73 | backward-time propagation (current to next). 74 | """ 75 | 76 | n, t, c, h, w = lrs.size() 77 | lrs_1 = lrs[:, :-1, :, :, :].reshape(-1, c, h, w) 78 | lrs_2 = lrs[:, 1:, :, :, :].reshape(-1, c, h, w) 79 | 80 | flows_backward = self.spynet(lrs_1, lrs_2).view(n, t - 1, 2, h, w) 81 | 82 | if self.is_mirror_extended: # flows_forward = flows_backward.flip(1) 83 | flows_forward = None 84 | else: 85 | flows_forward = self.spynet(lrs_2, lrs_1).view(n, t - 1, 2, h, w) 86 | 87 | return flows_forward, flows_backward 88 | 89 | def forward(self, lrs): 90 | """Forward function for BasicVSR. 91 | Args: 92 | lrs (Tensor): Input LR sequence with shape (n, t, c, h, w). 93 | Returns: 94 | Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).(if scale_factor=4) 95 | """ 96 | 97 | n, t, c, h, w = lrs.size() 98 | assert h >= 64 and w >= 64, ( 99 | 'The height and width of inputs should be at least 64, ' 100 | f'but got {h} and {w}.') 101 | 102 | # check whether the input is an extended sequence 103 | self.check_if_mirror_extended(lrs) 104 | 105 | # compute optical flow 106 | flows_forward, flows_backward = self.compute_flow(lrs) 107 | 108 | # backward-time propgation 109 | outputs = [] 110 | feat_prop = lrs.new_zeros(n, self.mid_channels, h, w) 111 | for i in range(t - 1, -1, -1): 112 | if i < t - 1: # no warping required for the last timestep 113 | flow = flows_backward[:, i, :, :, :] 114 | feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1)) 115 | 116 | feat_prop = torch.cat([lrs[:, i, :, :, :], feat_prop], dim=1) 117 | feat_prop = self.backward_resblocks(feat_prop) 118 | 119 | outputs.append(feat_prop) 120 | outputs = outputs[::-1] 121 | 122 | # forward-time propagation and upsampling 123 | feat_prop = torch.zeros_like(feat_prop) 124 | for i in range(0, t): 125 | lr_curr = lrs[:, i, :, :, :] 126 | if i > 0: # no warping required for the first timestep 127 | if flows_forward is not None: 128 | flow = flows_forward[:, i - 1, :, :, :] 129 | else: 130 | flow = flows_backward[:, -i, :, :, :] 131 | feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1)) 132 | 133 | feat_prop = torch.cat([lr_curr, feat_prop], dim=1) 134 | feat_prop = self.forward_resblocks(feat_prop) 135 | 136 | # upsampling given the backward and forward features 137 | out = torch.cat([outputs[i], feat_prop], dim=1) 138 | out = self.lrelu(self.fusion(out)) 139 | out = self.lrelu(self.upsample1(out)) 140 | out = self.lrelu(self.upsample2(out)) 141 | out = self.lrelu(self.conv_hr(out)) 142 | out = self.conv_last(out) 143 | base = self.img_upsample(lr_curr) 144 | out += base 145 | outputs[i] = out 146 | 147 | return torch.stack(outputs, dim=1) -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is based on Open-MMLab's one. 3 | https://github.com/open-mmlab/mmediting 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | def flow_warp(x, 11 | flow, 12 | interpolation='bilinear', 13 | padding_mode='zeros', 14 | align_corners=True): 15 | """Warp an image or a feature map with optical flow. 16 | Args: 17 | x (Tensor): Tensor with size (n, c, h, w). 18 | flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is 19 | a two-channel, denoting the width and height relative offsets. 20 | Note that the values are not normalized to [-1, 1]. 21 | interpolation (str): Interpolation mode: 'nearest' or 'bilinear'. 22 | Default: 'bilinear'. 23 | padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'. 24 | Default: 'zeros'. 25 | align_corners (bool): Whether align corners. Default: True. 26 | Returns: 27 | Tensor: Warped image or feature map. 28 | """ 29 | if x.size()[-2:] != flow.size()[1:3]: 30 | raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and ' 31 | f'flow ({flow.size()[1:3]}) are not the same.') 32 | _, _, h, w = x.size() 33 | # create mesh grid 34 | grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w)) 35 | grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (h, w, 2) 36 | grid.requires_grad = False 37 | 38 | grid_flow = grid + flow 39 | # scale grid_flow to [-1,1] 40 | grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0 41 | grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0 42 | grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3) 43 | output = F.grid_sample( 44 | x, 45 | grid_flow, 46 | mode=interpolation, 47 | padding_mode=padding_mode, 48 | align_corners=align_corners) 49 | return output 50 | 51 | def make_itrblocks(block, num_blocks, **kwarg): 52 | """Make layers by stacking the same blocks. 53 | Args: 54 | block (nn.module): nn.module class for basic block. 55 | num_blocks (int): number of blocks. 56 | Returns: 57 | nn.Sequential: Stacked blocks in nn.Sequential. 58 | """ 59 | layers = [] 60 | for _ in range(num_blocks): 61 | layers.append(block(**kwarg)) 62 | return nn.Sequential(*layers) 63 | 64 | class ResidualBlocksWithInputConv(nn.Module): 65 | """Residual blocks with a convolution in front. 66 | Args: 67 | in_channels (int): Number of input channels of the first conv. 68 | out_channels (int): Number of channels of the residual blocks. 69 | Default: 64. 70 | num_blocks (int): Number of residual blocks. Default: 30. 71 | """ 72 | 73 | def __init__(self, in_channels, out_channels=64, num_blocks=30): 74 | super().__init__() 75 | 76 | layers = [] 77 | 78 | # a convolution used to match the channels of the residual blocks 79 | layers.append(nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=True)) 80 | layers.append(nn.LeakyReLU(negative_slope=0.1, inplace=True)) 81 | 82 | # residual blocks 83 | layers.append(make_itrblocks(ResidualBlockNoBN, num_blocks, mid_channels=out_channels)) 84 | 85 | self.layers = nn.Sequential(*layers) 86 | 87 | def forward(self, feat): 88 | """ 89 | Forward function for ResidualBlocksWithInputConv. 90 | Args: 91 | feat (Tensor): Input feature with shape (n, in_channels, h, w) 92 | Returns: 93 | Tensor: Output feature with shape (n, out_channels, h, w) 94 | """ 95 | return self.layers(feat) 96 | 97 | class ResidualBlockNoBN(nn.Module): 98 | """Residual block without BN. 99 | It has a style of: 100 | :: 101 | ---Conv-ReLU-Conv-+- 102 | |________________| 103 | Args: 104 | mid_channels (int): Channel number of intermediate features. 105 | Default: 64. 106 | res_scale (float): Used to scale the residual before addition. 107 | Default: 1.0. 108 | """ 109 | 110 | def __init__(self, mid_channels=64, res_scale=1.0): 111 | super().__init__() 112 | self.res_scale = res_scale 113 | self.conv1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=True) 114 | self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=True) 115 | 116 | self.relu = nn.ReLU(inplace=True) 117 | 118 | # if res_scale < 1.0, use the default initialization, as in EDSR. 119 | # if res_scale = 1.0, use scaled kaiming_init, as in MSRResNet. 120 | if res_scale == 1.0: 121 | self.init_weights() 122 | 123 | def init_weights(self): 124 | """Initialize weights for ResidualBlockNoBN. 125 | Initialization methods like `kaiming_init` are for VGG-style 126 | modules. For modules with residual paths, using smaller std is 127 | better for stability and performance. We empirically use 0.1. 128 | See more details in "ESRGAN: Enhanced Super-Resolution Generative 129 | Adversarial Networks" 130 | """ 131 | 132 | for m in [self.conv1, self.conv2]: 133 | nn.init.kaiming_uniform_(m.weight, a=0, mode='fan_in', nonlinearity='relu') 134 | m.weight.data *= 0.1 135 | nn.init.constant_(m.bias, 0) 136 | 137 | def forward(self, x): 138 | """Forward function. 139 | Args: 140 | x (Tensor): Input tensor with shape (n, c, h, w). 141 | Returns: 142 | Tensor: Forward results. 143 | """ 144 | 145 | identity = x 146 | x=self.conv1(x) 147 | x=self.relu(x) 148 | out = self.conv2(x) 149 | return identity + out * self.res_scale 150 | 151 | class PixelShuffle(nn.Module): 152 | """ Pixel Shuffle upsample layer. 153 | Args: 154 | in_channels (int): Number of input channels. 155 | out_channels (int): Number of output channels. 156 | scale_factor (int): Upsample ratio. 157 | upsample_kernel (int): Kernel size of Conv layer to expand channels. 158 | Returns: 159 | Upsampled feature map. 160 | """ 161 | 162 | def __init__(self, in_channels, out_channels, scale_factor, upsample_kernel): 163 | super().__init__() 164 | self.scale_factor = scale_factor 165 | self.upsample_conv = nn.Conv2d( 166 | in_channels, 167 | out_channels * scale_factor * scale_factor, 168 | upsample_kernel, 169 | padding=(upsample_kernel - 1) // 2 170 | ) 171 | self.init_weights() 172 | 173 | def init_weights(self): 174 | """Initialize weights for PixelShufflePack. 175 | """ 176 | for m in [self.upsample_conv]: 177 | nn.init.kaiming_uniform_(m.weight, a=0, mode='fan_in', nonlinearity='relu') 178 | nn.init.constant_(m.bias, 0) 179 | 180 | def forward(self, x): 181 | """Forward function for PixelShufflePack. 182 | Args: 183 | x (Tensor): Input tensor with shape (n, c, h, w). 184 | Returns: 185 | Tensor: Forward results. 186 | """ 187 | x = self.upsample_conv(x) 188 | x = F.pixel_shuffle(x, self.scale_factor) 189 | return x -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from typing import OrderedDict 4 | from tqdm import tqdm 5 | import matplotlib.pyplot as plt 6 | from math import log10 7 | from PIL import Image 8 | 9 | import torch 10 | from torch import nn 11 | from torch.autograd import Variable 12 | from torchvision.utils import save_image 13 | from torch.utils.data import DataLoader 14 | from torch.optim.lr_scheduler import CosineAnnealingLR 15 | 16 | from dataset import REDSDataset 17 | from model import basicVSR 18 | from loss import CharbonnierLoss 19 | from utils import resize_sequences 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--gt_dir', default='./REDS/train_sharp') 23 | parser.add_argument('--lq_dir', default='./REDS/train_sharp_bicubic/X4') 24 | parser.add_argument('--log_dir', default='./log_dir') 25 | parser.add_argument('--spynet_pretrained', default='spynet_20210409-c6c1bd09.pth') 26 | parser.add_argument('--scale_factor', default=4,type=int) 27 | parser.add_argument('--batch_size', default=8,type=int) 28 | parser.add_argument('--patch_size', default=64,type=int) 29 | parser.add_argument('--epochs', default=300000,type=int) 30 | parser.add_argument('--num_input_frames', default=15,type=int) 31 | parser.add_argument('--val_interval', default=1000,type=int) 32 | parser.add_argument('--max_keys', default=270,type=int) 33 | parser.add_argument('--filename_tmpl', default='{:08d}.png') 34 | args = parser.parse_args() 35 | 36 | train_set=REDSDataset(args.gt_dir,args.lq_dir,args.scale_factor,args.patch_size,args.num_input_frames,is_test=False,max_keys=args.max_keys,filename_tmpl=args.filename_tmpl) 37 | val_set=REDSDataset(args.gt_dir,args.lq_dir,args.scale_factor,args.patch_size,args.num_input_frames,is_test=True,max_keys=args.max_keys,filename_tmpl=args.filename_tmpl) 38 | 39 | train_loader=DataLoader(train_set,batch_size=args.batch_size,shuffle=True,num_workers=os.cpu_count(),pin_memory=True) 40 | val_loader=DataLoader(val_set,batch_size=1,num_workers=os.cpu_count(),pin_memory=True) 41 | 42 | model=basicVSR(spynet_pretrained=args.spynet_pretrained).cuda() 43 | 44 | criterion=CharbonnierLoss().cuda() 45 | criterion_mse=nn.MSELoss().cuda() 46 | optimizer = torch.optim.Adam([ 47 | {'params': model.spynet.parameters(), 'lr': 2.5e-5}, 48 | {'params': model.backward_resblocks.parameters()}, 49 | {'params': model.forward_resblocks.parameters()}, 50 | {'params': model.fusion.parameters()}, 51 | {'params': model.upsample1.parameters()}, 52 | {'params': model.upsample2.parameters()}, 53 | {'params': model.conv_hr.parameters()}, 54 | {'params': model.conv_last.parameters()} 55 | ], lr=2e-4, betas=(0.9,0.99) 56 | ) 57 | 58 | max_epoch=args.epochs 59 | scheduler=CosineAnnealingLR(optimizer,T_max=max_epoch,eta_min=1e-7) 60 | 61 | os.makedirs(f'{args.log_dir}/models',exist_ok=True) 62 | os.makedirs(f'{args.log_dir}/images',exist_ok=True) 63 | train_loss=[] 64 | validation_loss=[] 65 | for epoch in range(max_epoch): 66 | model.train() 67 | # fix SPyNet and EDVR at first 5000 iteration 68 | if epoch < 5000: 69 | for k, v in model.named_parameters(): 70 | if 'spynet' in k or 'edvr' in k: 71 | v.requires_grad_(False) 72 | elif epoch == 5000: 73 | # train all the parameters 74 | model.requires_grad_(True) 75 | 76 | epoch_loss = 0 77 | with tqdm(train_loader, ncols=100) as pbar: 78 | for idx, data in enumerate(pbar): 79 | gt_sequences, lq_sequences = Variable(data[0]),Variable(data[1]) 80 | gt_sequences=gt_sequences.to('cuda:0') 81 | lq_sequences=lq_sequences.to('cuda:0') 82 | 83 | optimizer.zero_grad() 84 | pred_sequences = model(lq_sequences) 85 | loss = criterion(pred_sequences, gt_sequences) 86 | epoch_loss += loss.item() 87 | #epoch_psnr += 10 * log10(1 / loss.data) 88 | 89 | loss.backward() 90 | optimizer.step() 91 | scheduler.step() 92 | 93 | pbar.set_description(f'[Epoch {epoch+1}]') 94 | pbar.set_postfix(OrderedDict(loss=f'{loss.data:.3f}')) 95 | 96 | train_loss.append(epoch_loss/len(train_loader)) 97 | 98 | if (epoch + 1) % args.val_interval != 0: 99 | continue 100 | 101 | model.eval() 102 | val_psnr,lq_psnr = 0,0 103 | os.makedirs(f'{args.log_dir}/images/epoch{epoch+1:05}',exist_ok=True) 104 | with torch.no_grad(): 105 | for idx,data in enumerate(val_loader): 106 | gt_sequences, lq_sequences = data 107 | gt_sequences=gt_sequences.to('cuda:0') 108 | lq_sequences=lq_sequences.to('cuda:0') 109 | pred_sequences = model(lq_sequences) 110 | lq_sequences=resize_sequences(lq_sequences,(gt_sequences.size(dim=3),gt_sequences.size(dim=4))) 111 | val_mse = criterion_mse(pred_sequences, gt_sequences) 112 | lq_mse = criterion_mse(lq_sequences,gt_sequences) 113 | val_psnr += 10 * log10(1 / val_mse.data) 114 | lq_psnr += 10 * log10(1 / lq_mse.data) 115 | 116 | save_image(pred_sequences[0], f'{args.log_dir}/images/epoch{epoch+1:05}/{idx}_SR.png',nrow=5) 117 | save_image(lq_sequences[0], f'{args.log_dir}/images/epoch{epoch+1:05}/{idx}_LQ.png',nrow=5) 118 | save_image(gt_sequences[0], f'{args.log_dir}/images/epoch{epoch+1:05}/{idx}_GT.png',nrow=5) 119 | 120 | validation_loss.append(epoch_loss/len(val_loader)) 121 | 122 | print(f'==[validation]== PSNR:{val_psnr / len(val_loader):.2f},(lq:{lq_psnr/len(val_loader):.2f})') 123 | torch.save(model.state_dict(),f'{args.log_dir}/models/model_{epoch}.pth') 124 | 125 | fig=plt.figure() 126 | train_loss=[loss for loss in train_loss] 127 | validation_loss=[loss for loss in validation_loss] 128 | x_train=list(range(len(train_loss))) 129 | x_val=[x for x in range(max_epoch) if (x + 1) % args.val_interval == 0] 130 | plt.plot(x_train,train_loss) 131 | plt.plot(x_val,validation_loss) 132 | 133 | fig.savefig(f'{args.log_dir}/loss.png') -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as T 3 | 4 | def resize_sequences(sequences,target_size): 5 | """resize sequence 6 | Args: 7 | sequences (Tensor): input sequence with shape (n, t, c, h, w) 8 | target_size (tuple): the size of output sequence with shape (H, W) 9 | Returns: 10 | Tensor: Output sequences with shape (n, t, c, H, W) 11 | """ 12 | seq_list=[] 13 | for sequence in sequences: 14 | img_list=[T.Resize(target_size,interpolation=T.InterpolationMode.BICUBIC)(lq_image) for lq_image in sequence] 15 | seq_list.append(torch.stack(img_list)) 16 | 17 | return torch.stack(seq_list) --------------------------------------------------------------------------------