├── .DS_Store ├── README.md ├── checkpoints ├── .DS_Store └── SSRDEF_4xSR_epoch80.pth.tar ├── common.py ├── data ├── .DS_Store └── test │ └── .DS_Store ├── datasets ├── __init__.py ├── data_io.py ├── kitti_dataset.py └── sceneflow_dataset.py ├── figs ├── .DS_Store ├── Model.png ├── Results.png └── Visual.png ├── loss.py ├── metric.py ├── model.py ├── result.txt ├── test_disp.py ├── test_sr.py ├── train.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIVRC/SSRDEFNet-PyTorch/4d95c175fc60526fecbb42172a211fbd4472adf6/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SSRDEFNet_PyTorch 2 | ### This repository is an official PyTorch implementation of the paper "Feedback Network for Mutually Boosted Stereo Image Super-Resolution and Disparity Estimation". (ACM MM 2021) 3 | 4 | ### Paper can be download from SSRDE-FNet 5 | 6 | Under stereo settings, the problem of image super-resolution (SR) and disparity estimation are interrelated that the result of each problem could help to solve the other. The effective exploitation of correspondence between different views facilitates the SR performance, while the high-resolution (HR) features with richer details benefit the correspondence estimation. According to this motivation, we propose a Stereo Super-Resolution and Disparity Estimation Feedback Network (SSRDE-FNet), which simultaneously handles the stereo image super-resolution and disparity estimation in a unified framework and interact them with each other to further improve their performance. Specifically, the SSRDE-FNet is composed of two dual recursive sub-networks for left and right views. Besides the cross-view information exploitation in the low-resolution (LR) space, HR representations produced by the SR process are utilized to perform HR disparity estimation with higher accuracy, through which the HR features can be aggregated to generate a finer SR result. Afterward, the proposed HR Disparity Information Feedback (HRDIF) mechanism delivers information carried by HR disparity back to previous layers to further refine the SR image reconstruction. Extensive experiments demonstrate the effectiveness and advancement of SSRDE-FNet. 7 | 8 | 9 | 10 | 11 | 12 | The pre-trained model for x4 SR is in ./checkpoints/SSRDEF_x4SR_epoch80.pth.tar 13 | 14 | All reconstructed images for x4 SR can be download from SSRDEFNet_Results 15 | 16 | All test datasets (Preprocessed HR images) can be download from SSRDEFNet_Test 17 | 18 | Extract the dataset and put them into the ./data/test/. 19 | 20 | 21 | ## Prerequisites: 22 | 1. Python 3.6 23 | 2. PyTorch >= 0.4.0 24 | 3. numpy 25 | 4. skimage 26 | 5. imageio 27 | 6. matplotlib 28 | 29 | 30 | ## Dataset 31 | 32 | We used Flickr1024 and Middlebury dataset to train our model, which is exactly the same as iPASSR. Please refer to their homepage and download the dataset as their guidance. 33 | 34 | Extract the dataset and put them into the ./data/train/. 35 | 36 | ##Training 37 | 38 | ```python 39 | 40 | python train.py --scale_factor 4 41 | 42 | ``` 43 | 44 | ##Testing 45 | 46 | ```python 47 | 48 | # Testing stereo sr performance 49 | 50 | python test_sr.py 51 | 52 | # Testing disparity estimation performance 53 | 54 | python test_disp.py 55 | 56 | ``` 57 | 58 | ##Performance 59 | 60 |

61 | 62 |

63 | 64 |

65 | 66 |

67 | 68 | 69 | 70 | ``` 71 | @inproceedings{dai2021feedback, 72 | title={Feedback Network for Mutually Boosted Stereo Image Super-Resolution and Disparity Estimation}, 73 | author={Dai, Qinyan and Li, Juncheng and Yi, Qiaosi and Fang, Faming and Zhang, Guixu}, 74 | booktitle={Proceedings of the 29th ACM International Conference on Multimedia}, 75 | pages={1985--1993}, 76 | year={2021} 77 | } 78 | ``` 79 | 80 | This implementation is for non-commercial research use only. 81 | If you find this code useful in your research, please cite the above papers. -------------------------------------------------------------------------------- /checkpoints/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIVRC/SSRDEFNet-PyTorch/4d95c175fc60526fecbb42172a211fbd4472adf6/checkpoints/.DS_Store -------------------------------------------------------------------------------- /checkpoints/SSRDEF_4xSR_epoch80.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIVRC/SSRDEFNet-PyTorch/4d95c175fc60526fecbb42172a211fbd4472adf6/checkpoints/SSRDEF_4xSR_epoch80.pth.tar -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def convbn(in_planes, out_planes, kernel_size, stride, pad, dilation): 8 | 9 | return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=dilation if dilation > 1 else pad, dilation = dilation, bias=False), 10 | nn.BatchNorm2d(out_planes)) 11 | 12 | 13 | def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True): 14 | return nn.Conv2d( 15 | in_channels, out_channels, kernel_size, 16 | padding=(kernel_size//2),stride=stride, bias=bias) 17 | 18 | class MeanShift(nn.Conv2d): 19 | def __init__( 20 | self, rgb_range, 21 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 22 | 23 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 24 | std = torch.Tensor(rgb_std) 25 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 26 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 27 | for p in self.parameters(): 28 | p.requires_grad = False 29 | 30 | class BasicBlock(nn.Sequential): 31 | def __init__( 32 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True, 33 | bn=False, act=nn.PReLU()): 34 | 35 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 36 | if bn: 37 | m.append(nn.BatchNorm2d(out_channels)) 38 | if act is not None: 39 | m.append(act) 40 | 41 | super(BasicBlock, self).__init__(*m) 42 | 43 | class ResBlock(nn.Module): 44 | def __init__( 45 | self, conv, n_feats, kernel_size, 46 | bias=True, bn=False, act=nn.PReLU()): 47 | 48 | super(ResBlock, self).__init__() 49 | m = [] 50 | for i in range(2): 51 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 52 | if bn: 53 | m.append(nn.BatchNorm2d(n_feats)) 54 | if i == 0: 55 | m.append(act) 56 | 57 | self.body = nn.Sequential(*m) 58 | 59 | def forward(self, x): 60 | res = self.body(x) 61 | res += x 62 | 63 | return res 64 | 65 | class ResBlockdilat(nn.Module): 66 | def __init__( 67 | self, n_feats, kernel_size, dilrate=1, 68 | bias=True, bn=False, act=nn.PReLU()): 69 | 70 | super(ResBlockdilat, self).__init__() 71 | m = [] 72 | for i in range(2): 73 | m.append(nn.Conv2d(n_feats, n_feats, kernel_size, 1, dilrate, dilrate, bias=bias)) 74 | if bn: 75 | m.append(nn.BatchNorm2d(n_feats)) 76 | if i == 0: 77 | m.append(act) 78 | 79 | self.body = nn.Sequential(*m) 80 | 81 | def forward(self, x): 82 | res = self.body(x) 83 | res += x 84 | 85 | return res 86 | 87 | class Upsampler(nn.Sequential): 88 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 89 | 90 | m = [] 91 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 92 | for _ in range(int(math.log(scale, 2))): 93 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 94 | m.append(nn.PixelShuffle(2)) 95 | if bn: 96 | m.append(nn.BatchNorm2d(n_feats)) 97 | if act == 'relu': 98 | m.append(nn.ReLU(True)) 99 | elif act == 'prelu': 100 | m.append(nn.PReLU(n_feats)) 101 | 102 | elif scale == 3: 103 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 104 | m.append(nn.PixelShuffle(3)) 105 | if bn: 106 | m.append(nn.BatchNorm2d(n_feats)) 107 | if act == 'relu': 108 | m.append(nn.ReLU(True)) 109 | elif act == 'prelu': 110 | m.append(nn.PReLU(n_feats)) 111 | else: 112 | raise NotImplementedError 113 | 114 | super(Upsampler, self).__init__(*m) 115 | 116 | class ConvBlock(torch.nn.Module): 117 | def __init__(self, input_size, output_size, kernel_size=3, stride=1, padding=1, bias=True, activation='prelu', norm=None): 118 | super(ConvBlock, self).__init__() 119 | self.conv = torch.nn.Conv2d(input_size, output_size, kernel_size, stride, padding, bias=bias) 120 | 121 | self.norm = norm 122 | if self.norm =='batch': 123 | self.bn = torch.nn.BatchNorm2d(output_size) 124 | elif self.norm == 'instance': 125 | self.bn = torch.nn.InstanceNorm2d(output_size) 126 | 127 | self.activation = activation 128 | if self.activation == 'relu': 129 | self.act = torch.nn.ReLU(True) 130 | elif self.activation == 'prelu': 131 | self.act = torch.nn.PReLU() 132 | elif self.activation == 'lrelu': 133 | self.act = torch.nn.LeakyReLU(0.2, True) 134 | elif self.activation == 'tanh': 135 | self.act = torch.nn.Tanh() 136 | elif self.activation == 'sigmoid': 137 | self.act = torch.nn.Sigmoid() 138 | 139 | def forward(self, x): 140 | if self.norm is not None: 141 | out = self.bn(self.conv(x)) 142 | else: 143 | out = self.conv(x) 144 | 145 | if self.activation is not None: 146 | return self.act(out) 147 | else: 148 | return out 149 | 150 | 151 | class DeconvBlock(torch.nn.Module): 152 | def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True, activation='prelu', norm=None): 153 | super(DeconvBlock, self).__init__() 154 | self.deconv = torch.nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, bias=bias) 155 | 156 | self.norm = norm 157 | if self.norm == 'batch': 158 | self.bn = torch.nn.BatchNorm2d(output_size) 159 | elif self.norm == 'instance': 160 | self.bn = torch.nn.InstanceNorm2d(output_size) 161 | 162 | self.activation = activation 163 | if self.activation == 'relu': 164 | self.act = torch.nn.ReLU(True) 165 | elif self.activation == 'prelu': 166 | self.act = torch.nn.PReLU() 167 | elif self.activation == 'lrelu': 168 | self.act = torch.nn.LeakyReLU(0.2, True) 169 | elif self.activation == 'tanh': 170 | self.act = torch.nn.Tanh() 171 | elif self.activation == 'sigmoid': 172 | self.act = torch.nn.Sigmoid() 173 | 174 | def forward(self, x): 175 | if self.norm is not None: 176 | out = self.bn(self.deconv(x)) 177 | else: 178 | out = self.deconv(x) 179 | 180 | if self.activation is not None: 181 | return self.act(out) 182 | else: 183 | return out 184 | 185 | 186 | class Bottle2neck(nn.Module): 187 | 188 | def __init__(self, inplanes, planes, stride=1, downsample=None, scale = 4, stype='normal'): 189 | """ Constructor 190 | Args: 191 | inplanes: input channel dimensionality 192 | planes: output channel dimensionality 193 | stride: conv stride. Replaces pooling layer. 194 | downsample: None when stride = 1 195 | baseWidth: basic width of conv3x3 196 | scale: number of scale. 197 | type: 'normal': normal set. 'stage': first block of a new stage. 198 | """ 199 | super(Bottle2neck, self).__init__() 200 | 201 | width = planes//scale 202 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 203 | 204 | if scale == 1: 205 | self.nums = 1 206 | else: 207 | self.nums = scale -1 208 | if stype == 'stage': 209 | self.pool = nn.AvgPool2d(kernel_size=3, stride = stride, padding=1) 210 | convs = [] 211 | for i in range(self.nums): 212 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride = stride, padding=1, bias=False)) 213 | self.convs = nn.ModuleList(convs) 214 | 215 | self.conv3 = nn.Conv2d(width*scale, planes, kernel_size=1, bias=False) 216 | 217 | self.relu = nn.PReLU() 218 | self.downsample = downsample 219 | self.stype = stype 220 | self.scale = scale 221 | self.width = width 222 | 223 | def forward(self, x): 224 | residual = x 225 | out = self.conv1(x) 226 | out = self.relu(out) 227 | 228 | spx = torch.split(out, self.width, 1) 229 | for i in range(self.nums): 230 | if i==0 or self.stype=='stage': 231 | sp = spx[i] 232 | else: 233 | sp = sp + spx[i] 234 | sp = self.convs[i](sp) 235 | sp = self.relu(sp) 236 | if i==0: 237 | out = sp 238 | else: 239 | out = torch.cat((out, sp), 1) 240 | if self.scale != 1 and self.stype=='normal': 241 | out = torch.cat((out, spx[self.nums]),1) 242 | elif self.scale != 1 and self.stype=='stage': 243 | out = torch.cat((out, self.pool(spx[self.nums])),1) 244 | 245 | out = self.conv3(out) 246 | 247 | out += residual 248 | out = self.relu(out) 249 | 250 | return out -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIVRC/SSRDEFNet-PyTorch/4d95c175fc60526fecbb42172a211fbd4472adf6/data/.DS_Store -------------------------------------------------------------------------------- /data/test/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIVRC/SSRDEFNet-PyTorch/4d95c175fc60526fecbb42172a211fbd4472adf6/data/test/.DS_Store -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .kitti_dataset import KITTIDataset 2 | from .sceneflow_dataset import SceneFlowDatset 3 | 4 | __datasets__ = { 5 | "sceneflow": SceneFlowDatset, 6 | "kitti": KITTIDataset 7 | } 8 | -------------------------------------------------------------------------------- /datasets/data_io.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | import sys 4 | import torchvision.transforms as transforms 5 | import torch 6 | 7 | 8 | def get_transform(): 9 | mean = [0.485, 0.456, 0.406] 10 | std = [0.229, 0.224, 0.225] 11 | 12 | return transforms.Compose([ 13 | transforms.ToTensor(), 14 | #transforms.Normalize(mean=mean, std=std), 15 | ]) 16 | 17 | 18 | # def get_transform_kitti(): 19 | # mean = [0.485, 0.456, 0.406] 20 | # std = [0.229, 0.224, 0.225] 21 | # 22 | # return transforms.Compose([ 23 | # transforms.ToTensor(), 24 | # RandomPhotometric( 25 | # noise_stddev=0.0, 26 | # min_contrast=-0.3, 27 | # max_contrast=0.3, 28 | # brightness_stddev=0.02, 29 | # min_color=0.9, 30 | # max_color=1.1, 31 | # min_gamma=0.7, 32 | # max_gamma=1.5), 33 | # transforms.Normalize(mean=mean, std=std), 34 | # ]) 35 | 36 | 37 | # read all lines in a file 38 | def read_all_lines(filename): 39 | with open(filename) as f: 40 | lines = [line.rstrip() for line in f.readlines()] 41 | return lines 42 | 43 | 44 | # read an .pfm file into numpy array, used to load SceneFlow disparity files 45 | def pfm_imread(filename): 46 | file = open(filename, 'rb') 47 | color = None 48 | width = None 49 | height = None 50 | scale = None 51 | endian = None 52 | 53 | header = file.readline().decode('utf-8').rstrip() 54 | if header == 'PF': 55 | color = True 56 | elif header == 'Pf': 57 | color = False 58 | else: 59 | raise Exception('Not a PFM file.') 60 | 61 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 62 | if dim_match: 63 | width, height = map(int, dim_match.groups()) 64 | else: 65 | raise Exception('Malformed PFM header.') 66 | 67 | scale = float(file.readline().rstrip()) 68 | if scale < 0: # little-endian 69 | endian = '<' 70 | scale = -scale 71 | else: 72 | endian = '>' # big-endian 73 | 74 | data = np.fromfile(file, endian + 'f') 75 | shape = (height, width, 3) if color else (height, width) 76 | 77 | data = np.reshape(data, shape) 78 | data = np.flipud(data) 79 | return data, scale 80 | 81 | 82 | def writePFM(file, image, scale=1): 83 | file = open(file, 'wb') 84 | 85 | color = None 86 | 87 | if image.dtype.name != 'float32': 88 | raise Exception('Image dtype must be float32.') 89 | 90 | image = np.flipud(image) 91 | 92 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 93 | color = True 94 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale 95 | color = False 96 | else: 97 | raise Exception( 98 | 'Image must have H x W x 3, H x W x 1 or H x W dimensions.') 99 | 100 | file.write('PF\n'.encode('utf-8') if color else 'Pf\n'.encode('utf-8')) 101 | file.write('%d %d\n'.encode('utf-8') % (image.shape[1], image.shape[0])) 102 | 103 | endian = image.dtype.byteorder 104 | 105 | if endian == '<' or endian == '=' and sys.byteorder == 'little': 106 | scale = -scale 107 | 108 | file.write('%f\n'.encode('utf-8') % scale) 109 | 110 | image.tofile(file) 111 | 112 | 113 | class RandomPhotometric(object): 114 | """Applies photometric augmentations to a list of image tensors. 115 | Each image in the list is augmented in the same way. 116 | 117 | Args: 118 | ims: list of 3-channel images normalized to [0, 1]. 119 | 120 | Returns: 121 | normalized images with photometric augmentations. Has the same 122 | shape as the input. 123 | """ 124 | def __init__(self, 125 | noise_stddev=0.0, 126 | min_contrast=0.0, 127 | max_contrast=0.0, 128 | brightness_stddev=0.0, 129 | min_color=1.0, 130 | max_color=1.0, 131 | min_gamma=1.0, 132 | max_gamma=1.0): 133 | self.noise_stddev = noise_stddev 134 | self.min_contrast = min_contrast 135 | self.max_contrast = max_contrast 136 | self.brightness_stddev = brightness_stddev 137 | self.min_color = min_color 138 | self.max_color = max_color 139 | self.min_gamma = min_gamma 140 | self.max_gamma = max_gamma 141 | 142 | def __call__(self, ims): 143 | contrast = np.random.uniform(self.min_contrast, self.max_contrast) 144 | gamma = np.random.uniform(self.min_gamma, self.max_gamma) 145 | gamma_inv = 1.0 / gamma 146 | color = torch.from_numpy( 147 | np.random.uniform(self.min_color, self.max_color, (3))).float() 148 | if self.noise_stddev > 0.0: 149 | noise = np.random.normal(scale=self.noise_stddev) 150 | else: 151 | noise = 0 152 | if self.brightness_stddev > 0.0: 153 | brightness = np.random.normal(scale=self.brightness_stddev) 154 | else: 155 | brightness = 0 156 | 157 | im_re = ims.permute(1, 2, 0) 158 | im_re = (im_re * (contrast + 1.0) + brightness) * color 159 | im_re = torch.clamp(im_re, min=0.0, max=1.0) 160 | im_re = torch.pow(im_re, gamma_inv) 161 | im_re += noise 162 | 163 | out = im_re.permute(2, 0, 1) 164 | 165 | return out -------------------------------------------------------------------------------- /datasets/kitti_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | from torch.utils.data import Dataset 5 | from PIL import Image 6 | import numpy as np 7 | from datasets.data_io import get_transform, read_all_lines 8 | 9 | 10 | class KITTIDataset(Dataset): 11 | def __init__(self, datapath, list_filename, training): 12 | self.datapath = datapath 13 | self.left_filenames, self.right_filenames, self.disp_occ_filenames, self.disp_noc_filenames = self.load_path(list_filename) 14 | self.training = training 15 | if self.training: 16 | assert self.disp_occ_filenames is not None 17 | 18 | def load_path(self, list_filename): 19 | lines = read_all_lines(list_filename) 20 | splits = [line.split() for line in lines] 21 | left_images = [x[0] for x in splits] 22 | right_images = [x[1] for x in splits] 23 | if len(splits[0]) == 2: # ground truth not available 24 | return left_images, right_images, None, None 25 | else: 26 | disp_occ_images = [x[2] for x in splits] 27 | disp_noc_images = [x[2].replace('occ', 'noc') for x in splits] 28 | return left_images, right_images, disp_occ_images, disp_noc_images 29 | 30 | def load_image(self, filename): 31 | return Image.open(filename).convert('RGB') 32 | 33 | def load_disp(self, filename): 34 | data = Image.open(filename) 35 | data = np.array(data, dtype=np.float32) / 256. 36 | return data 37 | 38 | def __len__(self): 39 | return len(self.left_filenames) 40 | 41 | def __getitem__(self, index): 42 | left_img = self.load_image(os.path.join(self.datapath, self.left_filenames[index])) 43 | right_img = self.load_image(os.path.join(self.datapath, self.right_filenames[index])) 44 | 45 | if self.disp_occ_filenames: # has disparity ground truth 46 | disp_occ = self.load_disp(os.path.join(self.datapath, self.disp_occ_filenames[index])) 47 | disp_noc = self.load_disp(os.path.join(self.datapath, self.disp_noc_filenames[index])) 48 | else: 49 | disp_occ = None 50 | 51 | if self.training: 52 | w, h = left_img.size 53 | crop_w, crop_h = 512, 256 54 | 55 | x1 = random.randint(0, w - crop_w) 56 | y1 = random.randint(0, h - crop_h) 57 | 58 | # random crop 59 | left_img = left_img.crop((x1, y1, x1 + crop_w, y1 + crop_h)) 60 | right_img = right_img.crop((x1, y1, x1 + crop_w, y1 + crop_h)) 61 | disp_occ = disp_occ[y1:y1 + crop_h, x1:x1 + crop_w] 62 | disp_noc = disp_noc[y1:y1 + crop_h, x1:x1 + crop_w] 63 | occ_mask = ((disp_occ - disp_noc) > 0).astype(np.float32) 64 | 65 | # to tensor, normalize 66 | processed = get_transform() 67 | left_img = processed(left_img) 68 | right_img = processed(right_img) 69 | 70 | # # augumentation 71 | # if random.random() < 0.5: 72 | # left_img = torch.flip(left_img, [1]) 73 | # right_img = torch.flip(right_img, [1]) 74 | # disp_occ = np.ascontiguousarray(np.flip(disp_occ, 0)) 75 | 76 | return {"left": left_img, 77 | "right": right_img, 78 | "left_disp": disp_occ, 79 | "occ_mask": occ_mask} 80 | else: 81 | w, h = left_img.size 82 | 83 | # normalize 84 | processed = get_transform() 85 | left_img = processed(left_img).numpy() 86 | right_img = processed(right_img).numpy() 87 | 88 | # pad to size 1248x384 89 | top_pad = 380 - h 90 | right_pad = 1244 - w 91 | assert top_pad > 0 and right_pad > 0 92 | # pad images 93 | left_img = np.lib.pad(left_img, ((0, 0), (top_pad, 0), (0, right_pad)), mode='edge') 94 | right_img = np.lib.pad(right_img, ((0, 0), (top_pad, 0), (0, right_pad)), mode='edge') 95 | 96 | 97 | # pad disparity gt 98 | if self.disp_occ_filenames is not None: 99 | # assert len(self.disp_occ_filenames.shape) == 2 100 | disp_occ = np.lib.pad(disp_occ, ((top_pad, 0), (0, right_pad)), mode='constant', constant_values=0) 101 | 102 | if self.disp_occ_filenames is not None: 103 | return {"HR_left": left_img, 104 | "HR_right": right_img, 105 | "left_disp": disp_occ, 106 | "top_pad": top_pad, 107 | "right_pad": right_pad 108 | } 109 | else: 110 | return {"HR_left": left_img, 111 | "HR_right": right_img, 112 | "top_pad": top_pad, 113 | "right_pad": right_pad, 114 | "left_filename": self.left_filenames[index], 115 | "right_filename": self.right_filenames[index] 116 | } 117 | -------------------------------------------------------------------------------- /datasets/sceneflow_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from torch.utils.data import Dataset 7 | from PIL import Image 8 | from datasets.data_io import get_transform, read_all_lines, pfm_imread 9 | 10 | 11 | class SceneFlowDatset(Dataset): 12 | def __init__(self, datapath, list_filename, training): 13 | self.datapath = datapath 14 | self.left_filenames, self.right_filenames, self.left_disp_filenames, self.right_disp_filenames = self.load_path(list_filename) 15 | self.training = training 16 | 17 | def load_path(self, list_filename): 18 | lines = read_all_lines(list_filename) 19 | splits = [line.split() for line in lines] 20 | left_images = [x[0] for x in splits] 21 | right_images = [x[1] for x in splits] 22 | left_disp = [x[2] for x in splits] 23 | right_disp = [x[2][:-13]+'right/'+x[2][-8:] for x in splits] 24 | 25 | return left_images, right_images, left_disp, right_disp 26 | 27 | def load_image(self, filename): 28 | return Image.open(filename).convert('RGB') 29 | 30 | def load_disp(self, filename): 31 | data, scale = pfm_imread(filename) 32 | data = np.ascontiguousarray(data, dtype=np.float32) 33 | return data 34 | 35 | def __len__(self): 36 | return len(self.left_filenames) 37 | 38 | def __getitem__(self, index): 39 | left_img = self.load_image(os.path.join(self.datapath, self.left_filenames[index])) 40 | right_img = self.load_image(os.path.join(self.datapath, self.right_filenames[index])) 41 | left_disp = self.load_disp(os.path.join(self.datapath, self.left_disp_filenames[index])) 42 | right_disp = self.load_disp(os.path.join(self.datapath, self.right_disp_filenames[index])) 43 | 44 | if self.training: 45 | w, h = left_img.size 46 | crop_w, crop_h = 512, 256 47 | 48 | x1 = random.randint(0, w - crop_w) 49 | y1 = random.randint(0, h - crop_h) 50 | 51 | # random crop 52 | left_img = left_img.crop((x1, y1, x1 + crop_w, y1 + crop_h)) 53 | right_img = right_img.crop((x1, y1, x1 + crop_w, y1 + crop_h)) 54 | left_disp = left_disp[y1:y1 + crop_h, x1:x1 + crop_w] 55 | right_disp = right_disp[y1:y1 + crop_h, x1:x1 + crop_w] 56 | 57 | # to tensor, normalize 58 | processed = get_transform() 59 | left_img = processed(left_img) 60 | right_img = processed(right_img) 61 | 62 | # augumentation 63 | # if random.random()<0.5: 64 | # left_img = torch.flip(left_img, [1]) 65 | # right_img = torch.flip(right_img, [1]) 66 | # left_disp = np.ascontiguousarray(np.flip(left_disp, 0)) 67 | # right_disp = np.ascontiguousarray(np.flip(right_disp, 0)) 68 | 69 | return {"left": left_img, 70 | "right": right_img, 71 | "left_disp": left_disp, 72 | "right_disp": right_disp} 73 | else: 74 | w, h = left_img.size 75 | crop_w, crop_h = 960, 512 76 | 77 | left_img = left_img.crop((w - crop_w, h - crop_h, w, h)) 78 | right_img = right_img.crop((w - crop_w, h - crop_h, w, h)) 79 | disparity = left_disp[h - crop_h:h, w - crop_w: w] 80 | disparity_right = right_disp[h - crop_h:h, w - crop_w: w] 81 | 82 | processed = get_transform() 83 | left_img = processed(left_img) 84 | right_img = processed(right_img) 85 | 86 | return {"left": left_img, 87 | "right": right_img, 88 | "left_disp": disparity, 89 | "right_disp": disparity_right, 90 | "top_pad": 0, 91 | "right_pad": 0} 92 | -------------------------------------------------------------------------------- /figs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIVRC/SSRDEFNet-PyTorch/4d95c175fc60526fecbb42172a211fbd4472adf6/figs/.DS_Store -------------------------------------------------------------------------------- /figs/Model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIVRC/SSRDEFNet-PyTorch/4d95c175fc60526fecbb42172a211fbd4472adf6/figs/Model.png -------------------------------------------------------------------------------- /figs/Results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIVRC/SSRDEFNet-PyTorch/4d95c175fc60526fecbb42172a211fbd4472adf6/figs/Results.png -------------------------------------------------------------------------------- /figs/Visual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIVRC/SSRDEFNet-PyTorch/4d95c175fc60526fecbb42172a211fbd4472adf6/figs/Visual.png -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from math import exp 5 | from utils import * 6 | 7 | 8 | def L1Loss(input, target): 9 | return (input - target).abs().mean() 10 | 11 | 12 | def loss_disp_unsupervised(img_left, img_right, disp, valid_mask=None): 13 | b, _, h, w = img_left.shape 14 | image_warped = warp_disp(img_right, disp) 15 | 16 | valid_mask = torch.ones(b, 1, h, w).to(img_left.device) if valid_mask is None else valid_mask 17 | 18 | loss = 0.15 * L1Loss(image_warped * valid_mask, img_left * valid_mask) + \ 19 | 0.85 * (valid_mask * (1 - ssim(img_left, image_warped)) / 2).mean() 20 | return loss 21 | 22 | 23 | def loss_disp_smoothness(disp, img): 24 | img_grad_x = img[:, :, :, :-1] - img[:, :, :, 1:] 25 | img_grad_y = img[:, :, :-1, :] - img[:, :, 1:, :] 26 | weight_x = torch.exp(-torch.abs(img_grad_x).mean(1).unsqueeze(1)) 27 | weight_y = torch.exp(-torch.abs(img_grad_y).mean(1).unsqueeze(1)) 28 | 29 | loss = (((disp[:, :, :, :-1] - disp[:, :, :, 1:]).abs() * weight_x).sum() + 30 | ((disp[:, :, :-1, :] - disp[:, :, 1:, :]).abs() * weight_y).sum()) / \ 31 | (weight_x.sum() + weight_y.sum()) 32 | 33 | return loss 34 | 35 | 36 | def loss_pam_photometric(img_left, img_right, att, valid_mask, mask=None): 37 | weight = [0.2, 0.3, 0.5] 38 | loss = torch.zeros(1).to(img_left.device) 39 | 40 | for idx_scale in range(len(att)): 41 | scale = img_left.size()[2] // valid_mask[idx_scale][0].size()[2] 42 | b, c, h, w = valid_mask[idx_scale][0].size() 43 | 44 | att_right2left = att[idx_scale][0] # b * h * w * w 45 | att_left2right = att[idx_scale][1] 46 | valid_mask_left = valid_mask[idx_scale][0] # b * 1 * h * w 47 | valid_mask_right = valid_mask[idx_scale][1] 48 | 49 | if mask is not None: 50 | valid_mask_left = valid_mask_left * (nn.AvgPool2d(scale)(mask[0].float()) > 0).float() 51 | valid_mask_right = valid_mask_right * (nn.AvgPool2d(scale)(mask[1].float()) > 0).float() 52 | 53 | img_left_scale = F.interpolate(img_left, scale_factor=1/scale, mode='bilinear') 54 | img_right_scale = F.interpolate(img_right, scale_factor=1/scale, mode='bilinear') 55 | 56 | img_right_warp = torch.matmul(att_right2left, img_right_scale.permute(0, 2, 3, 1).contiguous()) 57 | img_right_warp = img_right_warp.permute(0, 3, 1, 2) 58 | img_left_warp = torch.matmul(att_left2right, img_left_scale.permute(0, 2, 3, 1).contiguous()) 59 | img_left_warp = img_left_warp.permute(0, 3, 1, 2) 60 | 61 | loss_scale = L1Loss(img_left_scale * valid_mask_left, img_right_warp * valid_mask_left) + \ 62 | L1Loss(img_right_scale * valid_mask_right, img_left_warp * valid_mask_right) 63 | 64 | loss = loss + weight[idx_scale] * loss_scale 65 | 66 | return loss 67 | 68 | 69 | def loss_pam_cycle(att_cycle, valid_mask): 70 | weight = [0.2, 0.3, 0.5] 71 | loss = torch.zeros(1).to(att_cycle[0][0].device) 72 | 73 | for idx_scale in range(len(att_cycle)): 74 | b, c, h, w = valid_mask[idx_scale][0].shape 75 | I = torch.eye(w, w).repeat(b, h, 1, 1).to(att_cycle[0][0].device) 76 | 77 | att_left2right2left = att_cycle[idx_scale][0] 78 | att_right2left2right = att_cycle[idx_scale][1] 79 | valid_mask_left = valid_mask[idx_scale][0] 80 | valid_mask_right = valid_mask[idx_scale][1] 81 | 82 | loss_scale = L1Loss(att_left2right2left * valid_mask_left.permute(0, 2, 3, 1), I * valid_mask_left.permute(0, 2, 3, 1)) + \ 83 | L1Loss(att_right2left2right * valid_mask_right.permute(0, 2, 3, 1), I * valid_mask_right.permute(0, 2, 3, 1)) 84 | 85 | loss = loss + weight[idx_scale] * loss_scale 86 | 87 | return loss 88 | 89 | 90 | def loss_pam_smoothness(att): 91 | weight = [0.2, 0.3, 0.5] 92 | loss = torch.zeros(1).to(att[0][0].device) 93 | 94 | for idx_scale in range(len(att)): 95 | att_right2left = att[idx_scale][0] 96 | att_left2right = att[idx_scale][1] 97 | 98 | loss_scale = L1Loss(att_right2left[:, :-1, :, :], att_right2left[:, 1:, :, :]) + \ 99 | L1Loss(att_left2right[:, :-1, :, :], att_left2right[:, 1:, :, :]) + \ 100 | L1Loss(att_right2left[:, :, :-1, :-1], att_right2left[:, :, 1:, 1:]) + \ 101 | L1Loss(att_left2right[:, :, :-1, :-1], att_left2right[:, :, 1:, 1:]) 102 | 103 | loss = loss + weight[idx_scale] * loss_scale 104 | 105 | return loss 106 | 107 | 108 | def warp_disp(img, disp): 109 | ''' 110 | Borrowed from: https://github.com/OniroAI/MonoDepth-PyTorch 111 | ''' 112 | b, _, h, w = img.size() 113 | 114 | # Original coordinates of pixels 115 | x_base = torch.linspace(0, 1, w).repeat(b, h, 1).type_as(img) 116 | y_base = torch.linspace(0, 1, h).repeat(b, w, 1).transpose(1, 2).type_as(img) 117 | 118 | # Apply shift in X direction 119 | x_shifts = disp[:, 0, :, :] / w 120 | flow_field = torch.stack((x_shifts, y_base), dim=3) 121 | 122 | # In grid_sample coordinates are assumed to be between -1 and 1 123 | output = F.grid_sample(img, 2 * flow_field - 1, mode='bilinear', padding_mode='border') 124 | 125 | return output 126 | 127 | 128 | def gaussian(window_size, sigma): 129 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 130 | return gauss / gauss.sum() 131 | 132 | 133 | def create_window(window_size, channel): 134 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 135 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 136 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 137 | return window 138 | 139 | 140 | def _ssim(img1, img2, window, window_size, channel): 141 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 142 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 143 | 144 | mu1_sq = mu1.pow(2) 145 | mu2_sq = mu2.pow(2) 146 | mu1_mu2 = mu1 * mu2 147 | 148 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 149 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 150 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 151 | 152 | C1 = 0.01 ** 2 153 | C2 = 0.03 ** 2 154 | 155 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 156 | 157 | return ssim_map 158 | 159 | 160 | def ssim(img1, img2, window_size=11): 161 | _, channel, h, w = img1.size() 162 | window = create_window(window_size, channel) 163 | if img1.is_cuda: 164 | window = window.cuda(img1.get_device()) 165 | window = window.type_as(img1) 166 | return _ssim(img1, img2, window, window_size, channel) 167 | -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | EPSILON = 1e-8 5 | 6 | 7 | def epe_metric(d_est, d_gt, mask, use_np=False): 8 | d_est, d_gt = d_est[mask], d_gt[mask] 9 | if use_np: 10 | epe = np.mean(np.abs(d_est - d_gt)) 11 | else: 12 | epe = torch.mean(torch.abs(d_est - d_gt)) 13 | 14 | return epe 15 | 16 | 17 | def d1_metric(d_est, d_gt, mask, err, use_np=False): 18 | d_est, d_gt = d_est[mask], d_gt[mask] 19 | 20 | if use_np: 21 | e = np.abs(d_gt - d_est) 22 | else: 23 | e = torch.abs(d_gt - d_est) 24 | err_mask = (e > err) & (e / d_gt > 0.05) 25 | 26 | if use_np: 27 | mean = np.mean(err_mask.astype('float')) 28 | else: 29 | mean = torch.mean(err_mask.float()) 30 | 31 | return mean 32 | 33 | 34 | def thres_metric(d_est, d_gt, mask, thres, use_np=False): 35 | assert isinstance(thres, (int, float)) 36 | d_est, d_gt = d_est[mask], d_gt[mask] 37 | if use_np: 38 | e = np.abs(d_gt - d_est) 39 | else: 40 | e = torch.abs(d_gt - d_est) 41 | err_mask = e > thres 42 | 43 | if use_np: 44 | mean = np.mean(err_mask.astype('float')) 45 | else: 46 | mean = torch.mean(err_mask.float()) 47 | 48 | return mean 49 | 50 | def compute_depth_errors(gt, pred): 51 | """Computation of error metrics between predicted and ground truth depths 52 | """ 53 | thresh = torch.max((gt / pred), (pred / gt)) 54 | a1 = (thresh < 1.25 ).float().mean() 55 | a2 = (thresh < 1.25 ** 2).float().mean() 56 | a3 = (thresh < 1.25 ** 3).float().mean() 57 | 58 | rmse = (gt - pred) ** 2 59 | rmse = torch.sqrt(rmse.mean()) 60 | 61 | rmse_log = (torch.log(gt) - torch.log(pred)) ** 2 62 | rmse_log = torch.sqrt(rmse_log.mean()) 63 | 64 | abs_rel = torch.mean(torch.abs(gt - pred) / gt) 65 | 66 | sq_rel = torch.mean((gt - pred) ** 2 / gt) 67 | 68 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from skimage import morphology 7 | from torchvision import transforms 8 | from common import * 9 | 10 | class SSRDEFNet(nn.Module): 11 | def __init__(self, upscale_factor): 12 | super(SSRDEFNet, self).__init__() 13 | self.upscale_factor = upscale_factor 14 | if upscale_factor == 2: 15 | kernel = 6 16 | stride = 2 17 | padding = 2 18 | elif upscale_factor == 4: 19 | kernel = 8 20 | stride = 4 21 | padding = 2 22 | self.init_feature = nn.Conv2d(3, 64, 3, 1, 1, bias=True) 23 | self.deep_feature = RDG(G0=64, C=4, G=24, n_RDB=4) 24 | self.pam = PAM(64) 25 | self.transition = nn.Sequential( 26 | nn.BatchNorm2d(64), 27 | ResB(64) 28 | ) 29 | self.transition2 = nn.Sequential( 30 | nn.BatchNorm2d(64), 31 | ResB(64) 32 | ) 33 | self.StereoFea = Stereo_feature() 34 | self.StereoFeaHigh = hStereo_feature() 35 | 36 | self.encode = RDB(G0=64, C=6, G=24) 37 | self.encoder2 = RDB(G0=64, C=6, G=24) 38 | self.CALayer2 = CALayer(64, 8) 39 | self.CALayer = CALayer(64, 8) 40 | 41 | self.reconstruct = RDG(G0=64, C=4, G=24, n_RDB=4) 42 | self.upscale = nn.Sequential( 43 | nn.Conv2d(64, 64 * upscale_factor ** 2, 1, 1, 0, bias=True), 44 | nn.PixelShuffle(upscale_factor)) 45 | self.final = nn.Conv2d(64, 3, 3, 1, 1, bias=True) 46 | self.final2 = nn.Conv2d(64, 3, 3, 1, 1, bias=True) 47 | 48 | self.get_cv = GetCostVolume(24) 49 | self.dres0 = nn.Sequential(convbn(24, 24, 3, 1, 1, 1), 50 | nn.PReLU(), 51 | convbn(24, 24, 3, 1, 1, 1), 52 | nn.PReLU()) 53 | 54 | self.dres1 = nn.Sequential(convbn(24, 24, 3, 1, 1, 1), 55 | nn.PReLU(), 56 | convbn(24, 24, 3, 1, 1, 1)) 57 | self.dres2 = hourglass(24) 58 | self.softmax = nn.Softmax(1) 59 | 60 | self.att1 = nn.Conv2d(64, 32, 1, 1, 0) 61 | self.att2 = nn.Conv2d(64, 32, 1, 1, 0) 62 | 63 | self.backatt1 = nn.Conv2d(64, 32, 1, 1, 0) 64 | self.backatt2 = nn.Conv2d(64, 32, 1, 1, 0) 65 | 66 | self.feedbackatt = FeedbackBlock(64, 64, kernel, stride, padding) 67 | self.down = ConvBlock(64, 64, kernel, stride, padding, activation='prelu', norm=None) 68 | self.compress_in = nn.Conv2d(64*2, 64, 1, bias=True) 69 | self.resblock = RDB(G0=64, C=2, G=24) 70 | 71 | def forward(self, x_left, x_right, is_training): 72 | x_left_upscale = F.interpolate(x_left, scale_factor=self.upscale_factor, mode='bicubic', align_corners=False) 73 | x_right_upscale = F.interpolate(x_right, scale_factor=self.upscale_factor, mode='bicubic', align_corners=False) 74 | buffer_left = self.init_feature(x_left) 75 | buffer_right = self.init_feature(x_right) 76 | buffer_left = self.deep_feature(buffer_left) 77 | buffer_right = self.deep_feature(buffer_right) 78 | 79 | stereo_left = self.StereoFea(self.transition(buffer_left)) 80 | stereo_right = self.StereoFea(self.transition(buffer_right)) 81 | 82 | b,c,h,w = buffer_left.shape 83 | 84 | cost = [ 85 | torch.zeros(b*h, w, w).to(buffer_left.device), 86 | torch.zeros(b*h, w, w).to(buffer_left.device) 87 | ] 88 | 89 | buffer_leftT, buffer_rightT, disp1, disp2, (M_right_to_left, M_left_to_right), (V_left, V_right)\ 90 | = self.pam(buffer_left, buffer_right, stereo_left, stereo_right, cost, is_training) 91 | 92 | 93 | buffer_leftF = self.CALayer(buffer_left + self.encode(buffer_leftT - buffer_left)) 94 | buffer_rightF = self.CALayer(buffer_right + self.encode(buffer_rightT - buffer_right)) 95 | 96 | 97 | buffer_leftF = self.reconstruct(buffer_leftF) 98 | buffer_rightF = self.reconstruct(buffer_rightF) 99 | feat_left = self.upscale(buffer_leftF) 100 | feat_right = self.upscale(buffer_rightF) 101 | out1_left = self.final(feat_left)+x_left_upscale 102 | out1_right = self.final(feat_right)+x_right_upscale 103 | 104 | hstereo_left = self.StereoFeaHigh(self.transition2(feat_left)) 105 | hstereo_right = self.StereoFeaHigh(self.transition2(feat_right)) 106 | 107 | 108 | disp1 = F.interpolate(disp1 * self.upscale_factor, scale_factor=self.upscale_factor, mode='bilinear', align_corners=False) 109 | disp2 = F.interpolate(disp2 * self.upscale_factor, scale_factor=self.upscale_factor, mode='bilinear', align_corners=False) 110 | maxdisp = x_left.shape[3]*self.upscale_factor 111 | 112 | 113 | disp_range_samples1 = get_disp_range_samples(cur_disp=disp1.detach().squeeze(1), ndisp=24, 114 | shape=[x_left.shape[0], x_left.shape[2]*self.upscale_factor, x_left.shape[3]*self.upscale_factor], 115 | max_disp=maxdisp) 116 | disp_range_samples2 = get_disp_range_samples(cur_disp=disp2.detach().squeeze(1), ndisp=24, 117 | shape=[x_left.shape[0], x_left.shape[2]*self.upscale_factor, x_left.shape[3]*self.upscale_factor], 118 | max_disp=maxdisp) 119 | 120 | cost1, cost2 = self.get_cv(hstereo_left, hstereo_right, disp_range_samples1, disp_range_samples2, 24) 121 | 122 | cost1 = cost1.contiguous() 123 | cost2 = cost2.contiguous() 124 | 125 | cost1_0 = self.dres0(cost1) 126 | cost1_0 = self.dres1(cost1_0) + cost1_0 127 | cost2_0 = self.dres0(cost2) 128 | cost2_0 = self.dres1(cost2_0) + cost2_0 129 | 130 | out1 = self.dres2(cost1_0, None, None) 131 | cost1_1 = out1+cost1_0 132 | out2 = self.dres2(cost2_0, None, None) 133 | cost2_1 = out2+cost2_0 134 | 135 | cost_prob1 = self.softmax(cost1_1) 136 | cost_prob2 = self.softmax(cost2_1) 137 | 138 | disp1_high = torch.sum(disp_range_samples1 * cost_prob1, dim=1).unsqueeze(1) 139 | disp2_high = torch.sum(disp_range_samples2 * cost_prob2, dim=1).unsqueeze(1) 140 | 141 | feat_leftW = dispwarpfeature(feat_right, disp1_high) 142 | feat_rightW = dispwarpfeature(feat_left, disp2_high) 143 | 144 | geoerror_left = torch.abs(disp1_high - dispwarpfeature(disp2_high, disp1_high)).detach() 145 | geoerror_right = torch.abs(disp2_high - dispwarpfeature(disp1_high, disp2_high)).detach() 146 | 147 | V_left2 = 1 - torch.tanh(0.1*geoerror_left) 148 | V_right2 = 1 - torch.tanh(0.1*geoerror_right) 149 | 150 | V_left2 = torch.max(F.interpolate(V_left, scale_factor=self.upscale_factor, mode='nearest'), V_left2) 151 | V_right2 = torch.max(F.interpolate(V_right, scale_factor=self.upscale_factor, mode='nearest'), V_right2) 152 | 153 | left_att = self.att1(feat_left) 154 | leftW_att = self.att2(feat_leftW) 155 | corrleft = (torch.tanh(5*torch.sum(left_att*leftW_att, 1).unsqueeze(1))+1)/2 156 | 157 | right_att = self.att1(feat_right) 158 | rightW_att = self.att2(feat_rightW) 159 | corrright = (torch.tanh(5*torch.sum(right_att*rightW_att, 1).unsqueeze(1))+1)/2 160 | 161 | err1 = self.encoder2((feat_leftW - feat_left)*corrleft) 162 | buffer_leftF2 = self.CALayer2(err1 + feat_left) #high resolution feature that contains high resolution information of the other image through high res stereo matching 163 | 164 | err2 = self.encoder2((feat_rightW - feat_right)*corrright) 165 | buffer_rightF2 = self.CALayer2(err2 + feat_right) 166 | 167 | out2_left = self.final2(buffer_leftF2)+x_left_upscale 168 | out2_right = self.final2(buffer_rightF2)+x_right_upscale 169 | 170 | #feedback start 171 | 172 | left_back = self.down(buffer_leftF2) 173 | right_back = self.down(buffer_rightF2) 174 | 175 | att1 = self.feedbackatt(left_back) 176 | att2 = self.feedbackatt(right_back) 177 | 178 | left_back = left_back + 0.1*(left_back*att1) 179 | right_back = right_back + 0.1*(right_back*att2) 180 | 181 | 182 | bufferleft_att = self.backatt1(buffer_left) 183 | bufferright_att = self.backatt1(buffer_right) 184 | 185 | 186 | for ii in range(self.upscale_factor): 187 | for jj in range(self.upscale_factor): 188 | draft_l = dispwarpfeature(buffer_right, disp1_high[:, :, ii::self.upscale_factor, jj::self.upscale_factor]/self.upscale_factor) 189 | draft_r = dispwarpfeature(buffer_left, disp2_high[:, :, ii::self.upscale_factor, jj::self.upscale_factor]/self.upscale_factor) 190 | draftl_att = self.backatt2(draft_l) 191 | draftr_att = self.backatt2(draft_r) 192 | corrleft = (torch.tanh(5*torch.sum(bufferleft_att*draftl_att, 1).unsqueeze(1))+1)/2 193 | corrright = (torch.tanh(5*torch.sum(bufferright_att*draftr_att, 1).unsqueeze(1))+1)/2 194 | draft_l = (1-corrleft)*buffer_left+corrleft*draft_l 195 | draft_r = (1-corrright)*buffer_right+corrright*draft_r 196 | if ii==0 and jj==0: 197 | draft_left = buffer_left + self.resblock(draft_l - buffer_left) 198 | draft_right = buffer_right + self.resblock(draft_r - buffer_right) 199 | else: 200 | draft_left += buffer_left + self.resblock(draft_l - buffer_left) 201 | draft_right += buffer_right + self.resblock(draft_r - buffer_right) 202 | 203 | draft_left = draft_left/(self.upscale_factor**2) 204 | draft_right = draft_right/(self.upscale_factor**2) 205 | 206 | buffer_left = self.compress_in(torch.cat([draft_left, left_back], 1)) 207 | buffer_right = self.compress_in(torch.cat([draft_right, right_back], 1)) 208 | 209 | stereo_left = self.StereoFea(self.transition(buffer_left)) 210 | stereo_right = self.StereoFea(self.transition(buffer_right)) 211 | 212 | cost = [ 213 | torch.zeros(b*h, w, w).to(buffer_left.device), 214 | torch.zeros(b*h, w, w).to(buffer_left.device) 215 | ] 216 | 217 | buffer_leftT, buffer_rightT, disp1_3, disp2_3, (M_right_to_left3, M_left_to_right3), (V_left3, V_right3)\ 218 | = self.pam(buffer_left, buffer_right, stereo_left, stereo_right, cost, is_training) 219 | 220 | buffer_leftF = self.CALayer(buffer_left + self.encode(buffer_leftT - buffer_left)) 221 | buffer_rightF = self.CALayer(buffer_right + self.encode(buffer_rightT - buffer_right)) 222 | 223 | 224 | buffer_leftF = self.reconstruct(buffer_leftF) 225 | buffer_rightF = self.reconstruct(buffer_rightF) 226 | feat_left = self.upscale(buffer_leftF) 227 | feat_right = self.upscale(buffer_rightF) 228 | out3_left = self.final(feat_left)+x_left_upscale 229 | out3_right = self.final(feat_right)+x_right_upscale 230 | 231 | hstereo_left = self.StereoFeaHigh(self.transition2(feat_left)) 232 | hstereo_right = self.StereoFeaHigh(self.transition2(feat_right)) 233 | 234 | 235 | disp1_3 = F.interpolate(disp1_3 * self.upscale_factor, scale_factor=self.upscale_factor, mode='bilinear', align_corners=False) 236 | disp2_3 = F.interpolate(disp2_3 * self.upscale_factor, scale_factor=self.upscale_factor, mode='bilinear', align_corners=False) 237 | maxdisp = x_left.shape[3]*self.upscale_factor 238 | 239 | 240 | disp_range_samples1 = get_disp_range_samples(cur_disp=disp1_3.detach().squeeze(1), ndisp=24, 241 | shape=[x_left.shape[0], x_left.shape[2]*self.upscale_factor, x_left.shape[3]*self.upscale_factor], 242 | max_disp=maxdisp) 243 | disp_range_samples2 = get_disp_range_samples(cur_disp=disp2_3.detach().squeeze(1), ndisp=24, 244 | shape=[x_left.shape[0], x_left.shape[2]*self.upscale_factor, x_left.shape[3]*self.upscale_factor], 245 | max_disp=maxdisp) 246 | 247 | cost1, cost2 = self.get_cv(hstereo_left, hstereo_right, disp_range_samples1, disp_range_samples2, 24) 248 | cost1 = cost1.contiguous() 249 | cost2 = cost2.contiguous() 250 | 251 | cost1_0 = self.dres0(cost1) 252 | cost1_0 = self.dres1(cost1_0) + cost1_0 253 | cost2_0 = self.dres0(cost2) 254 | cost2_0 = self.dres1(cost2_0) + cost2_0 255 | 256 | out1 = self.dres2(cost1_0, None, None) 257 | cost1_1 = out1+cost1_0 258 | out2 = self.dres2(cost2_0, None, None) 259 | cost2_1 = out2+cost2_0 260 | 261 | 262 | cost_prob1_2 = self.softmax(cost1_1) 263 | cost_prob2_2 = self.softmax(cost2_1) 264 | 265 | disp1_high2 = torch.sum(disp_range_samples1 * cost_prob1_2, dim=1).unsqueeze(1) 266 | disp2_high2 = torch.sum(disp_range_samples2 * cost_prob2_2, dim=1).unsqueeze(1) 267 | 268 | feat_leftW = dispwarpfeature(feat_right, disp1_high2) 269 | feat_rightW = dispwarpfeature(feat_left, disp2_high2) 270 | 271 | geoerror_left = torch.abs(disp1_high2 - dispwarpfeature(disp2_high2, disp1_high2)).detach() 272 | geoerror_right = torch.abs(disp2_high2 - dispwarpfeature(disp1_high2, disp2_high2)).detach() 273 | 274 | V_left4 = 1 - torch.tanh(0.1*geoerror_left) 275 | V_right4 = 1 - torch.tanh(0.1*geoerror_right) 276 | 277 | V_left4 = torch.max(F.interpolate(V_left3, scale_factor=self.upscale_factor, mode='nearest'), V_left4) 278 | V_right4 = torch.max(F.interpolate(V_right3, scale_factor=self.upscale_factor, mode='nearest'), V_right4) 279 | 280 | left_att = self.att1(feat_left) 281 | leftW_att = self.att2(feat_leftW) 282 | corrleft = (torch.tanh(5*torch.sum(left_att*leftW_att, 1).unsqueeze(1))+1)/2 283 | 284 | right_att = self.att1(feat_right) 285 | rightW_att = self.att2(feat_rightW) 286 | corrright = (torch.tanh(5*torch.sum(right_att*rightW_att, 1).unsqueeze(1))+1)/2 287 | 288 | err1 = self.encoder2((feat_leftW - feat_left)*corrleft) 289 | buffer_leftF2 = self.CALayer2(err1 + feat_left) 290 | 291 | err2 = self.encoder2((feat_rightW - feat_right)*corrright) 292 | buffer_rightF2 = self.CALayer2(err2 + feat_right) 293 | 294 | out4_left = self.final2(buffer_leftF2)+x_left_upscale 295 | out4_right = self.final2(buffer_rightF2)+x_right_upscale 296 | 297 | if is_training == 0: 298 | index=(torch.arange(w*self.upscale_factor).view(1, 1, w*self.upscale_factor).repeat(1,h*self.upscale_factor,1)).to(buffer_left.device) 299 | 300 | disp1 = index-disp1.view(1,h*self.upscale_factor,w*self.upscale_factor) 301 | disp2 = disp2.view(1,h*self.upscale_factor,w*self.upscale_factor)-index 302 | disp1[disp1<0]=0 303 | disp2[disp2<0]=0 304 | disp1[disp1>192]=192 305 | disp2[disp2>192]=192 306 | 307 | disp1_3 = index-disp1_3.view(1,h*self.upscale_factor,w*self.upscale_factor) 308 | disp2_3 = disp2_3.view(1,h*self.upscale_factor,w*self.upscale_factor)-index 309 | disp1_3[disp1_3<0]=0 310 | disp2_3[disp2_3<0]=0 311 | disp1_3[disp1_3>192]=192 312 | disp2_3[disp2_3>192]=192 313 | 314 | disp1_high = index-disp1_high.view(1,h*self.upscale_factor,w*self.upscale_factor) 315 | disp2_high = disp2_high.view(1,h*self.upscale_factor,w*self.upscale_factor)-index 316 | disp1_high[disp1_high<0]=0 317 | disp2_high[disp2_high<0]=0 318 | disp1_high[disp1_high>192]=192 319 | disp2_high[disp2_high>192]=192 320 | 321 | disp1_high2 = index-disp1_high2.view(1,h*self.upscale_factor,w*self.upscale_factor) 322 | disp2_high2 = disp2_high2.view(1,h*self.upscale_factor,w*self.upscale_factor)-index 323 | disp1_high2[disp1_high2<0]=0 324 | disp2_high2[disp2_high2<0]=0 325 | disp1_high2[disp1_high2>192]=192 326 | disp2_high2[disp2_high2>192]=192 327 | 328 | return out1_left, out1_right, out2_left, out2_right, out3_left, out3_right, out4_left, out4_right, (disp1, disp2), (disp1_3, disp2_3), (disp1_high, disp2_high), (disp1_high2, disp2_high2) 329 | 330 | if is_training == 1: 331 | return out1_left, out1_right, out2_left, out2_right, out3_left, out3_right, out4_left, out4_right,\ 332 | (M_right_to_left, M_left_to_right), (disp1, disp2), (V_left, V_right), (V_left2, V_right2), (disp1_high, disp2_high),\ 333 | (M_right_to_left3, M_left_to_right3), (disp1_3, disp2_3), (V_left3, V_right3), (V_left4, V_right4), (disp1_high2, disp2_high2) 334 | 335 | 336 | 337 | def get_disp_range_samples(cur_disp, ndisp, shape, max_disp): 338 | #shape, (B, H, W) 339 | #cur_disp: (B, H, W) 340 | #return disp_range_samples: (B, D, H, W) 341 | cur_disp_min = (cur_disp - ndisp / 2).clamp(min=0.0) # (B, H, W) 342 | cur_disp_max = (cur_disp + ndisp / 2).clamp(max=max_disp) 343 | 344 | assert cur_disp.shape == torch.Size(shape), "cur_disp:{}, input shape:{}".format(cur_disp.shape, shape) 345 | new_interval = (cur_disp_max - cur_disp_min) / (ndisp - 1) # (B, H, W) 346 | 347 | disp_range_samples = cur_disp_min.unsqueeze(1) + (torch.arange(0, ndisp, device=cur_disp.device, 348 | dtype=cur_disp.dtype, 349 | requires_grad=False).reshape(1, -1, 1, 350 | 1) * new_interval.unsqueeze(1)) 351 | return disp_range_samples 352 | 353 | 354 | class Stereo_feature(nn.Module): 355 | def __init__(self): 356 | super(Stereo_feature, self).__init__() 357 | 358 | self.branch1 = nn.Sequential(nn.AvgPool2d((30, 30), stride=(30,30)), 359 | convbn(64, 32, 1, 1, 0, 1), 360 | nn.PReLU()) 361 | 362 | self.branch2 = nn.Sequential(nn.AvgPool2d((10, 10), stride=(10,10)), 363 | convbn(64, 32, 1, 1, 0, 1), 364 | nn.PReLU()) 365 | 366 | self.branch3 = nn.Sequential(nn.AvgPool2d((5, 5), stride=(5,5)), 367 | convbn(64, 32, 1, 1, 0, 1), 368 | nn.PReLU()) 369 | 370 | self.lastconv = nn.Sequential(convbn(160, 64, 3, 1, 1, 1), 371 | nn.PReLU(), 372 | nn.Conv2d(64, 64, kernel_size=1, padding=0, stride = 1, bias=False)) 373 | 374 | def forward(self, output_skip): 375 | 376 | output_branch1 = self.branch1(output_skip) 377 | output_branch1 = F.upsample(output_branch1, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear') 378 | 379 | output_branch2 = self.branch2(output_skip) 380 | output_branch2 = F.upsample(output_branch2, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear') 381 | 382 | output_branch3 = self.branch3(output_skip) 383 | output_branch3 = F.upsample(output_branch3, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear') 384 | 385 | output_feature = torch.cat((output_skip, output_branch3, output_branch2, output_branch1), 1) 386 | output_feature = self.lastconv(output_feature) 387 | 388 | return output_feature 389 | 390 | 391 | class hourglass(nn.Module): 392 | def __init__(self, inplanes): 393 | super(hourglass, self).__init__() 394 | 395 | self.conv1 = nn.Sequential(convbn(inplanes, inplanes, kernel_size=3, stride=2, pad=1, dilation=1), 396 | nn.PReLU()) 397 | 398 | self.conv2 = convbn(inplanes, inplanes, kernel_size=3, stride=1, pad=1, dilation=1) 399 | 400 | self.conv3 = nn.Sequential(convbn(inplanes, inplanes, kernel_size=3, stride=2, pad=1, dilation=1), 401 | nn.PReLU()) 402 | 403 | self.conv4 = nn.Sequential(convbn(inplanes, inplanes, kernel_size=3, stride=1, pad=1, dilation=1), 404 | nn.PReLU()) 405 | 406 | self.conv5 = nn.Sequential(nn.ConvTranspose2d(inplanes, inplanes, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False), 407 | nn.BatchNorm2d(inplanes)) #+conv2 408 | 409 | self.conv6 = nn.Sequential(nn.ConvTranspose2d(inplanes, inplanes, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False), 410 | nn.BatchNorm2d(inplanes)) #+x 411 | 412 | self.prelu = nn.PReLU() 413 | 414 | def forward(self, x ,presqu, postsqu): 415 | 416 | out = self.conv1(x) #in:1/4 out:1/8 417 | pre = self.conv2(out) #in:1/8 out:1/8 418 | if postsqu is not None: 419 | pre = self.prelu(pre + postsqu) 420 | else: 421 | pre = self.prelu(pre) 422 | 423 | out = self.conv3(pre) #in:1/8 out:1/16 424 | out = self.conv4(out) #in:1/16 out:1/16 425 | 426 | if presqu is not None: 427 | post = self.prelu(self.conv5(out)+presqu) #in:1/16 out:1/8 428 | else: 429 | post = self.prelu(self.conv5(out)+pre) 430 | 431 | out = self.conv6(post) #in:1/8 out:1/4 432 | 433 | return out 434 | 435 | 436 | class hStereo_feature(nn.Module): 437 | def __init__(self): 438 | super(hStereo_feature, self).__init__() 439 | 440 | self.branch2 = nn.Sequential(nn.AvgPool2d((16, 16), stride=(16,16)), 441 | convbn(64, 24, 1, 1, 0, 1), 442 | nn.PReLU()) 443 | 444 | self.branch3 = nn.Sequential(nn.AvgPool2d((8, 8), stride=(8,8)), 445 | convbn(64, 24, 1, 1, 0, 1), 446 | nn.PReLU()) 447 | 448 | self.lastconv = nn.Sequential(convbn(48+64, 64, 3, 1, 1, 1), 449 | nn.PReLU(), 450 | nn.Conv2d(64, 24, kernel_size=1, padding=0, stride = 1, bias=False)) 451 | 452 | def forward(self, output_skip): 453 | 454 | output_branch2 = self.branch2(output_skip) 455 | output_branch2 = F.upsample(output_branch2, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear') 456 | 457 | output_branch3 = self.branch3(output_skip) 458 | output_branch3 = F.upsample(output_branch3, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear') 459 | 460 | output_feature = torch.cat((output_skip, output_branch3, output_branch2), 1) 461 | output_feature = self.lastconv(output_feature) 462 | 463 | return output_feature 464 | 465 | class FeedbackBlock(torch.nn.Module): 466 | def __init__(self, in_filter, num_filter, kernel_size=8, stride=4, padding=2, bias=True, activation='prelu', norm=None): 467 | super(FeedbackBlock, self).__init__() 468 | #self.conv1 = ConvBlock(in_filter, num_filter, 1, 1, 0, activation='prelu', norm=None) 469 | self.avgpool_1 = torch.nn.AvgPool2d(4, 4, 0) 470 | self.up_1 = DeconvBlock(num_filter, num_filter , 8, 4, 2, activation='prelu', norm=None) 471 | self.act_1 = torch.nn.ReLU(True) 472 | 473 | def forward(self, x): 474 | 475 | #x = self.conv1(x) 476 | p1 = self.avgpool_1(x) 477 | l00 = self.up_1(p1) 478 | l00 = F.upsample(l00, x.size()[2:], mode='bilinear') 479 | act1 = self.act_1(x - l00) 480 | return act1 481 | 482 | def dispwarpfeature(feat, disp): 483 | bs, channels, height, width = feat.size() 484 | mh,_ = torch.meshgrid([torch.arange(0, height, dtype=feat.dtype, device=feat.device), torch.arange(0, width, dtype=feat.dtype, device=feat.device)]) # (H *W) 485 | mh = mh.reshape(1, 1, height, width).repeat(bs, 1, 1, 1) 486 | 487 | cur_disp_coords_y = mh 488 | cur_disp_coords_x = disp 489 | 490 | coords_x = cur_disp_coords_x / ((width - 1.0) / 2.0) - 1.0 # trans to -1 - 1 491 | coords_y = cur_disp_coords_y / ((height - 1.0) / 2.0) - 1.0 492 | grid = torch.stack([coords_x, coords_y], dim=4).view(bs, height, width, 2) #(B, D, H, W, 2)->(B, D*H, W, 2) 493 | 494 | #warped = F.grid_sample(feat, grid.view(bs, ndisp * height, width, 2), mode='bilinear', padding_mode='zeros').view(bs, channels, ndisp, height, width) 495 | warped_feat = F.grid_sample(feat, grid, mode='bilinear', padding_mode='zeros').view(bs, channels,height, width) 496 | 497 | return warped_feat 498 | 499 | def warpfeature(feat, disp_range_samples, cost_prob, ndisp): 500 | bs, channels, height, width = feat.size() 501 | mh,_ = torch.meshgrid([torch.arange(0, height, dtype=feat.dtype, device=feat.device), torch.arange(0, width, dtype=feat.dtype, device=feat.device)]) # (H *W) 502 | mh = mh.reshape(1, 1, height, width).repeat(bs, ndisp, 1, 1) 503 | 504 | cur_disp_coords_y = mh 505 | cur_disp_coords_x = disp_range_samples 506 | 507 | coords_x = cur_disp_coords_x / ((width - 1.0) / 2.0) - 1.0 # trans to -1 - 1 508 | coords_y = cur_disp_coords_y / ((height - 1.0) / 2.0) - 1.0 509 | grid = torch.stack([coords_x, coords_y], dim=4).view(bs, ndisp * height, width, 2) #(B, D, H, W, 2)->(B, D*H, W, 2) 510 | 511 | #warped = F.grid_sample(feat, grid.view(bs, ndisp * height, width, 2), mode='bilinear', padding_mode='zeros').view(bs, channels, ndisp, height, width) 512 | warped_feat = cost_prob[:, 0, :, :].unsqueeze(1) * F.grid_sample(feat, grid[:, :height, :, :], mode='bilinear', padding_mode='zeros').view(bs, channels,height, width) 513 | for i in range(1, ndisp): 514 | warped_feat += cost_prob[:, i, :, :].unsqueeze(1) * F.grid_sample(feat, grid[:, i*height:(i+1)*height, :, :], mode='bilinear', padding_mode='zeros').view(bs, channels,height, width) 515 | 516 | return warped_feat 517 | 518 | class GetCostVolume(nn.Module): 519 | def __init__(self, channels): 520 | super(GetCostVolume, self).__init__() 521 | self.query = nn.Conv2d(channels, channels, 1, 1, 0, bias=True) 522 | self.key = nn.Conv2d(channels, channels, 1, 1, 0, bias=True) 523 | 524 | def forward(self, x, y, disp_range_samples1, disp_range_samples2, ndisp): 525 | assert (x.is_contiguous() == True) 526 | 527 | Q = self.query(x) 528 | K = self.key(y) 529 | 530 | bs, channels, height, width = x.size() 531 | cost1 = x.new().resize_(bs, ndisp, height, width).zero_() 532 | cost2 = x.new().resize_(bs, ndisp, height, width).zero_() 533 | # cost = y.unsqueeze(2).repeat(1, 2, ndisp, 1, 1) #(B, D, H, W) 534 | 535 | mh, mw = torch.meshgrid([torch.arange(0, height, dtype=x.dtype, device=x.device), 536 | torch.arange(0, width, dtype=x.dtype, device=x.device)]) # (H *W) 537 | mh = mh.reshape(1, 1, height, width).repeat(bs, ndisp, 1, 1) 538 | mw = mw.reshape(1, 1, height, width).repeat(bs, ndisp, 1, 1) # (B, D, H, W) 539 | 540 | cur_disp_coords_y = mh 541 | cur_disp_coords_x = disp_range_samples1 542 | 543 | coords_x = cur_disp_coords_x / ((width - 1.0) / 2.0) - 1.0 # trans to -1 - 1 544 | coords_y = cur_disp_coords_y / ((height - 1.0) / 2.0) - 1.0 545 | grid = torch.stack([coords_x, coords_y], dim=4).view(bs, ndisp * height, width, 2) #(B, D, H, W, 2) 546 | 547 | for i in range(ndisp): 548 | cost1[:, i, :, :] = (Q * F.grid_sample(K, grid[:, i*height:(i+1)*height, :, :], mode='bilinear', padding_mode='zeros').view(bs, channels,height, width)).mean(dim=1) 549 | 550 | Q = self.query(y) 551 | K = self.key(x) 552 | cur_disp_coords_x = disp_range_samples2 553 | coords_y = cur_disp_coords_y / ((height - 1.0) / 2.0) - 1.0 554 | grid = torch.stack([coords_x, coords_y], dim=4).view(bs, ndisp * height, width, 2) 555 | 556 | for i in range(ndisp): 557 | cost2[:, i, :, :] = (Q * F.grid_sample(K, grid[:, i*height:(i+1)*height, :, :], mode='bilinear', padding_mode='zeros').view(bs, channels,height, width)).mean(dim=1) 558 | 559 | return cost1, cost2 560 | 561 | class one_conv(nn.Module): 562 | def __init__(self, G0, G): 563 | super(one_conv, self).__init__() 564 | self.conv = nn.Conv2d(G0, G, kernel_size=3, stride=1, padding=1, bias=True) 565 | #self.relu = nn.LeakyReLU(0.1, inplace=True) 566 | self.relu = nn.PReLU() 567 | def forward(self, x): 568 | output = self.relu(self.conv(x)) 569 | return torch.cat((x, output), dim=1) 570 | 571 | 572 | class RDB(nn.Module): 573 | def __init__(self, G0, C, G): 574 | super(RDB, self).__init__() 575 | convs = [] 576 | for i in range(C): 577 | convs.append(one_conv(G0+i*G, G)) 578 | self.conv = nn.Sequential(*convs) 579 | self.LFF = nn.Conv2d(G0+C*G, G0, kernel_size=1, stride=1, padding=0, bias=True) 580 | def forward(self, x): 581 | out = self.conv(x) 582 | lff = self.LFF(out) 583 | return lff + x 584 | 585 | 586 | class RDG(nn.Module): 587 | def __init__(self, G0, C, G, n_RDB): 588 | super(RDG, self).__init__() 589 | self.n_RDB = n_RDB 590 | RDBs = [] 591 | for i in range(n_RDB): 592 | RDBs.append(RDB(G0, C, G)) 593 | self.RDB = nn.Sequential(*RDBs) 594 | self.conv = nn.Conv2d(G0*n_RDB, G0, kernel_size=1, stride=1, padding=0, bias=True) 595 | 596 | def forward(self, x): 597 | buffer = x 598 | temp = [] 599 | for i in range(self.n_RDB): 600 | buffer = self.RDB[i](buffer) 601 | temp.append(buffer) 602 | buffer_cat = torch.cat(temp, dim=1) 603 | out = self.conv(buffer_cat) 604 | return out 605 | 606 | 607 | class CALayer(nn.Module): 608 | def __init__(self, channel, reduction): 609 | super(CALayer, self).__init__() 610 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 611 | self.conv_du = nn.Sequential( 612 | nn.Conv2d(channel, channel//reduction, 1, padding=0, bias=True), 613 | #nn.LeakyReLU(0.1, inplace=True), 614 | nn.PReLU(), 615 | nn.Conv2d(channel//reduction, channel, 1, padding=0, bias=True), 616 | nn.Sigmoid()) 617 | 618 | def forward(self, x): 619 | y = self.avg_pool(x) 620 | y = self.conv_du(y) 621 | return x * y 622 | 623 | 624 | class ResB(nn.Module): 625 | def __init__(self, channels): 626 | super(ResB, self).__init__() 627 | self.body = nn.Sequential( 628 | nn.Conv2d(channels, channels, 3, 1, 1, groups=4, bias=True), 629 | #nn.LeakyReLU(0.1, inplace=True), 630 | nn.PReLU(), 631 | nn.Conv2d(channels, channels, 3, 1, 1, groups=4, bias=True), 632 | ) 633 | def __call__(self,x): 634 | out = self.body(x) 635 | return out + x 636 | 637 | 638 | class PAM(nn.Module): 639 | def __init__(self, channels): 640 | super(PAM, self).__init__() 641 | self.pab1 = PAB(channels) 642 | self.pab2 = PAB(channels) 643 | self.pab3 = PAB(channels) 644 | self.pab4 = PAB(channels) 645 | self.softmax = nn.Softmax(-1) 646 | 647 | def forward(self, x_left, x_right, fea_left, fea_right, cost, is_training): 648 | b, c, h, w = fea_left.shape 649 | fea_left, fea_right, cost = self.pab1(fea_left, fea_right, cost) 650 | fea_left, fea_right, cost = self.pab2(fea_left, fea_right, cost) 651 | fea_left, fea_right, cost = self.pab3(fea_left, fea_right, cost) 652 | fea_left, fea_right, cost = self.pab4(fea_left, fea_right, cost) 653 | 654 | M_right_to_left = self.softmax(cost[0]) # (B*H) * Wl * Wr 655 | M_left_to_right = self.softmax(cost[1]) # (B*H) * Wr * Wl 656 | 657 | 658 | M_right_to_leftp = M_right_to_left.view(b,h,w,w).permute(0,3,1,2).contiguous() 659 | M_left_to_rightp = M_left_to_right.view(b,h,w,w).permute(0,3,1,2).contiguous() 660 | 661 | M_right_to_left_relaxed = M_Relax(M_right_to_left, num_pixels=2) 662 | V_left = torch.bmm(M_right_to_left_relaxed.contiguous().view(-1, w).unsqueeze(1), 663 | M_left_to_right.permute(0, 2, 1).contiguous().view(-1, w).unsqueeze(2) 664 | ).detach().contiguous().view(b, 1, h, w) # (B*H*Wr) * Wl * 1 665 | M_left_to_right_relaxed = M_Relax(M_left_to_right, num_pixels=2) 666 | V_right = torch.bmm(M_left_to_right_relaxed.contiguous().view(-1, w).unsqueeze(1), # (B*H*Wl) * 1 * Wr 667 | M_right_to_left.permute(0, 2, 1).contiguous().view(-1, w).unsqueeze(2) 668 | ).detach().contiguous().view(b, 1, h, w) # (B*H*Wr) * Wl * 1 669 | 670 | V_left_tanh = torch.tanh(5 * V_left) 671 | V_right_tanh = torch.tanh(5 * V_right) 672 | 673 | x_leftT = torch.bmm(M_right_to_left, x_right.permute(0, 2, 3, 1).contiguous().view(-1, w, c) 674 | ).contiguous().view(b, h, w, c).permute(0, 3, 1, 2) # B, C0, H0, W0 675 | x_rightT = torch.bmm(M_left_to_right, x_left.permute(0, 2, 3, 1).contiguous().view(-1, w, c) 676 | ).contiguous().view(b, h, w, c).permute(0, 3, 1, 2) # B, C0, H0, W0 677 | out_left = x_left * (1 - V_left_tanh.repeat(1, c, 1, 1)) + x_leftT * V_left_tanh.repeat(1, c, 1, 1) 678 | out_right = x_right * (1 - V_right_tanh.repeat(1, c, 1, 1)) + x_rightT * V_right_tanh.repeat(1, c, 1, 1) 679 | 680 | index = torch.arange(w).view(1, 1, 1, w).to(M_right_to_left.device).float() # index: 1*1*1*w 681 | disp1 = torch.sum(M_right_to_left * index, dim=-1).view(b, 1, h, w) # x axis of the corresponding point 682 | disp2 = torch.sum(M_left_to_right * index, dim=-1).view(b, 1, h, w) 683 | 684 | 685 | return out_left, out_right, disp1, disp2,\ 686 | (M_right_to_left.contiguous().view(b, h, w, w), M_left_to_right.contiguous().view(b, h, w, w)),\ 687 | (V_left_tanh, V_right_tanh) 688 | 689 | class PAB(nn.Module): 690 | def __init__(self, channels): 691 | super(PAB, self).__init__() 692 | self.head = nn.Sequential( 693 | nn.Conv2d(channels, channels//2, 3, 1, 1, bias=True), 694 | nn.BatchNorm2d(channels//2), 695 | nn.PReLU(), 696 | nn.Conv2d(channels//2, channels, 3, 1, 1, bias=True), 697 | nn.BatchNorm2d(channels), 698 | nn.PReLU(), 699 | ) 700 | self.bq = nn.Conv2d(channels, channels, 1, 1, 0, bias=True) 701 | self.bs = nn.Conv2d(channels, channels, 1, 1, 0, bias=True) 702 | 703 | def __call__(self, fea_left, fea_right, cost): 704 | b, c0, h0, w0 = fea_left.shape 705 | fea_left1 = self.head(fea_left) 706 | fea_right1 = self.head(fea_right) 707 | Q = self.bq(fea_left1) 708 | b, c, h, w = Q.shape 709 | Q = Q - torch.mean(Q, 3).unsqueeze(3).repeat(1, 1, 1, w) 710 | K = self.bs(fea_right1) 711 | K = K - torch.mean(K, 3).unsqueeze(3).repeat(1, 1, 1, w) 712 | 713 | score = torch.bmm(Q.permute(0, 2, 3, 1).contiguous().view(-1, w, c), # (B*H) * Wl * C 714 | K.permute(0, 2, 1, 3).contiguous().view(-1, c, w)) 715 | 716 | cost[0] += score 717 | cost[1] += score.permute(0, 2, 1) 718 | return fea_left+fea_left1, fea_right+fea_right1, cost 719 | 720 | 721 | def M_Relax(M, num_pixels): 722 | _, u, v = M.shape 723 | M_list = [] 724 | M_list.append(M.unsqueeze(1)) 725 | for i in range(num_pixels): 726 | pad = nn.ZeroPad2d(padding=(0, 0, i+1, 0)) 727 | pad_M = pad(M[:, :-1-i, :]) 728 | M_list.append(pad_M.unsqueeze(1)) 729 | for i in range(num_pixels): 730 | pad = nn.ZeroPad2d(padding=(0, 0, 0, i+1)) 731 | pad_M = pad(M[:, i+1:, :]) 732 | M_list.append(pad_M.unsqueeze(1)) 733 | M_relaxed = torch.sum(torch.cat(M_list, 1), dim=1) 734 | return M_relaxed 735 | 736 | if __name__ == "__main__": 737 | net = Net(upscale_factor=4) 738 | total = sum([param.nelement() for param in net.parameters()]) 739 | print(' Number of params: %.2fM' % (total / 1e6)) -------------------------------------------------------------------------------- /result.txt: -------------------------------------------------------------------------------- 1 | StereoSR: 2 | 3 | KITTI2012 mean psnr left: 26.619422161710588 4 | KITTI2012 mean psnr average: 26.712393708455398 5 | 6 | KITTI2015 mean psnr left: 25.74107406904548 7 | KITTI2015 mean psnr average: 26.464446701876494 8 | 9 | Middlebury mean psnr left: 29.308098231282155 10 | Middlebury mean psnr average: 29.381088935551993 11 | 12 | Flickr1024 mean psnr left: 23.50759679238374 13 | Flickr1024 mean psnr average: 23.585143575772204 14 | 15 | 16 | Disparity estimation: 17 | 18 | K2015 disp high left all1: 4.089844226837158 19 | K2015 disp high left noc1: 3.322787284851074 20 | K2015 disp high left all0: 5.026285648345947 21 | K2015 disp high left noc0: 4.3886919021606445 22 | K2015 disp left all1: 4.741995811462402 23 | K2015 disp left noc1: 3.98885440826416 24 | K2015 disp left all0: 5.614530086517334 25 | K2015 disp left noc0: 4.997827529907227 26 | 27 | K2012 disp high left all1: 4.963388919830322 28 | K2012 disp high left noc1: 3.6856789588928223 29 | K2012 disp high left all0: 6.467651844024658 30 | K2012 disp high left noc0: 5.434603214263916 31 | K2012 disp left all1: 5.762098789215088 32 | K2012 disp left noc1: 4.506109237670898 33 | K2012 disp left all0: 7.1608195304870605 34 | K2012 disp left noc0: 6.164617538452148 35 | -------------------------------------------------------------------------------- /test_disp.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | from PIL import Image 3 | from torchvision.transforms import ToTensor 4 | import argparse 5 | import os 6 | from model import * 7 | from utils import * 8 | import numpy as np 9 | import torch.nn.functional as F 10 | import re 11 | from metric import * 12 | import imageio 13 | 14 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 15 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--testset_dir', type=str, default='./data/test/') 20 | parser.add_argument('--scale_factor', type=int, default=4) 21 | parser.add_argument('--device', type=str, default='cuda') 22 | parser.add_argument('--model_name', type=str, default='SSRDEF_4xSR') 23 | return parser.parse_args() 24 | 25 | def read_disp(filename, subset=False): 26 | # Scene Flow dataset 27 | if filename.endswith('pfm'): 28 | # For finalpass and cleanpass, gt disparity is positive, subset is negative 29 | disp = np.ascontiguousarray(_read_pfm(filename)[0]) 30 | if subset: 31 | disp = -disp 32 | # KITTI 33 | elif filename.endswith('png'): 34 | disp = _read_kitti_disp(filename) 35 | elif filename.endswith('npy'): 36 | disp = np.load(filename) 37 | else: 38 | raise Exception('Invalid disparity file format!') 39 | return disp # [H, W] 40 | 41 | def _read_pfm(file): 42 | file = open(file, 'rb') 43 | 44 | color = None 45 | width = None 46 | height = None 47 | scale = None 48 | endian = None 49 | 50 | header = file.readline().rstrip() 51 | if header.decode("ascii") == 'PF': 52 | color = True 53 | elif header.decode("ascii") == 'Pf': 54 | color = False 55 | else: 56 | raise Exception('Not a PFM file.') 57 | 58 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) 59 | if dim_match: 60 | width, height = list(map(int, dim_match.groups())) 61 | else: 62 | raise Exception('Malformed PFM header.') 63 | 64 | scale = float(file.readline().decode("ascii").rstrip()) 65 | if scale < 0: # little-endian 66 | endian = '<' 67 | scale = -scale 68 | else: 69 | endian = '>' # big-endian 70 | 71 | data = np.fromfile(file, endian + 'f') 72 | shape = (height, width, 3) if color else (height, width) 73 | 74 | data = np.reshape(data, shape) 75 | data = np.flipud(data) 76 | return data, scale 77 | 78 | def _read_kitti_disp(filename): 79 | depth = np.array(Image.open(filename)) 80 | depth = depth.astype(np.float32)/256. 81 | return depth 82 | 83 | def test(cfg, loadname, net): 84 | print(loadname) 85 | psnr_list = [] 86 | psnr_list_r = [] 87 | psnr_list_m = [] 88 | psnr_list_r_m = [] 89 | disphigh_occ0 = [] 90 | disphigh_noc0 = [] 91 | disp_occ0 = [] 92 | disp_noc0 = [] 93 | disphigh_occ1 = [] 94 | disphigh_noc1 = [] 95 | disp_occ1 = [] 96 | disp_noc1 = [] 97 | 98 | file_list = os.listdir(cfg.testset_dir + cfg.dataset + '/hr') 99 | for idx in range(len(file_list)): 100 | LR_left = Image.open(cfg.testset_dir + cfg.dataset + '/lr_x' + str(cfg.scale_factor) + '/' + file_list[idx] + '/lr0.png') 101 | LR_right = Image.open(cfg.testset_dir + cfg.dataset + '/lr_x' + str(cfg.scale_factor) + '/' + file_list[idx] + '/lr1.png') 102 | HR_left = Image.open(cfg.testset_dir + cfg.dataset + '/hr/' + file_list[idx] + '/hr0.png') 103 | HR_right = Image.open(cfg.testset_dir + cfg.dataset + '/hr/' + file_list[idx] + '/hr1.png') 104 | disp_left = read_disp(cfg.testset_dir + cfg.dataset + '/hr/' + file_list[idx] + '/dispocc0.png') 105 | disp_leftall = read_disp(cfg.testset_dir + cfg.dataset + '/hr/' + file_list[idx] + '/dispnoc0.png') 106 | 107 | LR_left, LR_right, HR_left, HR_right, disp_left, disp_leftall = ToTensor()(LR_left), ToTensor()(LR_right), ToTensor()(HR_left), ToTensor()(HR_right), ToTensor()(disp_left), ToTensor()(disp_leftall) 108 | LR_left, LR_right, HR_left, HR_right, disp_left, disp_leftall = LR_left.unsqueeze(0), LR_right.unsqueeze(0), HR_left.unsqueeze(0), HR_right.unsqueeze(0), disp_left.unsqueeze(0), disp_leftall.unsqueeze(0) 109 | LR_left, LR_right, HR_left, HR_right, disp_left, disp_leftall = Variable(LR_left).cuda(), Variable(LR_right).cuda(), Variable(HR_left).cuda(), Variable(HR_right).cuda(), Variable(disp_left).cuda(), Variable(disp_leftall).cuda() 110 | scene_name = file_list[idx] 111 | 112 | _,_,h,w=disp_left.shape 113 | 114 | 115 | disp_left=disp_left.view(1,h,w) 116 | disp_leftall=disp_leftall.view(1,h,w) 117 | 118 | mask0 = (disp_left > 0) & (disp_left < 192) 119 | mask1 = (disp_leftall > 0) & (disp_leftall < 192) 120 | _,h,w=mask0.shape 121 | 122 | 123 | #print('Running Scene ' + scene_name + ' of ' + cfg.dataset + ' Dataset......') 124 | with torch.no_grad(): 125 | _,_,_,_,_,_,SR_left, SR_right, (disp1, disp2), (disp1_3, disp2_3), (disp1_high, disp2_high), (disp1_high2, disp2_high2) = net(LR_left, LR_right, is_training=0) 126 | 127 | SR_left, SR_right = torch.clamp(SR_left, 0, 1), torch.clamp(SR_right, 0, 1) 128 | 129 | psnr_list.append(cal_psnr(HR_left[:,:,:,64:].data.cpu(), SR_left[:,:,:,64:].data.cpu())) 130 | psnr_list_r.append(cal_psnr(HR_right.data.cpu(), SR_right.data.cpu())) 131 | psnr_list_r.append(cal_psnr(HR_left.data.cpu(), SR_left.data.cpu())) 132 | ''' 133 | psnr_list_m.append(cal_psnr(HR_left[:,:,:,64:].data.cpu(), SR_left[0][:,:,:,64:].data.cpu())) 134 | psnr_list_r_m.append(cal_psnr(HR_right.data.cpu(), SR_right[0].data.cpu())) 135 | psnr_list_r_m.append(cal_psnr(HR_left.data.cpu(), SR_left[0].data.cpu())) 136 | ''' 137 | 138 | disphigh_occ0.append((disp1_high[mask0].cpu()-disp_left[mask0].cpu()).abs().mean()) 139 | disphigh_noc0.append((disp1_high[mask1].cpu()-disp_leftall[mask1].cpu()).abs().mean()) 140 | disphigh_occ1.append((disp1_high2[mask0].cpu()-disp_left[mask0].cpu()).abs().mean()) 141 | disphigh_noc1.append((disp1_high2[mask1].cpu()-disp_leftall[mask1].cpu()).abs().mean()) 142 | 143 | disp_occ0.append((disp1[mask0].cpu()-disp_left[mask0].cpu()).abs().mean()) 144 | disp_noc0.append((disp1[mask1].cpu()-disp_leftall[mask1].cpu()).abs().mean()) 145 | disp_occ1.append((disp1_3[mask0].cpu()-disp_left[mask0].cpu()).abs().mean()) 146 | disp_noc1.append((disp1_3[mask1].cpu()-disp_leftall[mask1].cpu()).abs().mean()) 147 | 148 | 149 | #psnr_list_m.append(cal_psnr(HR_left[:,:,:,64:].data.cpu(), out1_left[:,:,:,64:].data.cpu())) 150 | #psnr_list_r_m.append(cal_psnr(HR_right.data.cpu(), out1_right.data.cpu())) 151 | #psnr_list_r_m.append(cal_psnr(HR_left.data.cpu(), out1_left.data.cpu())) 152 | #print(torch.mean(V_left2)) 153 | 154 | save_path = './results/' + cfg.model_name + '/' + cfg.dataset 155 | if not os.path.exists(save_path): 156 | os.makedirs(save_path) 157 | SR_left_img = transforms.ToPILImage()(torch.squeeze(SR_left.data.cpu(), 0)) 158 | SR_left_img.save(save_path + '/' + scene_name + '_L.png') 159 | SR_right_img = transforms.ToPILImage()(torch.squeeze(SR_right.data.cpu(), 0)) 160 | SR_right_img.save(save_path + '/' + scene_name + '_R.png') 161 | 162 | imageio.imsave(save_path + '/' + scene_name + '_Lhdisp.png', torch.squeeze(disp1_high2.data.cpu(), 0)) 163 | imageio.imsave(save_path + '/' + scene_name + '_Rhdisp.png', torch.squeeze(disp2_high2.data.cpu(), 0)) 164 | 165 | print(cfg.dataset + ' mean psnr left: ', float(np.array(psnr_list).mean())) 166 | print(cfg.dataset + ' mean psnr average: ', float(np.array(psnr_list_r).mean())) 167 | #print(cfg.dataset + ' mean psnr left intermedite: ', float(np.array(psnr_list_m).mean())) 168 | #print(cfg.dataset + ' mean psnr average intermedite: ', float(np.array(psnr_list_r_m).mean())) 169 | 170 | print(cfg.dataset + ' disp high left all1: ', float(np.array(disphigh_occ1).mean())) 171 | print(cfg.dataset + ' disp high left noc1: ', float(np.array(disphigh_noc1).mean())) 172 | print(cfg.dataset + ' disp high left all0: ', float(np.array(disphigh_occ0).mean())) 173 | print(cfg.dataset + ' disp high left noc0: ', float(np.array(disphigh_noc0).mean())) 174 | 175 | print(cfg.dataset + ' disp left all1: ', float(np.array(disp_occ1).mean())) 176 | print(cfg.dataset + ' disp left noc1: ', float(np.array(disp_noc1).mean())) 177 | print(cfg.dataset + ' disp left all0: ', float(np.array(disp_occ0).mean())) 178 | print(cfg.dataset + ' disp left noc0: ', float(np.array(disp_noc0).mean())) 179 | 180 | 181 | 182 | if __name__ == '__main__': 183 | cfg = parse_args() 184 | dataset_list = ['K2012','K2015'] 185 | net = SSRDEFNet(cfg.scale_factor).cuda() 186 | net = torch.nn.DataParallel(net) 187 | net.eval() 188 | 189 | for j in range(80,81): 190 | loadname = './checkpoints/SSRDEF_4xSR_epoch' + str(j) + '.pth.tar' 191 | model = torch.load(loadname) 192 | net.load_state_dict(model['state_dict']) 193 | for i in range(len(dataset_list)): 194 | cfg.dataset = dataset_list[i] 195 | test(cfg, loadname, net) 196 | print('Finished!') 197 | -------------------------------------------------------------------------------- /test_sr.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | from PIL import Image 3 | from torchvision.transforms import ToTensor 4 | import argparse 5 | import os 6 | from model import * 7 | from utils import * 8 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 9 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--testset_dir', type=str, default='./data/test/') 14 | parser.add_argument('--scale_factor', type=int, default=4) 15 | parser.add_argument('--device', type=str, default='cuda') 16 | parser.add_argument('--model_name', type=str, default='SSRDEF_4xSR') 17 | return parser.parse_args() 18 | 19 | 20 | def test(cfg, loadname, net): 21 | print(loadname) 22 | psnr_list = [] 23 | psnr_list_r = [] 24 | psnr_list_m1 = [] 25 | psnr_list_r_m1 = [] 26 | psnr_list_m2 = [] 27 | psnr_list_r_m2 = [] 28 | psnr_list_m3 = [] 29 | psnr_list_r_m3 = [] 30 | 31 | file_list = os.listdir(cfg.testset_dir + cfg.dataset + '/hr') 32 | for idx in range(len(file_list)): 33 | LR_left = Image.open(cfg.testset_dir + cfg.dataset + '/lr_x' + str(cfg.scale_factor) + '/' + file_list[idx] + '/lr0.png') 34 | LR_right = Image.open(cfg.testset_dir + cfg.dataset + '/lr_x' + str(cfg.scale_factor) + '/' + file_list[idx] + '/lr1.png') 35 | HR_left = Image.open(cfg.testset_dir + cfg.dataset + '/hr/' + file_list[idx] + '/hr0.png') 36 | HR_right = Image.open(cfg.testset_dir + cfg.dataset + '/hr/' + file_list[idx] + '/hr1.png') 37 | 38 | LR_left, LR_right, HR_left, HR_right = ToTensor()(LR_left), ToTensor()(LR_right), ToTensor()(HR_left), ToTensor()(HR_right) 39 | LR_left, LR_right, HR_left, HR_right = LR_left.unsqueeze(0), LR_right.unsqueeze(0), HR_left.unsqueeze(0), HR_right.unsqueeze(0) 40 | LR_left, LR_right, HR_left, HR_right = Variable(LR_left).cuda(), Variable(LR_right).cuda(), Variable(HR_left).cuda(), Variable(HR_right).cuda() 41 | scene_name = file_list[idx] 42 | #print('Running Scene ' + scene_name + ' of ' + cfg.dataset + ' Dataset......') 43 | with torch.no_grad(): 44 | SR_left1, SR_right1, SR_left2, SR_right2, SR_left3, SR_right3, SR_left4, SR_right4,_,_,_,_ = net(LR_left, LR_right, is_training=0) 45 | SR_left1, SR_right1 = torch.clamp(SR_left1, 0, 1), torch.clamp(SR_right1, 0, 1) 46 | SR_left2, SR_right2 = torch.clamp(SR_left2, 0, 1), torch.clamp(SR_right2, 0, 1) 47 | SR_left3, SR_right3 = torch.clamp(SR_left3, 0, 1), torch.clamp(SR_right3, 0, 1) 48 | SR_left4, SR_right4 = torch.clamp(SR_left4, 0, 1), torch.clamp(SR_right4, 0, 1) 49 | 50 | psnr_list.append(cal_psnr(HR_left[:,:,:,64:].data.cpu(), SR_left4[:,:,:,64:].data.cpu())) 51 | psnr_list_r.append(cal_psnr(HR_right.data.cpu(), SR_right4.data.cpu())) 52 | psnr_list_r.append(cal_psnr(HR_left.data.cpu(), SR_left4.data.cpu())) 53 | 54 | psnr_l = cal_psnr(HR_left.data.cpu(), SR_left4.data.cpu()) 55 | psnr_r = cal_psnr(HR_right.data.cpu(), SR_right4.data.cpu()) 56 | 57 | psnr_l1 = cal_psnr(HR_left.data.cpu(), SR_left1.data.cpu()) 58 | psnr_r1 = cal_psnr(HR_right.data.cpu(), SR_right1.data.cpu()) 59 | 60 | psnr_l2 = cal_psnr(HR_left.data.cpu(), SR_left2.data.cpu()) 61 | psnr_r2 = cal_psnr(HR_right.data.cpu(), SR_right2.data.cpu()) 62 | 63 | psnr_l3 = cal_psnr(HR_left.data.cpu(), SR_left3.data.cpu()) 64 | psnr_r3 = cal_psnr(HR_right.data.cpu(), SR_right3.data.cpu()) 65 | 66 | psnr_list_m1.append(cal_psnr(HR_left[:,:,:,64:].data.cpu(), SR_left1[:,:,:,64:].data.cpu())) 67 | psnr_list_r_m1.append(cal_psnr(HR_right.data.cpu(), SR_right1.data.cpu())) 68 | psnr_list_r_m1.append(cal_psnr(HR_left.data.cpu(), SR_left1.data.cpu())) 69 | 70 | psnr_list_m2.append(cal_psnr(HR_left[:,:,:,64:].data.cpu(), SR_left2[:,:,:,64:].data.cpu())) 71 | psnr_list_r_m2.append(cal_psnr(HR_right.data.cpu(), SR_right2.data.cpu())) 72 | psnr_list_r_m2.append(cal_psnr(HR_left.data.cpu(), SR_left2.data.cpu())) 73 | 74 | psnr_list_m3.append(cal_psnr(HR_left[:,:,:,64:].data.cpu(), SR_left3[:,:,:,64:].data.cpu())) 75 | psnr_list_r_m3.append(cal_psnr(HR_right.data.cpu(), SR_right3.data.cpu())) 76 | psnr_list_r_m3.append(cal_psnr(HR_left.data.cpu(), SR_left3.data.cpu())) 77 | 78 | save_path = './results/' + cfg.model_name + '/' + cfg.dataset 79 | if not os.path.exists(save_path): 80 | os.makedirs(save_path) 81 | 82 | SR_left_img = transforms.ToPILImage()(torch.squeeze(SR_left4.data.cpu(), 0)) 83 | SR_left_img.save(save_path + '/' + scene_name + '_L%.2f.png'%psnr_l) 84 | SR_right_img = transforms.ToPILImage()(torch.squeeze(SR_right4.data.cpu(), 0)) 85 | SR_right_img.save(save_path + '/' + scene_name + '_R%.2f.png'%psnr_r) 86 | ''' 87 | SR_left_img = transforms.ToPILImage()(torch.squeeze(SR_left3.data.cpu(), 0)) 88 | SR_left_img.save(save_path + '/' + scene_name + '_L%.2f.png'%psnr_l3) 89 | SR_right_img = transforms.ToPILImage()(torch.squeeze(SR_right3.data.cpu(), 0)) 90 | SR_right_img.save(save_path + '/' + scene_name + '_R%.2f.png'%psnr_r3) 91 | 92 | SR_left_img = transforms.ToPILImage()(torch.squeeze(SR_left2.data.cpu(), 0)) 93 | SR_left_img.save(save_path + '/' + scene_name + '_L%.2f.png'%psnr_l2) 94 | SR_right_img = transforms.ToPILImage()(torch.squeeze(SR_right2.data.cpu(), 0)) 95 | SR_right_img.save(save_path + '/' + scene_name + '_R%.2f.png'%psnr_r2) 96 | 97 | SR_left_img = transforms.ToPILImage()(torch.squeeze(SR_left1.data.cpu(), 0)) 98 | SR_left_img.save(save_path + '/' + scene_name + '_L%.2f.png'%psnr_l1) 99 | SR_right_img = transforms.ToPILImage()(torch.squeeze(SR_right1.data.cpu(), 0)) 100 | SR_right_img.save(save_path + '/' + scene_name + '_R%.2f.png'%psnr_r1) 101 | ''' 102 | 103 | 104 | print(cfg.dataset + ' mean psnr left: ', float(np.array(psnr_list).mean())) 105 | print(cfg.dataset + ' mean psnr average: ', float(np.array(psnr_list_r).mean())) 106 | print(cfg.dataset + ' mean psnr left intermediate3: ', float(np.array(psnr_list_m3).mean())) 107 | print(cfg.dataset + ' mean psnr average intermediate3: ', float(np.array(psnr_list_r_m3).mean())) 108 | print(cfg.dataset + ' mean psnr left intermediate2: ', float(np.array(psnr_list_m2).mean())) 109 | print(cfg.dataset + ' mean psnr average intermediate2: ', float(np.array(psnr_list_r_m2).mean())) 110 | print(cfg.dataset + ' mean psnr left intermediate1: ', float(np.array(psnr_list_m1).mean())) 111 | print(cfg.dataset + ' mean psnr average intermediate1: ', float(np.array(psnr_list_r_m1).mean())) 112 | 113 | if __name__ == '__main__': 114 | cfg = parse_args() 115 | dataset_list = ['Middlebury', 'KITTI2012','KITTI2015','Flickr1024'] 116 | net = SSRDEFNet(cfg.scale_factor).cuda() 117 | net = torch.nn.DataParallel(net) 118 | net.eval() 119 | 120 | for j in range(80,81): 121 | loadname = './checkpoints/SSRDEF_4xSR_epoch' + str(j) + '.pth.tar' 122 | print(loadname) 123 | model = torch.load(loadname) 124 | net.load_state_dict(model['state_dict']) 125 | for i in range(len(dataset_list)): 126 | cfg.dataset = dataset_list[i] 127 | test(cfg, loadname, net) 128 | print('Finished!') 129 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | from torch.utils.data import DataLoader 3 | import torch.backends.cudnn as cudnn 4 | import argparse 5 | from utils import * 6 | from model import * 7 | from torchvision.transforms import ToTensor 8 | import os 9 | import torch.nn.functional as F 10 | from loss import * 11 | 12 | 13 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 14 | os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3' 15 | 16 | 17 | def get_parameter_number(net): 18 | total_num = sum(p.numel() for p in net.parameters()) 19 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 20 | return {'Total': total_num, 'Trainable': trainable_num} 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--scale_factor", type=int, default=4) 25 | parser.add_argument('--device', type=str, default='cuda') 26 | parser.add_argument('--batch_size', type=int, default=12) 27 | parser.add_argument('--lr', type=float, default=2e-4, help='initial learning rate') 28 | parser.add_argument('--gamma', type=float, default=0.5, help='') 29 | parser.add_argument('--start_epoch', type=int, default=0, help='start epoch') 30 | parser.add_argument('--n_epochs', type=int, default=80, help='number of epochs to train') 31 | parser.add_argument('--n_steps', type=int, default=30, help='number of epochs to update learning rate') 32 | parser.add_argument('--trainset_dir', type=str, default='./data/train/Flickr1024_patches') 33 | parser.add_argument('--model_name', type=str, default='SSRDEF') 34 | parser.add_argument('--load_pretrain', type=bool, default=False) 35 | parser.add_argument('--model_path', type=str, default='') 36 | parser.add_argument('--testset_dir', type=str, default='./data/test/') 37 | return parser.parse_args() 38 | 39 | def warpfeature(feat, disp_range_samples, cost_prob, ndisp): 40 | bs, channels, height, width = feat.size() 41 | mh,_ = torch.meshgrid([torch.arange(0, height, dtype=feat.dtype, device=feat.device), torch.arange(0, width, dtype=feat.dtype, device=feat.device)]) # (H *W) 42 | mh = mh.reshape(1, 1, height, width).repeat(bs, ndisp, 1, 1) 43 | 44 | cur_disp_coords_y = mh 45 | cur_disp_coords_x = disp_range_samples 46 | 47 | coords_x = cur_disp_coords_x / ((width - 1.0) / 2.0) - 1.0 # trans to -1 - 1 48 | coords_y = cur_disp_coords_y / ((height - 1.0) / 2.0) - 1.0 49 | grid = torch.stack([coords_x, coords_y], dim=4).view(bs, ndisp * height, width, 2) #(B, D, H, W, 2)->(B, D*H, W, 2) 50 | 51 | #warped = F.grid_sample(feat, grid.view(bs, ndisp * height, width, 2), mode='bilinear', padding_mode='zeros').view(bs, channels, ndisp, height, width) 52 | warped_feat = cost_prob[:, 0, :, :].unsqueeze(1) * F.grid_sample(feat, grid[:, :height, :, :], mode='bilinear', padding_mode='zeros').view(bs, channels,height, width) 53 | for i in range(1, ndisp): 54 | warped_feat += cost_prob[:, i, :, :].unsqueeze(1) * F.grid_sample(feat, grid[:, i*height:(i+1)*height, :, :], mode='bilinear', padding_mode='zeros').view(bs, channels,height, width) 55 | 56 | return warped_feat 57 | 58 | def dispwarpfeature(feat, disp): 59 | bs, channels, height, width = feat.size() 60 | mh,_ = torch.meshgrid([torch.arange(0, height, dtype=feat.dtype, device=feat.device), torch.arange(0, width, dtype=feat.dtype, device=feat.device)]) # (H *W) 61 | mh = mh.reshape(1, 1, height, width).repeat(bs, 1, 1, 1) 62 | 63 | cur_disp_coords_y = mh 64 | cur_disp_coords_x = disp 65 | 66 | coords_x = cur_disp_coords_x / ((width - 1.0) / 2.0) - 1.0 # trans to -1 - 1 67 | coords_y = cur_disp_coords_y / ((height - 1.0) / 2.0) - 1.0 68 | grid = torch.stack([coords_x, coords_y], dim=4).view(bs, height, width, 2) #(B, D, H, W, 2)->(B, D*H, W, 2) 69 | 70 | #warped = F.grid_sample(feat, grid.view(bs, ndisp * height, width, 2), mode='bilinear', padding_mode='zeros').view(bs, channels, ndisp, height, width) 71 | warped_feat = F.grid_sample(feat, grid, mode='bilinear', padding_mode='zeros').view(bs, channels,height, width) 72 | 73 | return warped_feat 74 | 75 | def cal_grad(image): 76 | """ 77 | Calculate the image-edge-aware second-order smoothness loss for flo 78 | """ 79 | 80 | def gradient(pred): 81 | D_dy = pred[:, :, 1:, :] - pred[:, :, :-1, :] 82 | D_dx = pred[:, :, :, 1:] - pred[:, :, :, :-1] 83 | D_dy = F.pad(D_dy, pad=(0,0,0,1), mode="constant", value=0) 84 | D_dx = F.pad(D_dx, pad=(0,1,0,0), mode="constant", value=0) 85 | return D_dx, D_dy 86 | 87 | 88 | img_grad_x, img_grad_y = gradient(image) 89 | weights_x = torch.exp(-10.0 * torch.mean(torch.abs(img_grad_x), 1, keepdim=True)) 90 | weights_y = torch.exp(-10.0 * torch.mean(torch.abs(img_grad_y), 1, keepdim=True)) 91 | 92 | return weights_x, weights_y 93 | 94 | def load_pretrain(model, pretrained_dict): 95 | torch_params = model.state_dict() 96 | for k,v in pretrained_dict.items(): 97 | print(k) 98 | pretrained_dict_1 = {k: v for k, v in pretrained_dict.items() if k in torch_params} 99 | torch_params.update(pretrained_dict_1) 100 | model.load_state_dict(torch_params) 101 | 102 | def train(train_loader, cfg): 103 | net = SSRDEFNet(cfg.scale_factor).cuda() 104 | print(get_parameter_number(net)) 105 | cudnn.benchmark = True 106 | scale = cfg.scale_factor 107 | 108 | net = torch.nn.DataParallel(net) 109 | 110 | if cfg.load_pretrain: 111 | if os.path.isfile(cfg.model_path): 112 | model = torch.load(cfg.model_path) 113 | net.load_state_dict(model['state_dict']) 114 | cfg.start_epoch = model["epoch"] 115 | else: 116 | print("=> no model found at '{}'".format(cfg.load_model)) 117 | 118 | 119 | # net = torch.nn.DataParallel(net, device_ids=[0, 1]) 120 | criterion_L1 = torch.nn.L1Loss().cuda() 121 | optimizer = torch.optim.Adam([paras for paras in net.parameters() if paras.requires_grad == True], lr=cfg.lr) 122 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=cfg.n_steps, gamma=cfg.gamma) 123 | 124 | loss_epoch = [] 125 | loss_list = [] 126 | psnr_epoch = [] 127 | psnr_epoch_r = [] 128 | psnr_epoch_m = [] 129 | psnr_epoch_r_m = [] 130 | 131 | for idx_epoch in range(cfg.start_epoch, cfg.n_epochs): 132 | 133 | for idx_iter, (HR_left, HR_right, LR_left, LR_right) in enumerate(train_loader): 134 | b, c, h, w = LR_left.shape 135 | _, _, h2, w2 = HR_left.shape 136 | HR_left, HR_right, LR_left, LR_right = Variable(HR_left).cuda(), Variable(HR_right).cuda(),\ 137 | Variable(LR_left).cuda(), Variable(LR_right).cuda() 138 | 139 | SR_left, SR_right, SR_left2, SR_right2, SR_left3, SR_right3, SR_left4, SR_right4,\ 140 | (M_right_to_left, M_left_to_right), (disp1, disp2), (V_left, V_right), (V_left2, V_right2), (disp1_high, disp2_high),\ 141 | (M_right_to_left3, M_left_to_right3), (disp1_3, disp2_3), (V_left3, V_right3), (V_left4, V_right4), (disp1_high_2, disp2_high_2)\ 142 | =net(LR_left, LR_right, is_training=1) 143 | 144 | ''' SR Loss ''' 145 | loss_SR = criterion_L1(SR_left, HR_left) + criterion_L1(SR_right, HR_right) + criterion_L1(SR_left2, HR_left) + criterion_L1(SR_right2, HR_right) +\ 146 | criterion_L1(SR_left3, HR_left) + criterion_L1(SR_right3, HR_right) + criterion_L1(SR_left4, HR_left) + criterion_L1(SR_right4, HR_right) 147 | 148 | loss_S = loss_disp_smoothness(disp1_high, HR_left) + loss_disp_smoothness(disp2_high, HR_right) + \ 149 | loss_disp_smoothness(disp1_high_2, HR_left) + loss_disp_smoothness(disp2_high_2, HR_right) 150 | 151 | loss_P = loss_disp_unsupervised(HR_left, HR_right, disp1, F.interpolate(V_left, scale_factor=4, mode='nearest')) + loss_disp_unsupervised(HR_right, HR_left, disp2, F.interpolate(V_right, scale_factor=4, mode='nearest')) +\ 152 | loss_disp_unsupervised(HR_left, HR_right, disp1_high, V_left2) + loss_disp_unsupervised(HR_right, HR_left, disp2_high, V_right2) + \ 153 | loss_disp_unsupervised(HR_left, HR_right, disp1_3, F.interpolate(V_left3, scale_factor=4, mode='nearest')) + loss_disp_unsupervised(HR_right, HR_left, disp2_3, F.interpolate(V_right3, scale_factor=4, mode='nearest')) +\ 154 | loss_disp_unsupervised(HR_left, HR_right, disp1_high_2, V_left4) + loss_disp_unsupervised(HR_right, HR_left, disp2_high_2, V_right4) 155 | 156 | ''' Photometric Loss ''' 157 | Res_left = torch.abs(HR_left - F.interpolate(LR_left, scale_factor=scale, mode='bicubic', align_corners=False)) 158 | Res_left_low = F.interpolate(Res_left, scale_factor=1 / scale, mode='bicubic', align_corners=False) 159 | Res_right = torch.abs(HR_right - F.interpolate(LR_right, scale_factor=scale, mode='bicubic', align_corners=False)) 160 | Res_right_low = F.interpolate(Res_right, scale_factor=1 / scale, mode='bicubic', align_corners=False) 161 | Res_leftT_low = torch.bmm(M_right_to_left.contiguous().view(b * h, w, w), Res_right_low.permute(0, 2, 3, 1).contiguous().view(b * h, w, c) 162 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2) 163 | Res_rightT_low = torch.bmm(M_left_to_right.contiguous().view(b * h, w, w), Res_left_low.permute(0, 2, 3, 1).contiguous().view(b * h, w, c) 164 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2) 165 | Res_leftT_low2 = torch.bmm(M_right_to_left3.contiguous().view(b * h, w, w), Res_right_low.permute(0, 2, 3, 1).contiguous().view(b * h, w, c) 166 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2) 167 | Res_rightT_low2 = torch.bmm(M_left_to_right3.contiguous().view(b * h, w, w), Res_left_low.permute(0, 2, 3, 1).contiguous().view(b * h, w, c) 168 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2) 169 | Res_leftT = dispwarpfeature(Res_right, disp1_high) 170 | Res_rightT = dispwarpfeature(Res_left, disp2_high) 171 | Res_leftT2 = dispwarpfeature(Res_right, disp1_high_2) 172 | Res_rightT2 = dispwarpfeature(Res_left, disp2_high_2) 173 | 174 | loss_photo = criterion_L1(Res_left_low * V_left.repeat(1, 3, 1, 1), Res_leftT_low * V_left.repeat(1, 3, 1, 1)) + \ 175 | criterion_L1(Res_right_low * V_right.repeat(1, 3, 1, 1), Res_rightT_low * V_right.repeat(1, 3, 1, 1)) + \ 176 | criterion_L1(Res_left * V_left2.repeat(1, 3, 1, 1), Res_leftT * V_left2.repeat(1, 3, 1, 1)) + \ 177 | criterion_L1(Res_right * V_right2.repeat(1, 3, 1, 1), Res_rightT * V_right2.repeat(1, 3, 1, 1)) +\ 178 | criterion_L1(Res_left_low * V_left3.repeat(1, 3, 1, 1), Res_leftT_low2 * V_left3.repeat(1, 3, 1, 1)) + \ 179 | criterion_L1(Res_right_low * V_right3.repeat(1, 3, 1, 1), Res_rightT_low2 * V_right3.repeat(1, 3, 1, 1)) + \ 180 | criterion_L1(Res_left * V_left4.repeat(1, 3, 1, 1), Res_leftT2 * V_left4.repeat(1, 3, 1, 1)) + \ 181 | criterion_L1(Res_right * V_right4.repeat(1, 3, 1, 1), Res_rightT2 * V_right4.repeat(1, 3, 1, 1)) 182 | 183 | loss_h = criterion_L1(M_right_to_left[:, :-1, :, :], M_right_to_left[:, 1:, :, :]) + \ 184 | criterion_L1(M_left_to_right[:, :-1, :, :], M_left_to_right[:, 1:, :, :]) + \ 185 | criterion_L1(M_right_to_left3[:, :-1, :, :], M_right_to_left3[:, 1:, :, :]) + \ 186 | criterion_L1(M_left_to_right3[:, :-1, :, :], M_left_to_right3[:, 1:, :, :]) 187 | 188 | loss_w = criterion_L1(M_right_to_left[:, :, :-1, :-1], M_right_to_left[:, :, 1:, 1:]) + \ 189 | criterion_L1(M_left_to_right[:, :, :-1, :-1], M_left_to_right[:, :, 1:, 1:]) + \ 190 | criterion_L1(M_right_to_left3[:, :, :-1, :-1], M_right_to_left3[:, :, 1:, 1:]) + \ 191 | criterion_L1(M_left_to_right3[:, :, :-1, :-1], M_left_to_right3[:, :, 1:, 1:]) 192 | 193 | loss_smooth = loss_w + loss_h 194 | 195 | ''' Cycle Loss ''' 196 | Res_left_cycle_low = torch.bmm(M_right_to_left.contiguous().view(b * h, w, w), Res_rightT_low.permute(0, 2, 3, 1).contiguous().view(b * h, w, c) 197 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2) 198 | Res_right_cycle_low = torch.bmm(M_left_to_right.contiguous().view(b * h, w, w), Res_leftT_low.permute(0, 2, 3, 1).contiguous().view(b * h, w, c) 199 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2) 200 | Res_left_cycle = dispwarpfeature(Res_rightT, disp1_high) 201 | Res_right_cycle = dispwarpfeature(Res_leftT, disp2_high) 202 | 203 | Res_left_cycle_low2 = torch.bmm(M_right_to_left3.contiguous().view(b * h, w, w), Res_rightT_low2.permute(0, 2, 3, 1).contiguous().view(b * h, w, c) 204 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2) 205 | Res_right_cycle_low2 = torch.bmm(M_left_to_right3.contiguous().view(b * h, w, w), Res_leftT_low2.permute(0, 2, 3, 1).contiguous().view(b * h, w, c) 206 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2) 207 | Res_left_cycle2 = dispwarpfeature(Res_rightT2, disp1_high_2) 208 | Res_right_cycle2 = dispwarpfeature(Res_leftT2, disp2_high_2) 209 | 210 | loss_cycle = criterion_L1(Res_left_low * V_left.repeat(1, 3, 1, 1), Res_left_cycle_low * V_left.repeat(1, 3, 1, 1)) + \ 211 | criterion_L1(Res_right_low * V_right.repeat(1, 3, 1, 1), Res_right_cycle_low * V_right.repeat(1, 3, 1, 1)) + \ 212 | criterion_L1(Res_left * V_left2.repeat(1, 3, 1, 1), Res_left_cycle * V_left2.repeat(1, 3, 1, 1)) + \ 213 | criterion_L1(Res_right * V_right2.repeat(1, 3, 1, 1), Res_right_cycle * V_right2.repeat(1, 3, 1, 1)) +\ 214 | criterion_L1(Res_left_low * V_left3.repeat(1, 3, 1, 1), Res_left_cycle_low2 * V_left3.repeat(1, 3, 1, 1)) + \ 215 | criterion_L1(Res_right_low * V_right3.repeat(1, 3, 1, 1), Res_right_cycle_low2 * V_right3.repeat(1, 3, 1, 1)) + \ 216 | criterion_L1(Res_left * V_left4.repeat(1, 3, 1, 1), Res_left_cycle2 * V_left4.repeat(1, 3, 1, 1)) + \ 217 | criterion_L1(Res_right * V_right4.repeat(1, 3, 1, 1), Res_right_cycle2 * V_right4.repeat(1, 3, 1, 1)) 218 | 219 | ''' Consistency Loss ''' 220 | SR_left_res = F.interpolate(torch.abs(HR_left - SR_left), scale_factor=1 / scale, mode='bicubic', align_corners=False) 221 | SR_right_res = F.interpolate(torch.abs(HR_right - SR_right), scale_factor=1 / scale, mode='bicubic', align_corners=False) 222 | SR_left_res3 = F.interpolate(torch.abs(HR_left - SR_left3), scale_factor=1 / scale, mode='bicubic', align_corners=False) 223 | SR_right_res3 = F.interpolate(torch.abs(HR_right - SR_right3), scale_factor=1 / scale, mode='bicubic', align_corners=False) 224 | 225 | SR_left_res2 = torch.abs(HR_left - SR_left2) 226 | SR_right_res2 = torch.abs(HR_right - SR_right2) 227 | SR_left_res4 = torch.abs(HR_left - SR_left4) 228 | SR_right_res4 = torch.abs(HR_right - SR_right4) 229 | 230 | SR_left_resT = torch.bmm(M_right_to_left.detach().contiguous().view(b * h, w, w), SR_right_res.permute(0, 2, 3, 1).contiguous().view(b * h, w, c) 231 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2) 232 | SR_right_resT = torch.bmm(M_left_to_right.detach().contiguous().view(b * h, w, w), SR_left_res.permute(0, 2, 3, 1).contiguous().view(b * h, w, c) 233 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2) 234 | SR_left_resT2 = dispwarpfeature(SR_right_res2, disp1_high) 235 | SR_right_resT2 = dispwarpfeature(SR_left_res2, disp2_high) 236 | 237 | SR_left_resT3 = torch.bmm(M_right_to_left3.detach().contiguous().view(b * h, w, w), SR_right_res3.permute(0, 2, 3, 1).contiguous().view(b * h, w, c) 238 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2) 239 | SR_right_resT3 = torch.bmm(M_left_to_right3.detach().contiguous().view(b * h, w, w), SR_left_res3.permute(0, 2, 3, 1).contiguous().view(b * h, w, c) 240 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2) 241 | SR_left_resT4 = dispwarpfeature(SR_right_res4, disp1_high_2) 242 | SR_right_resT4 = dispwarpfeature(SR_left_res4, disp2_high_2) 243 | 244 | loss_cons = criterion_L1(SR_left_res * V_left.repeat(1, 3, 1, 1), SR_left_resT * V_left.repeat(1, 3, 1, 1)) + \ 245 | criterion_L1(SR_right_res * V_right.repeat(1, 3, 1, 1), SR_right_resT * V_right.repeat(1, 3, 1, 1)) + \ 246 | criterion_L1(SR_left_res2 * V_left2.repeat(1, 3, 1, 1), SR_left_resT2 * V_left2.repeat(1, 3, 1, 1)) + \ 247 | criterion_L1(SR_right_res2 * V_right2.repeat(1, 3, 1, 1), SR_right_resT2 * V_right2.repeat(1, 3, 1, 1)) + \ 248 | criterion_L1(SR_left_res3 * V_left3.repeat(1, 3, 1, 1), SR_left_resT3 * V_left3.repeat(1, 3, 1, 1)) + \ 249 | criterion_L1(SR_right_res3 * V_right3.repeat(1, 3, 1, 1), SR_right_resT3 * V_right3.repeat(1, 3, 1, 1)) + \ 250 | criterion_L1(SR_left_res4 * V_left4.repeat(1, 3, 1, 1), SR_left_resT4 * V_left4.repeat(1, 3, 1, 1)) + \ 251 | criterion_L1(SR_right_res4 * V_right4.repeat(1, 3, 1, 1), SR_right_resT4 * V_right4.repeat(1, 3, 1, 1)) 252 | ''' Total Loss ''' 253 | loss = loss_SR + 0.1 * loss_cons + 0.1 * (loss_photo + loss_smooth + loss_cycle) + 0.001*loss_S + 0.01*loss_P 254 | optimizer.zero_grad() 255 | loss.backward() 256 | optimizer.step() 257 | 258 | psnr_epoch.append(cal_psnr(HR_left[:,:,:,64:].data.cpu(), SR_left4[:,:,:,64:].data.cpu())) 259 | psnr_epoch_r.append(cal_psnr(HR_right[:,:,:,:HR_right.shape[3]-64].data.cpu(), SR_right4[:,:,:,:HR_right.shape[3]-64].data.cpu())) 260 | 261 | psnr_epoch_m.append(cal_psnr(HR_left[:,:,:,64:].data.cpu(), SR_left2[:,:,:,64:].data.cpu())) 262 | psnr_epoch_r_m.append(cal_psnr(HR_right[:,:,:,:HR_right.shape[3]-64].data.cpu(), SR_right2[:,:,:,:HR_right.shape[3]-64].data.cpu())) 263 | loss_epoch.append(loss.data.cpu()) 264 | if idx_iter%300==0: 265 | print("SRloss: {:.4f} Photoloss: {:.5f} Smoothloss: {:.5f} Cycleloss: {:.5f} Consloss: {:.5f} Ploss: {:.5f} Sloss: {:.5f}".format(loss_SR.item(), 0.1*loss_photo.item(), 0.1*loss_smooth.item(), 0.1*loss_cycle.item(), 0.1*loss_cons.item(), 0.02*loss_P.item(), 0.001*loss_S.item())) 266 | print(torch.mean(V_left2)) 267 | print(torch.mean(V_left4)) 268 | 269 | scheduler.step() 270 | 271 | 272 | if idx_epoch % 1 == 0: 273 | loss_list.append(float(np.array(loss_epoch).mean())) 274 | 275 | print('Epoch--%4d, loss--%f, loss_SR--%f, loss_photo--%f, loss_smooth--%f, loss_cycle--%f, loss_cons--%f' % 276 | (idx_epoch + 1, float(np.array(loss_epoch).mean()), float(np.array(loss_SR.data.cpu()).mean()), 277 | float(np.array(loss_photo.data.cpu()).mean()), float(np.array(loss_smooth.data.cpu()).mean()), 278 | float(np.array(loss_cycle.data.cpu()).mean()), float(np.array(loss_cons.data.cpu()).mean()))) 279 | print('PSNR left---%f, PSNR right---%f' % (float(np.array(psnr_epoch).mean()), float(np.array(psnr_epoch_r).mean()))) 280 | print('intermediate PSNR left---%f, PSNR right---%f' % (float(np.array(psnr_epoch_m).mean()), float(np.array(psnr_epoch_r_m).mean()))) 281 | loss_epoch = [] 282 | psnr_epoch = [] 283 | psnr_epoch_r = [] 284 | psnr_epoch_m = [] 285 | psnr_epoch_r_m = [] 286 | 287 | torch.save({'epoch': idx_epoch + 1, 'state_dict': net.state_dict()}, 288 | 'checkpoints/' + cfg.model_name + '_' + str(cfg.scale_factor) + 'xSR_epoch' + str(idx_epoch + 1) + '.pth.tar') 289 | 290 | 291 | def main(cfg): 292 | train_set = TrainSetLoader(cfg) 293 | train_loader = DataLoader(dataset=train_set, num_workers=6, batch_size=cfg.batch_size, shuffle=True) 294 | train(train_loader, cfg) 295 | 296 | if __name__ == '__main__': 297 | cfg = parse_args() 298 | main(cfg) 299 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | from torch.utils.data.dataset import Dataset 4 | import random 5 | import torch 6 | import numpy as np 7 | from skimage import measure 8 | 9 | def read_disp(filename, subset=False): 10 | # Scene Flow dataset 11 | if filename.endswith('pfm'): 12 | # For finalpass and cleanpass, gt disparity is positive, subset is negative 13 | disp = np.ascontiguousarray(_read_pfm(filename)[0]) 14 | if subset: 15 | disp = -disp 16 | # KITTI 17 | elif filename.endswith('png'): 18 | disp = _read_kitti_disp(filename) 19 | elif filename.endswith('npy'): 20 | disp = np.load(filename) 21 | else: 22 | raise Exception('Invalid disparity file format!') 23 | return disp # [H, W] 24 | 25 | 26 | def _read_pfm(file): 27 | file = open(file, 'rb') 28 | 29 | color = None 30 | width = None 31 | height = None 32 | scale = None 33 | endian = None 34 | 35 | header = file.readline().rstrip() 36 | if header.decode("ascii") == 'PF': 37 | color = True 38 | elif header.decode("ascii") == 'Pf': 39 | color = False 40 | else: 41 | raise Exception('Not a PFM file.') 42 | 43 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) 44 | if dim_match: 45 | width, height = list(map(int, dim_match.groups())) 46 | else: 47 | raise Exception('Malformed PFM header.') 48 | 49 | scale = float(file.readline().decode("ascii").rstrip()) 50 | if scale < 0: # little-endian 51 | endian = '<' 52 | scale = -scale 53 | else: 54 | endian = '>' # big-endian 55 | 56 | data = np.fromfile(file, endian + 'f') 57 | shape = (height, width, 3) if color else (height, width) 58 | 59 | data = np.reshape(data, shape) 60 | data = np.flipud(data) 61 | return data, scale 62 | 63 | def _read_kitti_disp(filename): 64 | depth = np.array(Image.open(filename)) 65 | depth = depth.astype(np.float32) / 256. 66 | return depth 67 | 68 | class TrainSetLoader(Dataset): 69 | def __init__(self, cfg): 70 | super(TrainSetLoader, self).__init__() 71 | self.dataset_dir = cfg.trainset_dir + '/patches_x' + str(cfg.scale_factor) 72 | self.file_list = os.listdir(self.dataset_dir) 73 | def __getitem__(self, index): 74 | img_hr_left = Image.open(self.dataset_dir + '/' + self.file_list[index] + '/hr0.png') 75 | img_hr_right = Image.open(self.dataset_dir + '/' + self.file_list[index] + '/hr1.png') 76 | img_lr_left = Image.open(self.dataset_dir + '/' + self.file_list[index] + '/lr0.png') 77 | img_lr_right = Image.open(self.dataset_dir + '/' + self.file_list[index] + '/lr1.png') 78 | img_hr_left = np.array(img_hr_left, dtype=np.float32) 79 | img_hr_right = np.array(img_hr_right, dtype=np.float32) 80 | img_lr_left = np.array(img_lr_left, dtype=np.float32) 81 | img_lr_right = np.array(img_lr_right, dtype=np.float32) 82 | img_hr_left, img_hr_right, img_lr_left, img_lr_right = augmentation(img_hr_left, img_hr_right, img_lr_left, img_lr_right) 83 | return toTensor(img_hr_left), toTensor(img_hr_right), toTensor(img_lr_left), toTensor(img_lr_right) 84 | 85 | def __len__(self): 86 | return len(self.file_list) 87 | 88 | def cal_psnr(img1, img2): 89 | img1_np = np.array(img1) 90 | img2_np = np.array(img2) 91 | return measure.compare_psnr(img1_np, img2_np) 92 | 93 | def augmentation(hr_image_left, hr_image_right, lr_image_left, lr_image_right): 94 | flag=0 95 | if random.random()<0.5: # flip horizonly 96 | lr_image_left_ = lr_image_right[:, ::-1, :] 97 | lr_image_right_ = lr_image_left[:, ::-1, :] 98 | hr_image_left_ = hr_image_right[:, ::-1, :] 99 | hr_image_right_ = hr_image_left[:, ::-1, :] 100 | lr_image_left, lr_image_right = lr_image_left_, lr_image_right_ 101 | hr_image_left, hr_image_right = hr_image_left_, hr_image_right_ 102 | flag=1 103 | 104 | if random.random()<0.5: #flip vertically 105 | lr_image_left = lr_image_left[::-1, :, :] 106 | lr_image_right = lr_image_right[::-1, :, :] 107 | hr_image_left = hr_image_left[::-1, :, :] 108 | hr_image_right = hr_image_right[::-1, :, :] 109 | 110 | return np.ascontiguousarray(hr_image_left), np.ascontiguousarray(hr_image_right), \ 111 | np.ascontiguousarray(lr_image_left), np.ascontiguousarray(lr_image_right) 112 | 113 | def toTensor(img): 114 | img = torch.from_numpy(img.transpose((2, 0, 1))) 115 | return img.float().div(255) 116 | 117 | def D1_metric(D_est, D_gt, mask, threshold=3): 118 | mask = mask.byte() 119 | error = [] 120 | for i in range(D_gt.size(0)): 121 | D_est_, D_gt_ = D_est[i,...][mask[i,...]], D_gt[i,...][mask[i,...]] 122 | if len(D_gt_) > 0: 123 | E = torch.abs(D_gt_ - D_est_) 124 | err_mask = (E > threshold) & (E / D_gt_.abs() > 0.05) 125 | error.append(torch.mean(err_mask.float()).data.cpu()) 126 | return error 127 | 128 | 129 | def EPE_metric(D_est, D_gt, mask): 130 | mask = mask.byte() 131 | error = [] 132 | for i in range(D_gt.size(0)): 133 | D_est_, D_gt_ = D_est[i,...][mask[i,...]], D_gt[i,...][mask[i,...]] 134 | if len(D_gt_) > 0: 135 | error.append(F.l1_loss(D_est_, D_gt_, size_average=True).data.cpu()) 136 | return error 137 | 138 | 139 | 140 | ''' 141 | from PIL import Image 142 | import os 143 | from torch.utils.data.dataset import Dataset 144 | import random 145 | import torch 146 | import numpy as np 147 | from skimage import measure 148 | import torch.nn.functional as F 149 | 150 | def read_disp(filename, subset=False): 151 | # Scene Flow dataset 152 | if filename.endswith('pfm'): 153 | # For finalpass and cleanpass, gt disparity is positive, subset is negative 154 | disp = np.ascontiguousarray(_read_pfm(filename)[0]) 155 | if subset: 156 | disp = -disp 157 | # KITTI 158 | elif filename.endswith('png'): 159 | disp = _read_kitti_disp(filename) 160 | elif filename.endswith('npy'): 161 | disp = np.load(filename) 162 | else: 163 | raise Exception('Invalid disparity file format!') 164 | return disp # [H, W] 165 | 166 | 167 | def _read_pfm(file): 168 | file = open(file, 'rb') 169 | 170 | color = None 171 | width = None 172 | height = None 173 | scale = None 174 | endian = None 175 | 176 | header = file.readline().rstrip() 177 | if header.decode("ascii") == 'PF': 178 | color = True 179 | elif header.decode("ascii") == 'Pf': 180 | color = False 181 | else: 182 | raise Exception('Not a PFM file.') 183 | 184 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) 185 | if dim_match: 186 | width, height = list(map(int, dim_match.groups())) 187 | else: 188 | raise Exception('Malformed PFM header.') 189 | 190 | scale = float(file.readline().decode("ascii").rstrip()) 191 | if scale < 0: # little-endian 192 | endian = '<' 193 | scale = -scale 194 | else: 195 | endian = '>' # big-endian 196 | 197 | data = np.fromfile(file, endian + 'f') 198 | shape = (height, width, 3) if color else (height, width) 199 | 200 | data = np.reshape(data, shape) 201 | data = np.flipud(data) 202 | return data, scale 203 | 204 | def _read_kitti_disp(filename): 205 | depth = np.array(Image.open(filename)) 206 | depth = depth.astype(np.float32) / 256. 207 | return depth 208 | 209 | class TrainSetLoader(Dataset): 210 | def __init__(self, cfg): 211 | super(TrainSetLoader, self).__init__() 212 | self.dataset_dir = cfg.trainset_dir + '/patches_x' + str(cfg.scale_factor) 213 | self.file_list = os.listdir(self.dataset_dir) 214 | def __getitem__(self, index): 215 | img_hr_left = Image.open(self.dataset_dir + '/' + self.file_list[index] + '/hr0.png') 216 | img_hr_right = Image.open(self.dataset_dir + '/' + self.file_list[index] + '/hr1.png') 217 | img_lr_left = Image.open(self.dataset_dir + '/' + self.file_list[index] + '/lr0.png') 218 | img_lr_right = Image.open(self.dataset_dir + '/' + self.file_list[index] + '/lr1.png') 219 | img_hr_left = np.array(img_hr_left, dtype=np.float32) 220 | img_hr_right = np.array(img_hr_right, dtype=np.float32) 221 | img_lr_left = np.array(img_lr_left, dtype=np.float32) 222 | img_lr_right = np.array(img_lr_right, dtype=np.float32) 223 | img_hr_left, img_hr_right, img_lr_left, img_lr_right = augmentation(img_hr_left, img_hr_right, img_lr_left, img_lr_right) 224 | return toTensor(img_hr_left), toTensor(img_hr_right), toTensor(img_lr_left), toTensor(img_lr_right) 225 | 226 | def __len__(self): 227 | return len(self.file_list) 228 | 229 | def cal_psnr(img1, img2): 230 | img1_np = np.array(img1) 231 | img2_np = np.array(img2) 232 | return measure.compare_psnr(img1_np, img2_np) 233 | 234 | def augmentation(hr_image_left, hr_image_right, lr_image_left, lr_image_right): 235 | 236 | if random.random()<0.5: # flip horizonly 237 | lr_image_left_ = lr_image_right[:, ::-1, :] 238 | lr_image_right_ = lr_image_left[:, ::-1, :] 239 | hr_image_left_ = hr_image_right[:, ::-1, :] 240 | hr_image_right_ = hr_image_left[:, ::-1, :] 241 | lr_image_left, lr_image_right = lr_image_left_, lr_image_right_ 242 | hr_image_left, hr_image_right = hr_image_left_, hr_image_right_ 243 | 244 | 245 | if random.random()<0.5: #flip vertically 246 | lr_image_left = lr_image_left[::-1, :, :] 247 | lr_image_right = lr_image_right[::-1, :, :] 248 | hr_image_left = hr_image_left[::-1, :, :] 249 | hr_image_right = hr_image_right[::-1, :, :] 250 | 251 | return np.ascontiguousarray(hr_image_left), np.ascontiguousarray(hr_image_right), \ 252 | np.ascontiguousarray(lr_image_left), np.ascontiguousarray(lr_image_right) 253 | 254 | def toTensor(img): 255 | img = torch.from_numpy(img.transpose((2, 0, 1))) 256 | return img.float().div(255) 257 | 258 | def D1_metric(D_est, D_gt, mask, threshold=3): 259 | mask = mask.byte() 260 | error = [] 261 | for i in range(D_gt.size(0)): 262 | D_est_, D_gt_ = D_est[i,...][mask[i,...]], D_gt[i,...][mask[i,...]] 263 | if len(D_gt_) > 0: 264 | E = torch.abs(D_gt_ - D_est_) 265 | err_mask = (E > threshold) & (E / D_gt_.abs() > 0.05) 266 | error.append(torch.mean(err_mask.float()).data.cpu()) 267 | return error 268 | 269 | 270 | def EPE_metric(D_est, D_gt, mask): 271 | mask = mask.byte() 272 | error = [] 273 | for i in range(D_gt.size(0)): 274 | D_est_, D_gt_ = D_est[i,...][mask[i,...]], D_gt[i,...][mask[i,...]] 275 | if len(D_gt_) > 0: 276 | error.append(F.l1_loss(D_est_, D_gt_, size_average=True).data.cpu()) 277 | return error 278 | ''' --------------------------------------------------------------------------------