├── .DS_Store
├── README.md
├── checkpoints
├── .DS_Store
└── SSRDEF_4xSR_epoch80.pth.tar
├── common.py
├── data
├── .DS_Store
└── test
│ └── .DS_Store
├── datasets
├── __init__.py
├── data_io.py
├── kitti_dataset.py
└── sceneflow_dataset.py
├── figs
├── .DS_Store
├── Model.png
├── Results.png
└── Visual.png
├── loss.py
├── metric.py
├── model.py
├── result.txt
├── test_disp.py
├── test_sr.py
├── train.py
└── utils.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIVRC/SSRDEFNet-PyTorch/4d95c175fc60526fecbb42172a211fbd4472adf6/.DS_Store
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SSRDEFNet_PyTorch
2 | ### This repository is an official PyTorch implementation of the paper "Feedback Network for Mutually Boosted Stereo Image Super-Resolution and Disparity Estimation". (ACM MM 2021)
3 |
4 | ### Paper can be download from SSRDE-FNet
5 |
6 | Under stereo settings, the problem of image super-resolution (SR) and disparity estimation are interrelated that the result of each problem could help to solve the other. The effective exploitation of correspondence between different views facilitates the SR performance, while the high-resolution (HR) features with richer details benefit the correspondence estimation. According to this motivation, we propose a Stereo Super-Resolution and Disparity Estimation Feedback Network (SSRDE-FNet), which simultaneously handles the stereo image super-resolution and disparity estimation in a unified framework and interact them with each other to further improve their performance. Specifically, the SSRDE-FNet is composed of two dual recursive sub-networks for left and right views. Besides the cross-view information exploitation in the low-resolution (LR) space, HR representations produced by the SR process are utilized to perform HR disparity estimation with higher accuracy, through which the HR features can be aggregated to generate a finer SR result. Afterward, the proposed HR Disparity Information Feedback (HRDIF) mechanism delivers information carried by HR disparity back to previous layers to further refine the SR image reconstruction. Extensive experiments demonstrate the effectiveness and advancement of SSRDE-FNet.
7 |
8 |
9 |
10 |
11 |
12 | The pre-trained model for x4 SR is in ./checkpoints/SSRDEF_x4SR_epoch80.pth.tar
13 |
14 | All reconstructed images for x4 SR can be download from SSRDEFNet_Results
15 |
16 | All test datasets (Preprocessed HR images) can be download from SSRDEFNet_Test
17 |
18 | Extract the dataset and put them into the ./data/test/.
19 |
20 |
21 | ## Prerequisites:
22 | 1. Python 3.6
23 | 2. PyTorch >= 0.4.0
24 | 3. numpy
25 | 4. skimage
26 | 5. imageio
27 | 6. matplotlib
28 |
29 |
30 | ## Dataset
31 |
32 | We used Flickr1024 and Middlebury dataset to train our model, which is exactly the same as iPASSR. Please refer to their homepage and download the dataset as their guidance.
33 |
34 | Extract the dataset and put them into the ./data/train/.
35 |
36 | ##Training
37 |
38 | ```python
39 |
40 | python train.py --scale_factor 4
41 |
42 | ```
43 |
44 | ##Testing
45 |
46 | ```python
47 |
48 | # Testing stereo sr performance
49 |
50 | python test_sr.py
51 |
52 | # Testing disparity estimation performance
53 |
54 | python test_disp.py
55 |
56 | ```
57 |
58 | ##Performance
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 | ```
71 | @inproceedings{dai2021feedback,
72 | title={Feedback Network for Mutually Boosted Stereo Image Super-Resolution and Disparity Estimation},
73 | author={Dai, Qinyan and Li, Juncheng and Yi, Qiaosi and Fang, Faming and Zhang, Guixu},
74 | booktitle={Proceedings of the 29th ACM International Conference on Multimedia},
75 | pages={1985--1993},
76 | year={2021}
77 | }
78 | ```
79 |
80 | This implementation is for non-commercial research use only.
81 | If you find this code useful in your research, please cite the above papers.
--------------------------------------------------------------------------------
/checkpoints/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIVRC/SSRDEFNet-PyTorch/4d95c175fc60526fecbb42172a211fbd4472adf6/checkpoints/.DS_Store
--------------------------------------------------------------------------------
/checkpoints/SSRDEF_4xSR_epoch80.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIVRC/SSRDEFNet-PyTorch/4d95c175fc60526fecbb42172a211fbd4472adf6/checkpoints/SSRDEF_4xSR_epoch80.pth.tar
--------------------------------------------------------------------------------
/common.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | def convbn(in_planes, out_planes, kernel_size, stride, pad, dilation):
8 |
9 | return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=dilation if dilation > 1 else pad, dilation = dilation, bias=False),
10 | nn.BatchNorm2d(out_planes))
11 |
12 |
13 | def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True):
14 | return nn.Conv2d(
15 | in_channels, out_channels, kernel_size,
16 | padding=(kernel_size//2),stride=stride, bias=bias)
17 |
18 | class MeanShift(nn.Conv2d):
19 | def __init__(
20 | self, rgb_range,
21 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
22 |
23 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
24 | std = torch.Tensor(rgb_std)
25 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
26 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
27 | for p in self.parameters():
28 | p.requires_grad = False
29 |
30 | class BasicBlock(nn.Sequential):
31 | def __init__(
32 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True,
33 | bn=False, act=nn.PReLU()):
34 |
35 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
36 | if bn:
37 | m.append(nn.BatchNorm2d(out_channels))
38 | if act is not None:
39 | m.append(act)
40 |
41 | super(BasicBlock, self).__init__(*m)
42 |
43 | class ResBlock(nn.Module):
44 | def __init__(
45 | self, conv, n_feats, kernel_size,
46 | bias=True, bn=False, act=nn.PReLU()):
47 |
48 | super(ResBlock, self).__init__()
49 | m = []
50 | for i in range(2):
51 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
52 | if bn:
53 | m.append(nn.BatchNorm2d(n_feats))
54 | if i == 0:
55 | m.append(act)
56 |
57 | self.body = nn.Sequential(*m)
58 |
59 | def forward(self, x):
60 | res = self.body(x)
61 | res += x
62 |
63 | return res
64 |
65 | class ResBlockdilat(nn.Module):
66 | def __init__(
67 | self, n_feats, kernel_size, dilrate=1,
68 | bias=True, bn=False, act=nn.PReLU()):
69 |
70 | super(ResBlockdilat, self).__init__()
71 | m = []
72 | for i in range(2):
73 | m.append(nn.Conv2d(n_feats, n_feats, kernel_size, 1, dilrate, dilrate, bias=bias))
74 | if bn:
75 | m.append(nn.BatchNorm2d(n_feats))
76 | if i == 0:
77 | m.append(act)
78 |
79 | self.body = nn.Sequential(*m)
80 |
81 | def forward(self, x):
82 | res = self.body(x)
83 | res += x
84 |
85 | return res
86 |
87 | class Upsampler(nn.Sequential):
88 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
89 |
90 | m = []
91 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
92 | for _ in range(int(math.log(scale, 2))):
93 | m.append(conv(n_feats, 4 * n_feats, 3, bias))
94 | m.append(nn.PixelShuffle(2))
95 | if bn:
96 | m.append(nn.BatchNorm2d(n_feats))
97 | if act == 'relu':
98 | m.append(nn.ReLU(True))
99 | elif act == 'prelu':
100 | m.append(nn.PReLU(n_feats))
101 |
102 | elif scale == 3:
103 | m.append(conv(n_feats, 9 * n_feats, 3, bias))
104 | m.append(nn.PixelShuffle(3))
105 | if bn:
106 | m.append(nn.BatchNorm2d(n_feats))
107 | if act == 'relu':
108 | m.append(nn.ReLU(True))
109 | elif act == 'prelu':
110 | m.append(nn.PReLU(n_feats))
111 | else:
112 | raise NotImplementedError
113 |
114 | super(Upsampler, self).__init__(*m)
115 |
116 | class ConvBlock(torch.nn.Module):
117 | def __init__(self, input_size, output_size, kernel_size=3, stride=1, padding=1, bias=True, activation='prelu', norm=None):
118 | super(ConvBlock, self).__init__()
119 | self.conv = torch.nn.Conv2d(input_size, output_size, kernel_size, stride, padding, bias=bias)
120 |
121 | self.norm = norm
122 | if self.norm =='batch':
123 | self.bn = torch.nn.BatchNorm2d(output_size)
124 | elif self.norm == 'instance':
125 | self.bn = torch.nn.InstanceNorm2d(output_size)
126 |
127 | self.activation = activation
128 | if self.activation == 'relu':
129 | self.act = torch.nn.ReLU(True)
130 | elif self.activation == 'prelu':
131 | self.act = torch.nn.PReLU()
132 | elif self.activation == 'lrelu':
133 | self.act = torch.nn.LeakyReLU(0.2, True)
134 | elif self.activation == 'tanh':
135 | self.act = torch.nn.Tanh()
136 | elif self.activation == 'sigmoid':
137 | self.act = torch.nn.Sigmoid()
138 |
139 | def forward(self, x):
140 | if self.norm is not None:
141 | out = self.bn(self.conv(x))
142 | else:
143 | out = self.conv(x)
144 |
145 | if self.activation is not None:
146 | return self.act(out)
147 | else:
148 | return out
149 |
150 |
151 | class DeconvBlock(torch.nn.Module):
152 | def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True, activation='prelu', norm=None):
153 | super(DeconvBlock, self).__init__()
154 | self.deconv = torch.nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, bias=bias)
155 |
156 | self.norm = norm
157 | if self.norm == 'batch':
158 | self.bn = torch.nn.BatchNorm2d(output_size)
159 | elif self.norm == 'instance':
160 | self.bn = torch.nn.InstanceNorm2d(output_size)
161 |
162 | self.activation = activation
163 | if self.activation == 'relu':
164 | self.act = torch.nn.ReLU(True)
165 | elif self.activation == 'prelu':
166 | self.act = torch.nn.PReLU()
167 | elif self.activation == 'lrelu':
168 | self.act = torch.nn.LeakyReLU(0.2, True)
169 | elif self.activation == 'tanh':
170 | self.act = torch.nn.Tanh()
171 | elif self.activation == 'sigmoid':
172 | self.act = torch.nn.Sigmoid()
173 |
174 | def forward(self, x):
175 | if self.norm is not None:
176 | out = self.bn(self.deconv(x))
177 | else:
178 | out = self.deconv(x)
179 |
180 | if self.activation is not None:
181 | return self.act(out)
182 | else:
183 | return out
184 |
185 |
186 | class Bottle2neck(nn.Module):
187 |
188 | def __init__(self, inplanes, planes, stride=1, downsample=None, scale = 4, stype='normal'):
189 | """ Constructor
190 | Args:
191 | inplanes: input channel dimensionality
192 | planes: output channel dimensionality
193 | stride: conv stride. Replaces pooling layer.
194 | downsample: None when stride = 1
195 | baseWidth: basic width of conv3x3
196 | scale: number of scale.
197 | type: 'normal': normal set. 'stage': first block of a new stage.
198 | """
199 | super(Bottle2neck, self).__init__()
200 |
201 | width = planes//scale
202 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
203 |
204 | if scale == 1:
205 | self.nums = 1
206 | else:
207 | self.nums = scale -1
208 | if stype == 'stage':
209 | self.pool = nn.AvgPool2d(kernel_size=3, stride = stride, padding=1)
210 | convs = []
211 | for i in range(self.nums):
212 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride = stride, padding=1, bias=False))
213 | self.convs = nn.ModuleList(convs)
214 |
215 | self.conv3 = nn.Conv2d(width*scale, planes, kernel_size=1, bias=False)
216 |
217 | self.relu = nn.PReLU()
218 | self.downsample = downsample
219 | self.stype = stype
220 | self.scale = scale
221 | self.width = width
222 |
223 | def forward(self, x):
224 | residual = x
225 | out = self.conv1(x)
226 | out = self.relu(out)
227 |
228 | spx = torch.split(out, self.width, 1)
229 | for i in range(self.nums):
230 | if i==0 or self.stype=='stage':
231 | sp = spx[i]
232 | else:
233 | sp = sp + spx[i]
234 | sp = self.convs[i](sp)
235 | sp = self.relu(sp)
236 | if i==0:
237 | out = sp
238 | else:
239 | out = torch.cat((out, sp), 1)
240 | if self.scale != 1 and self.stype=='normal':
241 | out = torch.cat((out, spx[self.nums]),1)
242 | elif self.scale != 1 and self.stype=='stage':
243 | out = torch.cat((out, self.pool(spx[self.nums])),1)
244 |
245 | out = self.conv3(out)
246 |
247 | out += residual
248 | out = self.relu(out)
249 |
250 | return out
--------------------------------------------------------------------------------
/data/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIVRC/SSRDEFNet-PyTorch/4d95c175fc60526fecbb42172a211fbd4472adf6/data/.DS_Store
--------------------------------------------------------------------------------
/data/test/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIVRC/SSRDEFNet-PyTorch/4d95c175fc60526fecbb42172a211fbd4472adf6/data/test/.DS_Store
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .kitti_dataset import KITTIDataset
2 | from .sceneflow_dataset import SceneFlowDatset
3 |
4 | __datasets__ = {
5 | "sceneflow": SceneFlowDatset,
6 | "kitti": KITTIDataset
7 | }
8 |
--------------------------------------------------------------------------------
/datasets/data_io.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import re
3 | import sys
4 | import torchvision.transforms as transforms
5 | import torch
6 |
7 |
8 | def get_transform():
9 | mean = [0.485, 0.456, 0.406]
10 | std = [0.229, 0.224, 0.225]
11 |
12 | return transforms.Compose([
13 | transforms.ToTensor(),
14 | #transforms.Normalize(mean=mean, std=std),
15 | ])
16 |
17 |
18 | # def get_transform_kitti():
19 | # mean = [0.485, 0.456, 0.406]
20 | # std = [0.229, 0.224, 0.225]
21 | #
22 | # return transforms.Compose([
23 | # transforms.ToTensor(),
24 | # RandomPhotometric(
25 | # noise_stddev=0.0,
26 | # min_contrast=-0.3,
27 | # max_contrast=0.3,
28 | # brightness_stddev=0.02,
29 | # min_color=0.9,
30 | # max_color=1.1,
31 | # min_gamma=0.7,
32 | # max_gamma=1.5),
33 | # transforms.Normalize(mean=mean, std=std),
34 | # ])
35 |
36 |
37 | # read all lines in a file
38 | def read_all_lines(filename):
39 | with open(filename) as f:
40 | lines = [line.rstrip() for line in f.readlines()]
41 | return lines
42 |
43 |
44 | # read an .pfm file into numpy array, used to load SceneFlow disparity files
45 | def pfm_imread(filename):
46 | file = open(filename, 'rb')
47 | color = None
48 | width = None
49 | height = None
50 | scale = None
51 | endian = None
52 |
53 | header = file.readline().decode('utf-8').rstrip()
54 | if header == 'PF':
55 | color = True
56 | elif header == 'Pf':
57 | color = False
58 | else:
59 | raise Exception('Not a PFM file.')
60 |
61 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
62 | if dim_match:
63 | width, height = map(int, dim_match.groups())
64 | else:
65 | raise Exception('Malformed PFM header.')
66 |
67 | scale = float(file.readline().rstrip())
68 | if scale < 0: # little-endian
69 | endian = '<'
70 | scale = -scale
71 | else:
72 | endian = '>' # big-endian
73 |
74 | data = np.fromfile(file, endian + 'f')
75 | shape = (height, width, 3) if color else (height, width)
76 |
77 | data = np.reshape(data, shape)
78 | data = np.flipud(data)
79 | return data, scale
80 |
81 |
82 | def writePFM(file, image, scale=1):
83 | file = open(file, 'wb')
84 |
85 | color = None
86 |
87 | if image.dtype.name != 'float32':
88 | raise Exception('Image dtype must be float32.')
89 |
90 | image = np.flipud(image)
91 |
92 | if len(image.shape) == 3 and image.shape[2] == 3: # color image
93 | color = True
94 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale
95 | color = False
96 | else:
97 | raise Exception(
98 | 'Image must have H x W x 3, H x W x 1 or H x W dimensions.')
99 |
100 | file.write('PF\n'.encode('utf-8') if color else 'Pf\n'.encode('utf-8'))
101 | file.write('%d %d\n'.encode('utf-8') % (image.shape[1], image.shape[0]))
102 |
103 | endian = image.dtype.byteorder
104 |
105 | if endian == '<' or endian == '=' and sys.byteorder == 'little':
106 | scale = -scale
107 |
108 | file.write('%f\n'.encode('utf-8') % scale)
109 |
110 | image.tofile(file)
111 |
112 |
113 | class RandomPhotometric(object):
114 | """Applies photometric augmentations to a list of image tensors.
115 | Each image in the list is augmented in the same way.
116 |
117 | Args:
118 | ims: list of 3-channel images normalized to [0, 1].
119 |
120 | Returns:
121 | normalized images with photometric augmentations. Has the same
122 | shape as the input.
123 | """
124 | def __init__(self,
125 | noise_stddev=0.0,
126 | min_contrast=0.0,
127 | max_contrast=0.0,
128 | brightness_stddev=0.0,
129 | min_color=1.0,
130 | max_color=1.0,
131 | min_gamma=1.0,
132 | max_gamma=1.0):
133 | self.noise_stddev = noise_stddev
134 | self.min_contrast = min_contrast
135 | self.max_contrast = max_contrast
136 | self.brightness_stddev = brightness_stddev
137 | self.min_color = min_color
138 | self.max_color = max_color
139 | self.min_gamma = min_gamma
140 | self.max_gamma = max_gamma
141 |
142 | def __call__(self, ims):
143 | contrast = np.random.uniform(self.min_contrast, self.max_contrast)
144 | gamma = np.random.uniform(self.min_gamma, self.max_gamma)
145 | gamma_inv = 1.0 / gamma
146 | color = torch.from_numpy(
147 | np.random.uniform(self.min_color, self.max_color, (3))).float()
148 | if self.noise_stddev > 0.0:
149 | noise = np.random.normal(scale=self.noise_stddev)
150 | else:
151 | noise = 0
152 | if self.brightness_stddev > 0.0:
153 | brightness = np.random.normal(scale=self.brightness_stddev)
154 | else:
155 | brightness = 0
156 |
157 | im_re = ims.permute(1, 2, 0)
158 | im_re = (im_re * (contrast + 1.0) + brightness) * color
159 | im_re = torch.clamp(im_re, min=0.0, max=1.0)
160 | im_re = torch.pow(im_re, gamma_inv)
161 | im_re += noise
162 |
163 | out = im_re.permute(2, 0, 1)
164 |
165 | return out
--------------------------------------------------------------------------------
/datasets/kitti_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import torch
4 | from torch.utils.data import Dataset
5 | from PIL import Image
6 | import numpy as np
7 | from datasets.data_io import get_transform, read_all_lines
8 |
9 |
10 | class KITTIDataset(Dataset):
11 | def __init__(self, datapath, list_filename, training):
12 | self.datapath = datapath
13 | self.left_filenames, self.right_filenames, self.disp_occ_filenames, self.disp_noc_filenames = self.load_path(list_filename)
14 | self.training = training
15 | if self.training:
16 | assert self.disp_occ_filenames is not None
17 |
18 | def load_path(self, list_filename):
19 | lines = read_all_lines(list_filename)
20 | splits = [line.split() for line in lines]
21 | left_images = [x[0] for x in splits]
22 | right_images = [x[1] for x in splits]
23 | if len(splits[0]) == 2: # ground truth not available
24 | return left_images, right_images, None, None
25 | else:
26 | disp_occ_images = [x[2] for x in splits]
27 | disp_noc_images = [x[2].replace('occ', 'noc') for x in splits]
28 | return left_images, right_images, disp_occ_images, disp_noc_images
29 |
30 | def load_image(self, filename):
31 | return Image.open(filename).convert('RGB')
32 |
33 | def load_disp(self, filename):
34 | data = Image.open(filename)
35 | data = np.array(data, dtype=np.float32) / 256.
36 | return data
37 |
38 | def __len__(self):
39 | return len(self.left_filenames)
40 |
41 | def __getitem__(self, index):
42 | left_img = self.load_image(os.path.join(self.datapath, self.left_filenames[index]))
43 | right_img = self.load_image(os.path.join(self.datapath, self.right_filenames[index]))
44 |
45 | if self.disp_occ_filenames: # has disparity ground truth
46 | disp_occ = self.load_disp(os.path.join(self.datapath, self.disp_occ_filenames[index]))
47 | disp_noc = self.load_disp(os.path.join(self.datapath, self.disp_noc_filenames[index]))
48 | else:
49 | disp_occ = None
50 |
51 | if self.training:
52 | w, h = left_img.size
53 | crop_w, crop_h = 512, 256
54 |
55 | x1 = random.randint(0, w - crop_w)
56 | y1 = random.randint(0, h - crop_h)
57 |
58 | # random crop
59 | left_img = left_img.crop((x1, y1, x1 + crop_w, y1 + crop_h))
60 | right_img = right_img.crop((x1, y1, x1 + crop_w, y1 + crop_h))
61 | disp_occ = disp_occ[y1:y1 + crop_h, x1:x1 + crop_w]
62 | disp_noc = disp_noc[y1:y1 + crop_h, x1:x1 + crop_w]
63 | occ_mask = ((disp_occ - disp_noc) > 0).astype(np.float32)
64 |
65 | # to tensor, normalize
66 | processed = get_transform()
67 | left_img = processed(left_img)
68 | right_img = processed(right_img)
69 |
70 | # # augumentation
71 | # if random.random() < 0.5:
72 | # left_img = torch.flip(left_img, [1])
73 | # right_img = torch.flip(right_img, [1])
74 | # disp_occ = np.ascontiguousarray(np.flip(disp_occ, 0))
75 |
76 | return {"left": left_img,
77 | "right": right_img,
78 | "left_disp": disp_occ,
79 | "occ_mask": occ_mask}
80 | else:
81 | w, h = left_img.size
82 |
83 | # normalize
84 | processed = get_transform()
85 | left_img = processed(left_img).numpy()
86 | right_img = processed(right_img).numpy()
87 |
88 | # pad to size 1248x384
89 | top_pad = 380 - h
90 | right_pad = 1244 - w
91 | assert top_pad > 0 and right_pad > 0
92 | # pad images
93 | left_img = np.lib.pad(left_img, ((0, 0), (top_pad, 0), (0, right_pad)), mode='edge')
94 | right_img = np.lib.pad(right_img, ((0, 0), (top_pad, 0), (0, right_pad)), mode='edge')
95 |
96 |
97 | # pad disparity gt
98 | if self.disp_occ_filenames is not None:
99 | # assert len(self.disp_occ_filenames.shape) == 2
100 | disp_occ = np.lib.pad(disp_occ, ((top_pad, 0), (0, right_pad)), mode='constant', constant_values=0)
101 |
102 | if self.disp_occ_filenames is not None:
103 | return {"HR_left": left_img,
104 | "HR_right": right_img,
105 | "left_disp": disp_occ,
106 | "top_pad": top_pad,
107 | "right_pad": right_pad
108 | }
109 | else:
110 | return {"HR_left": left_img,
111 | "HR_right": right_img,
112 | "top_pad": top_pad,
113 | "right_pad": right_pad,
114 | "left_filename": self.left_filenames[index],
115 | "right_filename": self.right_filenames[index]
116 | }
117 |
--------------------------------------------------------------------------------
/datasets/sceneflow_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import torch
4 | import numpy as np
5 | import torch.nn.functional as F
6 | from torch.utils.data import Dataset
7 | from PIL import Image
8 | from datasets.data_io import get_transform, read_all_lines, pfm_imread
9 |
10 |
11 | class SceneFlowDatset(Dataset):
12 | def __init__(self, datapath, list_filename, training):
13 | self.datapath = datapath
14 | self.left_filenames, self.right_filenames, self.left_disp_filenames, self.right_disp_filenames = self.load_path(list_filename)
15 | self.training = training
16 |
17 | def load_path(self, list_filename):
18 | lines = read_all_lines(list_filename)
19 | splits = [line.split() for line in lines]
20 | left_images = [x[0] for x in splits]
21 | right_images = [x[1] for x in splits]
22 | left_disp = [x[2] for x in splits]
23 | right_disp = [x[2][:-13]+'right/'+x[2][-8:] for x in splits]
24 |
25 | return left_images, right_images, left_disp, right_disp
26 |
27 | def load_image(self, filename):
28 | return Image.open(filename).convert('RGB')
29 |
30 | def load_disp(self, filename):
31 | data, scale = pfm_imread(filename)
32 | data = np.ascontiguousarray(data, dtype=np.float32)
33 | return data
34 |
35 | def __len__(self):
36 | return len(self.left_filenames)
37 |
38 | def __getitem__(self, index):
39 | left_img = self.load_image(os.path.join(self.datapath, self.left_filenames[index]))
40 | right_img = self.load_image(os.path.join(self.datapath, self.right_filenames[index]))
41 | left_disp = self.load_disp(os.path.join(self.datapath, self.left_disp_filenames[index]))
42 | right_disp = self.load_disp(os.path.join(self.datapath, self.right_disp_filenames[index]))
43 |
44 | if self.training:
45 | w, h = left_img.size
46 | crop_w, crop_h = 512, 256
47 |
48 | x1 = random.randint(0, w - crop_w)
49 | y1 = random.randint(0, h - crop_h)
50 |
51 | # random crop
52 | left_img = left_img.crop((x1, y1, x1 + crop_w, y1 + crop_h))
53 | right_img = right_img.crop((x1, y1, x1 + crop_w, y1 + crop_h))
54 | left_disp = left_disp[y1:y1 + crop_h, x1:x1 + crop_w]
55 | right_disp = right_disp[y1:y1 + crop_h, x1:x1 + crop_w]
56 |
57 | # to tensor, normalize
58 | processed = get_transform()
59 | left_img = processed(left_img)
60 | right_img = processed(right_img)
61 |
62 | # augumentation
63 | # if random.random()<0.5:
64 | # left_img = torch.flip(left_img, [1])
65 | # right_img = torch.flip(right_img, [1])
66 | # left_disp = np.ascontiguousarray(np.flip(left_disp, 0))
67 | # right_disp = np.ascontiguousarray(np.flip(right_disp, 0))
68 |
69 | return {"left": left_img,
70 | "right": right_img,
71 | "left_disp": left_disp,
72 | "right_disp": right_disp}
73 | else:
74 | w, h = left_img.size
75 | crop_w, crop_h = 960, 512
76 |
77 | left_img = left_img.crop((w - crop_w, h - crop_h, w, h))
78 | right_img = right_img.crop((w - crop_w, h - crop_h, w, h))
79 | disparity = left_disp[h - crop_h:h, w - crop_w: w]
80 | disparity_right = right_disp[h - crop_h:h, w - crop_w: w]
81 |
82 | processed = get_transform()
83 | left_img = processed(left_img)
84 | right_img = processed(right_img)
85 |
86 | return {"left": left_img,
87 | "right": right_img,
88 | "left_disp": disparity,
89 | "right_disp": disparity_right,
90 | "top_pad": 0,
91 | "right_pad": 0}
92 |
--------------------------------------------------------------------------------
/figs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIVRC/SSRDEFNet-PyTorch/4d95c175fc60526fecbb42172a211fbd4472adf6/figs/.DS_Store
--------------------------------------------------------------------------------
/figs/Model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIVRC/SSRDEFNet-PyTorch/4d95c175fc60526fecbb42172a211fbd4472adf6/figs/Model.png
--------------------------------------------------------------------------------
/figs/Results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIVRC/SSRDEFNet-PyTorch/4d95c175fc60526fecbb42172a211fbd4472adf6/figs/Results.png
--------------------------------------------------------------------------------
/figs/Visual.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIVRC/SSRDEFNet-PyTorch/4d95c175fc60526fecbb42172a211fbd4472adf6/figs/Visual.png
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from math import exp
5 | from utils import *
6 |
7 |
8 | def L1Loss(input, target):
9 | return (input - target).abs().mean()
10 |
11 |
12 | def loss_disp_unsupervised(img_left, img_right, disp, valid_mask=None):
13 | b, _, h, w = img_left.shape
14 | image_warped = warp_disp(img_right, disp)
15 |
16 | valid_mask = torch.ones(b, 1, h, w).to(img_left.device) if valid_mask is None else valid_mask
17 |
18 | loss = 0.15 * L1Loss(image_warped * valid_mask, img_left * valid_mask) + \
19 | 0.85 * (valid_mask * (1 - ssim(img_left, image_warped)) / 2).mean()
20 | return loss
21 |
22 |
23 | def loss_disp_smoothness(disp, img):
24 | img_grad_x = img[:, :, :, :-1] - img[:, :, :, 1:]
25 | img_grad_y = img[:, :, :-1, :] - img[:, :, 1:, :]
26 | weight_x = torch.exp(-torch.abs(img_grad_x).mean(1).unsqueeze(1))
27 | weight_y = torch.exp(-torch.abs(img_grad_y).mean(1).unsqueeze(1))
28 |
29 | loss = (((disp[:, :, :, :-1] - disp[:, :, :, 1:]).abs() * weight_x).sum() +
30 | ((disp[:, :, :-1, :] - disp[:, :, 1:, :]).abs() * weight_y).sum()) / \
31 | (weight_x.sum() + weight_y.sum())
32 |
33 | return loss
34 |
35 |
36 | def loss_pam_photometric(img_left, img_right, att, valid_mask, mask=None):
37 | weight = [0.2, 0.3, 0.5]
38 | loss = torch.zeros(1).to(img_left.device)
39 |
40 | for idx_scale in range(len(att)):
41 | scale = img_left.size()[2] // valid_mask[idx_scale][0].size()[2]
42 | b, c, h, w = valid_mask[idx_scale][0].size()
43 |
44 | att_right2left = att[idx_scale][0] # b * h * w * w
45 | att_left2right = att[idx_scale][1]
46 | valid_mask_left = valid_mask[idx_scale][0] # b * 1 * h * w
47 | valid_mask_right = valid_mask[idx_scale][1]
48 |
49 | if mask is not None:
50 | valid_mask_left = valid_mask_left * (nn.AvgPool2d(scale)(mask[0].float()) > 0).float()
51 | valid_mask_right = valid_mask_right * (nn.AvgPool2d(scale)(mask[1].float()) > 0).float()
52 |
53 | img_left_scale = F.interpolate(img_left, scale_factor=1/scale, mode='bilinear')
54 | img_right_scale = F.interpolate(img_right, scale_factor=1/scale, mode='bilinear')
55 |
56 | img_right_warp = torch.matmul(att_right2left, img_right_scale.permute(0, 2, 3, 1).contiguous())
57 | img_right_warp = img_right_warp.permute(0, 3, 1, 2)
58 | img_left_warp = torch.matmul(att_left2right, img_left_scale.permute(0, 2, 3, 1).contiguous())
59 | img_left_warp = img_left_warp.permute(0, 3, 1, 2)
60 |
61 | loss_scale = L1Loss(img_left_scale * valid_mask_left, img_right_warp * valid_mask_left) + \
62 | L1Loss(img_right_scale * valid_mask_right, img_left_warp * valid_mask_right)
63 |
64 | loss = loss + weight[idx_scale] * loss_scale
65 |
66 | return loss
67 |
68 |
69 | def loss_pam_cycle(att_cycle, valid_mask):
70 | weight = [0.2, 0.3, 0.5]
71 | loss = torch.zeros(1).to(att_cycle[0][0].device)
72 |
73 | for idx_scale in range(len(att_cycle)):
74 | b, c, h, w = valid_mask[idx_scale][0].shape
75 | I = torch.eye(w, w).repeat(b, h, 1, 1).to(att_cycle[0][0].device)
76 |
77 | att_left2right2left = att_cycle[idx_scale][0]
78 | att_right2left2right = att_cycle[idx_scale][1]
79 | valid_mask_left = valid_mask[idx_scale][0]
80 | valid_mask_right = valid_mask[idx_scale][1]
81 |
82 | loss_scale = L1Loss(att_left2right2left * valid_mask_left.permute(0, 2, 3, 1), I * valid_mask_left.permute(0, 2, 3, 1)) + \
83 | L1Loss(att_right2left2right * valid_mask_right.permute(0, 2, 3, 1), I * valid_mask_right.permute(0, 2, 3, 1))
84 |
85 | loss = loss + weight[idx_scale] * loss_scale
86 |
87 | return loss
88 |
89 |
90 | def loss_pam_smoothness(att):
91 | weight = [0.2, 0.3, 0.5]
92 | loss = torch.zeros(1).to(att[0][0].device)
93 |
94 | for idx_scale in range(len(att)):
95 | att_right2left = att[idx_scale][0]
96 | att_left2right = att[idx_scale][1]
97 |
98 | loss_scale = L1Loss(att_right2left[:, :-1, :, :], att_right2left[:, 1:, :, :]) + \
99 | L1Loss(att_left2right[:, :-1, :, :], att_left2right[:, 1:, :, :]) + \
100 | L1Loss(att_right2left[:, :, :-1, :-1], att_right2left[:, :, 1:, 1:]) + \
101 | L1Loss(att_left2right[:, :, :-1, :-1], att_left2right[:, :, 1:, 1:])
102 |
103 | loss = loss + weight[idx_scale] * loss_scale
104 |
105 | return loss
106 |
107 |
108 | def warp_disp(img, disp):
109 | '''
110 | Borrowed from: https://github.com/OniroAI/MonoDepth-PyTorch
111 | '''
112 | b, _, h, w = img.size()
113 |
114 | # Original coordinates of pixels
115 | x_base = torch.linspace(0, 1, w).repeat(b, h, 1).type_as(img)
116 | y_base = torch.linspace(0, 1, h).repeat(b, w, 1).transpose(1, 2).type_as(img)
117 |
118 | # Apply shift in X direction
119 | x_shifts = disp[:, 0, :, :] / w
120 | flow_field = torch.stack((x_shifts, y_base), dim=3)
121 |
122 | # In grid_sample coordinates are assumed to be between -1 and 1
123 | output = F.grid_sample(img, 2 * flow_field - 1, mode='bilinear', padding_mode='border')
124 |
125 | return output
126 |
127 |
128 | def gaussian(window_size, sigma):
129 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
130 | return gauss / gauss.sum()
131 |
132 |
133 | def create_window(window_size, channel):
134 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
135 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
136 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
137 | return window
138 |
139 |
140 | def _ssim(img1, img2, window, window_size, channel):
141 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
142 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
143 |
144 | mu1_sq = mu1.pow(2)
145 | mu2_sq = mu2.pow(2)
146 | mu1_mu2 = mu1 * mu2
147 |
148 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
149 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
150 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
151 |
152 | C1 = 0.01 ** 2
153 | C2 = 0.03 ** 2
154 |
155 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
156 |
157 | return ssim_map
158 |
159 |
160 | def ssim(img1, img2, window_size=11):
161 | _, channel, h, w = img1.size()
162 | window = create_window(window_size, channel)
163 | if img1.is_cuda:
164 | window = window.cuda(img1.get_device())
165 | window = window.type_as(img1)
166 | return _ssim(img1, img2, window, window_size, channel)
167 |
--------------------------------------------------------------------------------
/metric.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | EPSILON = 1e-8
5 |
6 |
7 | def epe_metric(d_est, d_gt, mask, use_np=False):
8 | d_est, d_gt = d_est[mask], d_gt[mask]
9 | if use_np:
10 | epe = np.mean(np.abs(d_est - d_gt))
11 | else:
12 | epe = torch.mean(torch.abs(d_est - d_gt))
13 |
14 | return epe
15 |
16 |
17 | def d1_metric(d_est, d_gt, mask, err, use_np=False):
18 | d_est, d_gt = d_est[mask], d_gt[mask]
19 |
20 | if use_np:
21 | e = np.abs(d_gt - d_est)
22 | else:
23 | e = torch.abs(d_gt - d_est)
24 | err_mask = (e > err) & (e / d_gt > 0.05)
25 |
26 | if use_np:
27 | mean = np.mean(err_mask.astype('float'))
28 | else:
29 | mean = torch.mean(err_mask.float())
30 |
31 | return mean
32 |
33 |
34 | def thres_metric(d_est, d_gt, mask, thres, use_np=False):
35 | assert isinstance(thres, (int, float))
36 | d_est, d_gt = d_est[mask], d_gt[mask]
37 | if use_np:
38 | e = np.abs(d_gt - d_est)
39 | else:
40 | e = torch.abs(d_gt - d_est)
41 | err_mask = e > thres
42 |
43 | if use_np:
44 | mean = np.mean(err_mask.astype('float'))
45 | else:
46 | mean = torch.mean(err_mask.float())
47 |
48 | return mean
49 |
50 | def compute_depth_errors(gt, pred):
51 | """Computation of error metrics between predicted and ground truth depths
52 | """
53 | thresh = torch.max((gt / pred), (pred / gt))
54 | a1 = (thresh < 1.25 ).float().mean()
55 | a2 = (thresh < 1.25 ** 2).float().mean()
56 | a3 = (thresh < 1.25 ** 3).float().mean()
57 |
58 | rmse = (gt - pred) ** 2
59 | rmse = torch.sqrt(rmse.mean())
60 |
61 | rmse_log = (torch.log(gt) - torch.log(pred)) ** 2
62 | rmse_log = torch.sqrt(rmse_log.mean())
63 |
64 | abs_rel = torch.mean(torch.abs(gt - pred) / gt)
65 |
66 | sq_rel = torch.mean((gt - pred) ** 2 / gt)
67 |
68 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | from skimage import morphology
7 | from torchvision import transforms
8 | from common import *
9 |
10 | class SSRDEFNet(nn.Module):
11 | def __init__(self, upscale_factor):
12 | super(SSRDEFNet, self).__init__()
13 | self.upscale_factor = upscale_factor
14 | if upscale_factor == 2:
15 | kernel = 6
16 | stride = 2
17 | padding = 2
18 | elif upscale_factor == 4:
19 | kernel = 8
20 | stride = 4
21 | padding = 2
22 | self.init_feature = nn.Conv2d(3, 64, 3, 1, 1, bias=True)
23 | self.deep_feature = RDG(G0=64, C=4, G=24, n_RDB=4)
24 | self.pam = PAM(64)
25 | self.transition = nn.Sequential(
26 | nn.BatchNorm2d(64),
27 | ResB(64)
28 | )
29 | self.transition2 = nn.Sequential(
30 | nn.BatchNorm2d(64),
31 | ResB(64)
32 | )
33 | self.StereoFea = Stereo_feature()
34 | self.StereoFeaHigh = hStereo_feature()
35 |
36 | self.encode = RDB(G0=64, C=6, G=24)
37 | self.encoder2 = RDB(G0=64, C=6, G=24)
38 | self.CALayer2 = CALayer(64, 8)
39 | self.CALayer = CALayer(64, 8)
40 |
41 | self.reconstruct = RDG(G0=64, C=4, G=24, n_RDB=4)
42 | self.upscale = nn.Sequential(
43 | nn.Conv2d(64, 64 * upscale_factor ** 2, 1, 1, 0, bias=True),
44 | nn.PixelShuffle(upscale_factor))
45 | self.final = nn.Conv2d(64, 3, 3, 1, 1, bias=True)
46 | self.final2 = nn.Conv2d(64, 3, 3, 1, 1, bias=True)
47 |
48 | self.get_cv = GetCostVolume(24)
49 | self.dres0 = nn.Sequential(convbn(24, 24, 3, 1, 1, 1),
50 | nn.PReLU(),
51 | convbn(24, 24, 3, 1, 1, 1),
52 | nn.PReLU())
53 |
54 | self.dres1 = nn.Sequential(convbn(24, 24, 3, 1, 1, 1),
55 | nn.PReLU(),
56 | convbn(24, 24, 3, 1, 1, 1))
57 | self.dres2 = hourglass(24)
58 | self.softmax = nn.Softmax(1)
59 |
60 | self.att1 = nn.Conv2d(64, 32, 1, 1, 0)
61 | self.att2 = nn.Conv2d(64, 32, 1, 1, 0)
62 |
63 | self.backatt1 = nn.Conv2d(64, 32, 1, 1, 0)
64 | self.backatt2 = nn.Conv2d(64, 32, 1, 1, 0)
65 |
66 | self.feedbackatt = FeedbackBlock(64, 64, kernel, stride, padding)
67 | self.down = ConvBlock(64, 64, kernel, stride, padding, activation='prelu', norm=None)
68 | self.compress_in = nn.Conv2d(64*2, 64, 1, bias=True)
69 | self.resblock = RDB(G0=64, C=2, G=24)
70 |
71 | def forward(self, x_left, x_right, is_training):
72 | x_left_upscale = F.interpolate(x_left, scale_factor=self.upscale_factor, mode='bicubic', align_corners=False)
73 | x_right_upscale = F.interpolate(x_right, scale_factor=self.upscale_factor, mode='bicubic', align_corners=False)
74 | buffer_left = self.init_feature(x_left)
75 | buffer_right = self.init_feature(x_right)
76 | buffer_left = self.deep_feature(buffer_left)
77 | buffer_right = self.deep_feature(buffer_right)
78 |
79 | stereo_left = self.StereoFea(self.transition(buffer_left))
80 | stereo_right = self.StereoFea(self.transition(buffer_right))
81 |
82 | b,c,h,w = buffer_left.shape
83 |
84 | cost = [
85 | torch.zeros(b*h, w, w).to(buffer_left.device),
86 | torch.zeros(b*h, w, w).to(buffer_left.device)
87 | ]
88 |
89 | buffer_leftT, buffer_rightT, disp1, disp2, (M_right_to_left, M_left_to_right), (V_left, V_right)\
90 | = self.pam(buffer_left, buffer_right, stereo_left, stereo_right, cost, is_training)
91 |
92 |
93 | buffer_leftF = self.CALayer(buffer_left + self.encode(buffer_leftT - buffer_left))
94 | buffer_rightF = self.CALayer(buffer_right + self.encode(buffer_rightT - buffer_right))
95 |
96 |
97 | buffer_leftF = self.reconstruct(buffer_leftF)
98 | buffer_rightF = self.reconstruct(buffer_rightF)
99 | feat_left = self.upscale(buffer_leftF)
100 | feat_right = self.upscale(buffer_rightF)
101 | out1_left = self.final(feat_left)+x_left_upscale
102 | out1_right = self.final(feat_right)+x_right_upscale
103 |
104 | hstereo_left = self.StereoFeaHigh(self.transition2(feat_left))
105 | hstereo_right = self.StereoFeaHigh(self.transition2(feat_right))
106 |
107 |
108 | disp1 = F.interpolate(disp1 * self.upscale_factor, scale_factor=self.upscale_factor, mode='bilinear', align_corners=False)
109 | disp2 = F.interpolate(disp2 * self.upscale_factor, scale_factor=self.upscale_factor, mode='bilinear', align_corners=False)
110 | maxdisp = x_left.shape[3]*self.upscale_factor
111 |
112 |
113 | disp_range_samples1 = get_disp_range_samples(cur_disp=disp1.detach().squeeze(1), ndisp=24,
114 | shape=[x_left.shape[0], x_left.shape[2]*self.upscale_factor, x_left.shape[3]*self.upscale_factor],
115 | max_disp=maxdisp)
116 | disp_range_samples2 = get_disp_range_samples(cur_disp=disp2.detach().squeeze(1), ndisp=24,
117 | shape=[x_left.shape[0], x_left.shape[2]*self.upscale_factor, x_left.shape[3]*self.upscale_factor],
118 | max_disp=maxdisp)
119 |
120 | cost1, cost2 = self.get_cv(hstereo_left, hstereo_right, disp_range_samples1, disp_range_samples2, 24)
121 |
122 | cost1 = cost1.contiguous()
123 | cost2 = cost2.contiguous()
124 |
125 | cost1_0 = self.dres0(cost1)
126 | cost1_0 = self.dres1(cost1_0) + cost1_0
127 | cost2_0 = self.dres0(cost2)
128 | cost2_0 = self.dres1(cost2_0) + cost2_0
129 |
130 | out1 = self.dres2(cost1_0, None, None)
131 | cost1_1 = out1+cost1_0
132 | out2 = self.dres2(cost2_0, None, None)
133 | cost2_1 = out2+cost2_0
134 |
135 | cost_prob1 = self.softmax(cost1_1)
136 | cost_prob2 = self.softmax(cost2_1)
137 |
138 | disp1_high = torch.sum(disp_range_samples1 * cost_prob1, dim=1).unsqueeze(1)
139 | disp2_high = torch.sum(disp_range_samples2 * cost_prob2, dim=1).unsqueeze(1)
140 |
141 | feat_leftW = dispwarpfeature(feat_right, disp1_high)
142 | feat_rightW = dispwarpfeature(feat_left, disp2_high)
143 |
144 | geoerror_left = torch.abs(disp1_high - dispwarpfeature(disp2_high, disp1_high)).detach()
145 | geoerror_right = torch.abs(disp2_high - dispwarpfeature(disp1_high, disp2_high)).detach()
146 |
147 | V_left2 = 1 - torch.tanh(0.1*geoerror_left)
148 | V_right2 = 1 - torch.tanh(0.1*geoerror_right)
149 |
150 | V_left2 = torch.max(F.interpolate(V_left, scale_factor=self.upscale_factor, mode='nearest'), V_left2)
151 | V_right2 = torch.max(F.interpolate(V_right, scale_factor=self.upscale_factor, mode='nearest'), V_right2)
152 |
153 | left_att = self.att1(feat_left)
154 | leftW_att = self.att2(feat_leftW)
155 | corrleft = (torch.tanh(5*torch.sum(left_att*leftW_att, 1).unsqueeze(1))+1)/2
156 |
157 | right_att = self.att1(feat_right)
158 | rightW_att = self.att2(feat_rightW)
159 | corrright = (torch.tanh(5*torch.sum(right_att*rightW_att, 1).unsqueeze(1))+1)/2
160 |
161 | err1 = self.encoder2((feat_leftW - feat_left)*corrleft)
162 | buffer_leftF2 = self.CALayer2(err1 + feat_left) #high resolution feature that contains high resolution information of the other image through high res stereo matching
163 |
164 | err2 = self.encoder2((feat_rightW - feat_right)*corrright)
165 | buffer_rightF2 = self.CALayer2(err2 + feat_right)
166 |
167 | out2_left = self.final2(buffer_leftF2)+x_left_upscale
168 | out2_right = self.final2(buffer_rightF2)+x_right_upscale
169 |
170 | #feedback start
171 |
172 | left_back = self.down(buffer_leftF2)
173 | right_back = self.down(buffer_rightF2)
174 |
175 | att1 = self.feedbackatt(left_back)
176 | att2 = self.feedbackatt(right_back)
177 |
178 | left_back = left_back + 0.1*(left_back*att1)
179 | right_back = right_back + 0.1*(right_back*att2)
180 |
181 |
182 | bufferleft_att = self.backatt1(buffer_left)
183 | bufferright_att = self.backatt1(buffer_right)
184 |
185 |
186 | for ii in range(self.upscale_factor):
187 | for jj in range(self.upscale_factor):
188 | draft_l = dispwarpfeature(buffer_right, disp1_high[:, :, ii::self.upscale_factor, jj::self.upscale_factor]/self.upscale_factor)
189 | draft_r = dispwarpfeature(buffer_left, disp2_high[:, :, ii::self.upscale_factor, jj::self.upscale_factor]/self.upscale_factor)
190 | draftl_att = self.backatt2(draft_l)
191 | draftr_att = self.backatt2(draft_r)
192 | corrleft = (torch.tanh(5*torch.sum(bufferleft_att*draftl_att, 1).unsqueeze(1))+1)/2
193 | corrright = (torch.tanh(5*torch.sum(bufferright_att*draftr_att, 1).unsqueeze(1))+1)/2
194 | draft_l = (1-corrleft)*buffer_left+corrleft*draft_l
195 | draft_r = (1-corrright)*buffer_right+corrright*draft_r
196 | if ii==0 and jj==0:
197 | draft_left = buffer_left + self.resblock(draft_l - buffer_left)
198 | draft_right = buffer_right + self.resblock(draft_r - buffer_right)
199 | else:
200 | draft_left += buffer_left + self.resblock(draft_l - buffer_left)
201 | draft_right += buffer_right + self.resblock(draft_r - buffer_right)
202 |
203 | draft_left = draft_left/(self.upscale_factor**2)
204 | draft_right = draft_right/(self.upscale_factor**2)
205 |
206 | buffer_left = self.compress_in(torch.cat([draft_left, left_back], 1))
207 | buffer_right = self.compress_in(torch.cat([draft_right, right_back], 1))
208 |
209 | stereo_left = self.StereoFea(self.transition(buffer_left))
210 | stereo_right = self.StereoFea(self.transition(buffer_right))
211 |
212 | cost = [
213 | torch.zeros(b*h, w, w).to(buffer_left.device),
214 | torch.zeros(b*h, w, w).to(buffer_left.device)
215 | ]
216 |
217 | buffer_leftT, buffer_rightT, disp1_3, disp2_3, (M_right_to_left3, M_left_to_right3), (V_left3, V_right3)\
218 | = self.pam(buffer_left, buffer_right, stereo_left, stereo_right, cost, is_training)
219 |
220 | buffer_leftF = self.CALayer(buffer_left + self.encode(buffer_leftT - buffer_left))
221 | buffer_rightF = self.CALayer(buffer_right + self.encode(buffer_rightT - buffer_right))
222 |
223 |
224 | buffer_leftF = self.reconstruct(buffer_leftF)
225 | buffer_rightF = self.reconstruct(buffer_rightF)
226 | feat_left = self.upscale(buffer_leftF)
227 | feat_right = self.upscale(buffer_rightF)
228 | out3_left = self.final(feat_left)+x_left_upscale
229 | out3_right = self.final(feat_right)+x_right_upscale
230 |
231 | hstereo_left = self.StereoFeaHigh(self.transition2(feat_left))
232 | hstereo_right = self.StereoFeaHigh(self.transition2(feat_right))
233 |
234 |
235 | disp1_3 = F.interpolate(disp1_3 * self.upscale_factor, scale_factor=self.upscale_factor, mode='bilinear', align_corners=False)
236 | disp2_3 = F.interpolate(disp2_3 * self.upscale_factor, scale_factor=self.upscale_factor, mode='bilinear', align_corners=False)
237 | maxdisp = x_left.shape[3]*self.upscale_factor
238 |
239 |
240 | disp_range_samples1 = get_disp_range_samples(cur_disp=disp1_3.detach().squeeze(1), ndisp=24,
241 | shape=[x_left.shape[0], x_left.shape[2]*self.upscale_factor, x_left.shape[3]*self.upscale_factor],
242 | max_disp=maxdisp)
243 | disp_range_samples2 = get_disp_range_samples(cur_disp=disp2_3.detach().squeeze(1), ndisp=24,
244 | shape=[x_left.shape[0], x_left.shape[2]*self.upscale_factor, x_left.shape[3]*self.upscale_factor],
245 | max_disp=maxdisp)
246 |
247 | cost1, cost2 = self.get_cv(hstereo_left, hstereo_right, disp_range_samples1, disp_range_samples2, 24)
248 | cost1 = cost1.contiguous()
249 | cost2 = cost2.contiguous()
250 |
251 | cost1_0 = self.dres0(cost1)
252 | cost1_0 = self.dres1(cost1_0) + cost1_0
253 | cost2_0 = self.dres0(cost2)
254 | cost2_0 = self.dres1(cost2_0) + cost2_0
255 |
256 | out1 = self.dres2(cost1_0, None, None)
257 | cost1_1 = out1+cost1_0
258 | out2 = self.dres2(cost2_0, None, None)
259 | cost2_1 = out2+cost2_0
260 |
261 |
262 | cost_prob1_2 = self.softmax(cost1_1)
263 | cost_prob2_2 = self.softmax(cost2_1)
264 |
265 | disp1_high2 = torch.sum(disp_range_samples1 * cost_prob1_2, dim=1).unsqueeze(1)
266 | disp2_high2 = torch.sum(disp_range_samples2 * cost_prob2_2, dim=1).unsqueeze(1)
267 |
268 | feat_leftW = dispwarpfeature(feat_right, disp1_high2)
269 | feat_rightW = dispwarpfeature(feat_left, disp2_high2)
270 |
271 | geoerror_left = torch.abs(disp1_high2 - dispwarpfeature(disp2_high2, disp1_high2)).detach()
272 | geoerror_right = torch.abs(disp2_high2 - dispwarpfeature(disp1_high2, disp2_high2)).detach()
273 |
274 | V_left4 = 1 - torch.tanh(0.1*geoerror_left)
275 | V_right4 = 1 - torch.tanh(0.1*geoerror_right)
276 |
277 | V_left4 = torch.max(F.interpolate(V_left3, scale_factor=self.upscale_factor, mode='nearest'), V_left4)
278 | V_right4 = torch.max(F.interpolate(V_right3, scale_factor=self.upscale_factor, mode='nearest'), V_right4)
279 |
280 | left_att = self.att1(feat_left)
281 | leftW_att = self.att2(feat_leftW)
282 | corrleft = (torch.tanh(5*torch.sum(left_att*leftW_att, 1).unsqueeze(1))+1)/2
283 |
284 | right_att = self.att1(feat_right)
285 | rightW_att = self.att2(feat_rightW)
286 | corrright = (torch.tanh(5*torch.sum(right_att*rightW_att, 1).unsqueeze(1))+1)/2
287 |
288 | err1 = self.encoder2((feat_leftW - feat_left)*corrleft)
289 | buffer_leftF2 = self.CALayer2(err1 + feat_left)
290 |
291 | err2 = self.encoder2((feat_rightW - feat_right)*corrright)
292 | buffer_rightF2 = self.CALayer2(err2 + feat_right)
293 |
294 | out4_left = self.final2(buffer_leftF2)+x_left_upscale
295 | out4_right = self.final2(buffer_rightF2)+x_right_upscale
296 |
297 | if is_training == 0:
298 | index=(torch.arange(w*self.upscale_factor).view(1, 1, w*self.upscale_factor).repeat(1,h*self.upscale_factor,1)).to(buffer_left.device)
299 |
300 | disp1 = index-disp1.view(1,h*self.upscale_factor,w*self.upscale_factor)
301 | disp2 = disp2.view(1,h*self.upscale_factor,w*self.upscale_factor)-index
302 | disp1[disp1<0]=0
303 | disp2[disp2<0]=0
304 | disp1[disp1>192]=192
305 | disp2[disp2>192]=192
306 |
307 | disp1_3 = index-disp1_3.view(1,h*self.upscale_factor,w*self.upscale_factor)
308 | disp2_3 = disp2_3.view(1,h*self.upscale_factor,w*self.upscale_factor)-index
309 | disp1_3[disp1_3<0]=0
310 | disp2_3[disp2_3<0]=0
311 | disp1_3[disp1_3>192]=192
312 | disp2_3[disp2_3>192]=192
313 |
314 | disp1_high = index-disp1_high.view(1,h*self.upscale_factor,w*self.upscale_factor)
315 | disp2_high = disp2_high.view(1,h*self.upscale_factor,w*self.upscale_factor)-index
316 | disp1_high[disp1_high<0]=0
317 | disp2_high[disp2_high<0]=0
318 | disp1_high[disp1_high>192]=192
319 | disp2_high[disp2_high>192]=192
320 |
321 | disp1_high2 = index-disp1_high2.view(1,h*self.upscale_factor,w*self.upscale_factor)
322 | disp2_high2 = disp2_high2.view(1,h*self.upscale_factor,w*self.upscale_factor)-index
323 | disp1_high2[disp1_high2<0]=0
324 | disp2_high2[disp2_high2<0]=0
325 | disp1_high2[disp1_high2>192]=192
326 | disp2_high2[disp2_high2>192]=192
327 |
328 | return out1_left, out1_right, out2_left, out2_right, out3_left, out3_right, out4_left, out4_right, (disp1, disp2), (disp1_3, disp2_3), (disp1_high, disp2_high), (disp1_high2, disp2_high2)
329 |
330 | if is_training == 1:
331 | return out1_left, out1_right, out2_left, out2_right, out3_left, out3_right, out4_left, out4_right,\
332 | (M_right_to_left, M_left_to_right), (disp1, disp2), (V_left, V_right), (V_left2, V_right2), (disp1_high, disp2_high),\
333 | (M_right_to_left3, M_left_to_right3), (disp1_3, disp2_3), (V_left3, V_right3), (V_left4, V_right4), (disp1_high2, disp2_high2)
334 |
335 |
336 |
337 | def get_disp_range_samples(cur_disp, ndisp, shape, max_disp):
338 | #shape, (B, H, W)
339 | #cur_disp: (B, H, W)
340 | #return disp_range_samples: (B, D, H, W)
341 | cur_disp_min = (cur_disp - ndisp / 2).clamp(min=0.0) # (B, H, W)
342 | cur_disp_max = (cur_disp + ndisp / 2).clamp(max=max_disp)
343 |
344 | assert cur_disp.shape == torch.Size(shape), "cur_disp:{}, input shape:{}".format(cur_disp.shape, shape)
345 | new_interval = (cur_disp_max - cur_disp_min) / (ndisp - 1) # (B, H, W)
346 |
347 | disp_range_samples = cur_disp_min.unsqueeze(1) + (torch.arange(0, ndisp, device=cur_disp.device,
348 | dtype=cur_disp.dtype,
349 | requires_grad=False).reshape(1, -1, 1,
350 | 1) * new_interval.unsqueeze(1))
351 | return disp_range_samples
352 |
353 |
354 | class Stereo_feature(nn.Module):
355 | def __init__(self):
356 | super(Stereo_feature, self).__init__()
357 |
358 | self.branch1 = nn.Sequential(nn.AvgPool2d((30, 30), stride=(30,30)),
359 | convbn(64, 32, 1, 1, 0, 1),
360 | nn.PReLU())
361 |
362 | self.branch2 = nn.Sequential(nn.AvgPool2d((10, 10), stride=(10,10)),
363 | convbn(64, 32, 1, 1, 0, 1),
364 | nn.PReLU())
365 |
366 | self.branch3 = nn.Sequential(nn.AvgPool2d((5, 5), stride=(5,5)),
367 | convbn(64, 32, 1, 1, 0, 1),
368 | nn.PReLU())
369 |
370 | self.lastconv = nn.Sequential(convbn(160, 64, 3, 1, 1, 1),
371 | nn.PReLU(),
372 | nn.Conv2d(64, 64, kernel_size=1, padding=0, stride = 1, bias=False))
373 |
374 | def forward(self, output_skip):
375 |
376 | output_branch1 = self.branch1(output_skip)
377 | output_branch1 = F.upsample(output_branch1, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear')
378 |
379 | output_branch2 = self.branch2(output_skip)
380 | output_branch2 = F.upsample(output_branch2, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear')
381 |
382 | output_branch3 = self.branch3(output_skip)
383 | output_branch3 = F.upsample(output_branch3, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear')
384 |
385 | output_feature = torch.cat((output_skip, output_branch3, output_branch2, output_branch1), 1)
386 | output_feature = self.lastconv(output_feature)
387 |
388 | return output_feature
389 |
390 |
391 | class hourglass(nn.Module):
392 | def __init__(self, inplanes):
393 | super(hourglass, self).__init__()
394 |
395 | self.conv1 = nn.Sequential(convbn(inplanes, inplanes, kernel_size=3, stride=2, pad=1, dilation=1),
396 | nn.PReLU())
397 |
398 | self.conv2 = convbn(inplanes, inplanes, kernel_size=3, stride=1, pad=1, dilation=1)
399 |
400 | self.conv3 = nn.Sequential(convbn(inplanes, inplanes, kernel_size=3, stride=2, pad=1, dilation=1),
401 | nn.PReLU())
402 |
403 | self.conv4 = nn.Sequential(convbn(inplanes, inplanes, kernel_size=3, stride=1, pad=1, dilation=1),
404 | nn.PReLU())
405 |
406 | self.conv5 = nn.Sequential(nn.ConvTranspose2d(inplanes, inplanes, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False),
407 | nn.BatchNorm2d(inplanes)) #+conv2
408 |
409 | self.conv6 = nn.Sequential(nn.ConvTranspose2d(inplanes, inplanes, kernel_size=3, padding=1, output_padding=1, stride=2,bias=False),
410 | nn.BatchNorm2d(inplanes)) #+x
411 |
412 | self.prelu = nn.PReLU()
413 |
414 | def forward(self, x ,presqu, postsqu):
415 |
416 | out = self.conv1(x) #in:1/4 out:1/8
417 | pre = self.conv2(out) #in:1/8 out:1/8
418 | if postsqu is not None:
419 | pre = self.prelu(pre + postsqu)
420 | else:
421 | pre = self.prelu(pre)
422 |
423 | out = self.conv3(pre) #in:1/8 out:1/16
424 | out = self.conv4(out) #in:1/16 out:1/16
425 |
426 | if presqu is not None:
427 | post = self.prelu(self.conv5(out)+presqu) #in:1/16 out:1/8
428 | else:
429 | post = self.prelu(self.conv5(out)+pre)
430 |
431 | out = self.conv6(post) #in:1/8 out:1/4
432 |
433 | return out
434 |
435 |
436 | class hStereo_feature(nn.Module):
437 | def __init__(self):
438 | super(hStereo_feature, self).__init__()
439 |
440 | self.branch2 = nn.Sequential(nn.AvgPool2d((16, 16), stride=(16,16)),
441 | convbn(64, 24, 1, 1, 0, 1),
442 | nn.PReLU())
443 |
444 | self.branch3 = nn.Sequential(nn.AvgPool2d((8, 8), stride=(8,8)),
445 | convbn(64, 24, 1, 1, 0, 1),
446 | nn.PReLU())
447 |
448 | self.lastconv = nn.Sequential(convbn(48+64, 64, 3, 1, 1, 1),
449 | nn.PReLU(),
450 | nn.Conv2d(64, 24, kernel_size=1, padding=0, stride = 1, bias=False))
451 |
452 | def forward(self, output_skip):
453 |
454 | output_branch2 = self.branch2(output_skip)
455 | output_branch2 = F.upsample(output_branch2, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear')
456 |
457 | output_branch3 = self.branch3(output_skip)
458 | output_branch3 = F.upsample(output_branch3, (output_skip.size()[2],output_skip.size()[3]),mode='bilinear')
459 |
460 | output_feature = torch.cat((output_skip, output_branch3, output_branch2), 1)
461 | output_feature = self.lastconv(output_feature)
462 |
463 | return output_feature
464 |
465 | class FeedbackBlock(torch.nn.Module):
466 | def __init__(self, in_filter, num_filter, kernel_size=8, stride=4, padding=2, bias=True, activation='prelu', norm=None):
467 | super(FeedbackBlock, self).__init__()
468 | #self.conv1 = ConvBlock(in_filter, num_filter, 1, 1, 0, activation='prelu', norm=None)
469 | self.avgpool_1 = torch.nn.AvgPool2d(4, 4, 0)
470 | self.up_1 = DeconvBlock(num_filter, num_filter , 8, 4, 2, activation='prelu', norm=None)
471 | self.act_1 = torch.nn.ReLU(True)
472 |
473 | def forward(self, x):
474 |
475 | #x = self.conv1(x)
476 | p1 = self.avgpool_1(x)
477 | l00 = self.up_1(p1)
478 | l00 = F.upsample(l00, x.size()[2:], mode='bilinear')
479 | act1 = self.act_1(x - l00)
480 | return act1
481 |
482 | def dispwarpfeature(feat, disp):
483 | bs, channels, height, width = feat.size()
484 | mh,_ = torch.meshgrid([torch.arange(0, height, dtype=feat.dtype, device=feat.device), torch.arange(0, width, dtype=feat.dtype, device=feat.device)]) # (H *W)
485 | mh = mh.reshape(1, 1, height, width).repeat(bs, 1, 1, 1)
486 |
487 | cur_disp_coords_y = mh
488 | cur_disp_coords_x = disp
489 |
490 | coords_x = cur_disp_coords_x / ((width - 1.0) / 2.0) - 1.0 # trans to -1 - 1
491 | coords_y = cur_disp_coords_y / ((height - 1.0) / 2.0) - 1.0
492 | grid = torch.stack([coords_x, coords_y], dim=4).view(bs, height, width, 2) #(B, D, H, W, 2)->(B, D*H, W, 2)
493 |
494 | #warped = F.grid_sample(feat, grid.view(bs, ndisp * height, width, 2), mode='bilinear', padding_mode='zeros').view(bs, channels, ndisp, height, width)
495 | warped_feat = F.grid_sample(feat, grid, mode='bilinear', padding_mode='zeros').view(bs, channels,height, width)
496 |
497 | return warped_feat
498 |
499 | def warpfeature(feat, disp_range_samples, cost_prob, ndisp):
500 | bs, channels, height, width = feat.size()
501 | mh,_ = torch.meshgrid([torch.arange(0, height, dtype=feat.dtype, device=feat.device), torch.arange(0, width, dtype=feat.dtype, device=feat.device)]) # (H *W)
502 | mh = mh.reshape(1, 1, height, width).repeat(bs, ndisp, 1, 1)
503 |
504 | cur_disp_coords_y = mh
505 | cur_disp_coords_x = disp_range_samples
506 |
507 | coords_x = cur_disp_coords_x / ((width - 1.0) / 2.0) - 1.0 # trans to -1 - 1
508 | coords_y = cur_disp_coords_y / ((height - 1.0) / 2.0) - 1.0
509 | grid = torch.stack([coords_x, coords_y], dim=4).view(bs, ndisp * height, width, 2) #(B, D, H, W, 2)->(B, D*H, W, 2)
510 |
511 | #warped = F.grid_sample(feat, grid.view(bs, ndisp * height, width, 2), mode='bilinear', padding_mode='zeros').view(bs, channels, ndisp, height, width)
512 | warped_feat = cost_prob[:, 0, :, :].unsqueeze(1) * F.grid_sample(feat, grid[:, :height, :, :], mode='bilinear', padding_mode='zeros').view(bs, channels,height, width)
513 | for i in range(1, ndisp):
514 | warped_feat += cost_prob[:, i, :, :].unsqueeze(1) * F.grid_sample(feat, grid[:, i*height:(i+1)*height, :, :], mode='bilinear', padding_mode='zeros').view(bs, channels,height, width)
515 |
516 | return warped_feat
517 |
518 | class GetCostVolume(nn.Module):
519 | def __init__(self, channels):
520 | super(GetCostVolume, self).__init__()
521 | self.query = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
522 | self.key = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
523 |
524 | def forward(self, x, y, disp_range_samples1, disp_range_samples2, ndisp):
525 | assert (x.is_contiguous() == True)
526 |
527 | Q = self.query(x)
528 | K = self.key(y)
529 |
530 | bs, channels, height, width = x.size()
531 | cost1 = x.new().resize_(bs, ndisp, height, width).zero_()
532 | cost2 = x.new().resize_(bs, ndisp, height, width).zero_()
533 | # cost = y.unsqueeze(2).repeat(1, 2, ndisp, 1, 1) #(B, D, H, W)
534 |
535 | mh, mw = torch.meshgrid([torch.arange(0, height, dtype=x.dtype, device=x.device),
536 | torch.arange(0, width, dtype=x.dtype, device=x.device)]) # (H *W)
537 | mh = mh.reshape(1, 1, height, width).repeat(bs, ndisp, 1, 1)
538 | mw = mw.reshape(1, 1, height, width).repeat(bs, ndisp, 1, 1) # (B, D, H, W)
539 |
540 | cur_disp_coords_y = mh
541 | cur_disp_coords_x = disp_range_samples1
542 |
543 | coords_x = cur_disp_coords_x / ((width - 1.0) / 2.0) - 1.0 # trans to -1 - 1
544 | coords_y = cur_disp_coords_y / ((height - 1.0) / 2.0) - 1.0
545 | grid = torch.stack([coords_x, coords_y], dim=4).view(bs, ndisp * height, width, 2) #(B, D, H, W, 2)
546 |
547 | for i in range(ndisp):
548 | cost1[:, i, :, :] = (Q * F.grid_sample(K, grid[:, i*height:(i+1)*height, :, :], mode='bilinear', padding_mode='zeros').view(bs, channels,height, width)).mean(dim=1)
549 |
550 | Q = self.query(y)
551 | K = self.key(x)
552 | cur_disp_coords_x = disp_range_samples2
553 | coords_y = cur_disp_coords_y / ((height - 1.0) / 2.0) - 1.0
554 | grid = torch.stack([coords_x, coords_y], dim=4).view(bs, ndisp * height, width, 2)
555 |
556 | for i in range(ndisp):
557 | cost2[:, i, :, :] = (Q * F.grid_sample(K, grid[:, i*height:(i+1)*height, :, :], mode='bilinear', padding_mode='zeros').view(bs, channels,height, width)).mean(dim=1)
558 |
559 | return cost1, cost2
560 |
561 | class one_conv(nn.Module):
562 | def __init__(self, G0, G):
563 | super(one_conv, self).__init__()
564 | self.conv = nn.Conv2d(G0, G, kernel_size=3, stride=1, padding=1, bias=True)
565 | #self.relu = nn.LeakyReLU(0.1, inplace=True)
566 | self.relu = nn.PReLU()
567 | def forward(self, x):
568 | output = self.relu(self.conv(x))
569 | return torch.cat((x, output), dim=1)
570 |
571 |
572 | class RDB(nn.Module):
573 | def __init__(self, G0, C, G):
574 | super(RDB, self).__init__()
575 | convs = []
576 | for i in range(C):
577 | convs.append(one_conv(G0+i*G, G))
578 | self.conv = nn.Sequential(*convs)
579 | self.LFF = nn.Conv2d(G0+C*G, G0, kernel_size=1, stride=1, padding=0, bias=True)
580 | def forward(self, x):
581 | out = self.conv(x)
582 | lff = self.LFF(out)
583 | return lff + x
584 |
585 |
586 | class RDG(nn.Module):
587 | def __init__(self, G0, C, G, n_RDB):
588 | super(RDG, self).__init__()
589 | self.n_RDB = n_RDB
590 | RDBs = []
591 | for i in range(n_RDB):
592 | RDBs.append(RDB(G0, C, G))
593 | self.RDB = nn.Sequential(*RDBs)
594 | self.conv = nn.Conv2d(G0*n_RDB, G0, kernel_size=1, stride=1, padding=0, bias=True)
595 |
596 | def forward(self, x):
597 | buffer = x
598 | temp = []
599 | for i in range(self.n_RDB):
600 | buffer = self.RDB[i](buffer)
601 | temp.append(buffer)
602 | buffer_cat = torch.cat(temp, dim=1)
603 | out = self.conv(buffer_cat)
604 | return out
605 |
606 |
607 | class CALayer(nn.Module):
608 | def __init__(self, channel, reduction):
609 | super(CALayer, self).__init__()
610 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
611 | self.conv_du = nn.Sequential(
612 | nn.Conv2d(channel, channel//reduction, 1, padding=0, bias=True),
613 | #nn.LeakyReLU(0.1, inplace=True),
614 | nn.PReLU(),
615 | nn.Conv2d(channel//reduction, channel, 1, padding=0, bias=True),
616 | nn.Sigmoid())
617 |
618 | def forward(self, x):
619 | y = self.avg_pool(x)
620 | y = self.conv_du(y)
621 | return x * y
622 |
623 |
624 | class ResB(nn.Module):
625 | def __init__(self, channels):
626 | super(ResB, self).__init__()
627 | self.body = nn.Sequential(
628 | nn.Conv2d(channels, channels, 3, 1, 1, groups=4, bias=True),
629 | #nn.LeakyReLU(0.1, inplace=True),
630 | nn.PReLU(),
631 | nn.Conv2d(channels, channels, 3, 1, 1, groups=4, bias=True),
632 | )
633 | def __call__(self,x):
634 | out = self.body(x)
635 | return out + x
636 |
637 |
638 | class PAM(nn.Module):
639 | def __init__(self, channels):
640 | super(PAM, self).__init__()
641 | self.pab1 = PAB(channels)
642 | self.pab2 = PAB(channels)
643 | self.pab3 = PAB(channels)
644 | self.pab4 = PAB(channels)
645 | self.softmax = nn.Softmax(-1)
646 |
647 | def forward(self, x_left, x_right, fea_left, fea_right, cost, is_training):
648 | b, c, h, w = fea_left.shape
649 | fea_left, fea_right, cost = self.pab1(fea_left, fea_right, cost)
650 | fea_left, fea_right, cost = self.pab2(fea_left, fea_right, cost)
651 | fea_left, fea_right, cost = self.pab3(fea_left, fea_right, cost)
652 | fea_left, fea_right, cost = self.pab4(fea_left, fea_right, cost)
653 |
654 | M_right_to_left = self.softmax(cost[0]) # (B*H) * Wl * Wr
655 | M_left_to_right = self.softmax(cost[1]) # (B*H) * Wr * Wl
656 |
657 |
658 | M_right_to_leftp = M_right_to_left.view(b,h,w,w).permute(0,3,1,2).contiguous()
659 | M_left_to_rightp = M_left_to_right.view(b,h,w,w).permute(0,3,1,2).contiguous()
660 |
661 | M_right_to_left_relaxed = M_Relax(M_right_to_left, num_pixels=2)
662 | V_left = torch.bmm(M_right_to_left_relaxed.contiguous().view(-1, w).unsqueeze(1),
663 | M_left_to_right.permute(0, 2, 1).contiguous().view(-1, w).unsqueeze(2)
664 | ).detach().contiguous().view(b, 1, h, w) # (B*H*Wr) * Wl * 1
665 | M_left_to_right_relaxed = M_Relax(M_left_to_right, num_pixels=2)
666 | V_right = torch.bmm(M_left_to_right_relaxed.contiguous().view(-1, w).unsqueeze(1), # (B*H*Wl) * 1 * Wr
667 | M_right_to_left.permute(0, 2, 1).contiguous().view(-1, w).unsqueeze(2)
668 | ).detach().contiguous().view(b, 1, h, w) # (B*H*Wr) * Wl * 1
669 |
670 | V_left_tanh = torch.tanh(5 * V_left)
671 | V_right_tanh = torch.tanh(5 * V_right)
672 |
673 | x_leftT = torch.bmm(M_right_to_left, x_right.permute(0, 2, 3, 1).contiguous().view(-1, w, c)
674 | ).contiguous().view(b, h, w, c).permute(0, 3, 1, 2) # B, C0, H0, W0
675 | x_rightT = torch.bmm(M_left_to_right, x_left.permute(0, 2, 3, 1).contiguous().view(-1, w, c)
676 | ).contiguous().view(b, h, w, c).permute(0, 3, 1, 2) # B, C0, H0, W0
677 | out_left = x_left * (1 - V_left_tanh.repeat(1, c, 1, 1)) + x_leftT * V_left_tanh.repeat(1, c, 1, 1)
678 | out_right = x_right * (1 - V_right_tanh.repeat(1, c, 1, 1)) + x_rightT * V_right_tanh.repeat(1, c, 1, 1)
679 |
680 | index = torch.arange(w).view(1, 1, 1, w).to(M_right_to_left.device).float() # index: 1*1*1*w
681 | disp1 = torch.sum(M_right_to_left * index, dim=-1).view(b, 1, h, w) # x axis of the corresponding point
682 | disp2 = torch.sum(M_left_to_right * index, dim=-1).view(b, 1, h, w)
683 |
684 |
685 | return out_left, out_right, disp1, disp2,\
686 | (M_right_to_left.contiguous().view(b, h, w, w), M_left_to_right.contiguous().view(b, h, w, w)),\
687 | (V_left_tanh, V_right_tanh)
688 |
689 | class PAB(nn.Module):
690 | def __init__(self, channels):
691 | super(PAB, self).__init__()
692 | self.head = nn.Sequential(
693 | nn.Conv2d(channels, channels//2, 3, 1, 1, bias=True),
694 | nn.BatchNorm2d(channels//2),
695 | nn.PReLU(),
696 | nn.Conv2d(channels//2, channels, 3, 1, 1, bias=True),
697 | nn.BatchNorm2d(channels),
698 | nn.PReLU(),
699 | )
700 | self.bq = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
701 | self.bs = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
702 |
703 | def __call__(self, fea_left, fea_right, cost):
704 | b, c0, h0, w0 = fea_left.shape
705 | fea_left1 = self.head(fea_left)
706 | fea_right1 = self.head(fea_right)
707 | Q = self.bq(fea_left1)
708 | b, c, h, w = Q.shape
709 | Q = Q - torch.mean(Q, 3).unsqueeze(3).repeat(1, 1, 1, w)
710 | K = self.bs(fea_right1)
711 | K = K - torch.mean(K, 3).unsqueeze(3).repeat(1, 1, 1, w)
712 |
713 | score = torch.bmm(Q.permute(0, 2, 3, 1).contiguous().view(-1, w, c), # (B*H) * Wl * C
714 | K.permute(0, 2, 1, 3).contiguous().view(-1, c, w))
715 |
716 | cost[0] += score
717 | cost[1] += score.permute(0, 2, 1)
718 | return fea_left+fea_left1, fea_right+fea_right1, cost
719 |
720 |
721 | def M_Relax(M, num_pixels):
722 | _, u, v = M.shape
723 | M_list = []
724 | M_list.append(M.unsqueeze(1))
725 | for i in range(num_pixels):
726 | pad = nn.ZeroPad2d(padding=(0, 0, i+1, 0))
727 | pad_M = pad(M[:, :-1-i, :])
728 | M_list.append(pad_M.unsqueeze(1))
729 | for i in range(num_pixels):
730 | pad = nn.ZeroPad2d(padding=(0, 0, 0, i+1))
731 | pad_M = pad(M[:, i+1:, :])
732 | M_list.append(pad_M.unsqueeze(1))
733 | M_relaxed = torch.sum(torch.cat(M_list, 1), dim=1)
734 | return M_relaxed
735 |
736 | if __name__ == "__main__":
737 | net = Net(upscale_factor=4)
738 | total = sum([param.nelement() for param in net.parameters()])
739 | print(' Number of params: %.2fM' % (total / 1e6))
--------------------------------------------------------------------------------
/result.txt:
--------------------------------------------------------------------------------
1 | StereoSR:
2 |
3 | KITTI2012 mean psnr left: 26.619422161710588
4 | KITTI2012 mean psnr average: 26.712393708455398
5 |
6 | KITTI2015 mean psnr left: 25.74107406904548
7 | KITTI2015 mean psnr average: 26.464446701876494
8 |
9 | Middlebury mean psnr left: 29.308098231282155
10 | Middlebury mean psnr average: 29.381088935551993
11 |
12 | Flickr1024 mean psnr left: 23.50759679238374
13 | Flickr1024 mean psnr average: 23.585143575772204
14 |
15 |
16 | Disparity estimation:
17 |
18 | K2015 disp high left all1: 4.089844226837158
19 | K2015 disp high left noc1: 3.322787284851074
20 | K2015 disp high left all0: 5.026285648345947
21 | K2015 disp high left noc0: 4.3886919021606445
22 | K2015 disp left all1: 4.741995811462402
23 | K2015 disp left noc1: 3.98885440826416
24 | K2015 disp left all0: 5.614530086517334
25 | K2015 disp left noc0: 4.997827529907227
26 |
27 | K2012 disp high left all1: 4.963388919830322
28 | K2012 disp high left noc1: 3.6856789588928223
29 | K2012 disp high left all0: 6.467651844024658
30 | K2012 disp high left noc0: 5.434603214263916
31 | K2012 disp left all1: 5.762098789215088
32 | K2012 disp left noc1: 4.506109237670898
33 | K2012 disp left all0: 7.1608195304870605
34 | K2012 disp left noc0: 6.164617538452148
35 |
--------------------------------------------------------------------------------
/test_disp.py:
--------------------------------------------------------------------------------
1 | from torch.autograd import Variable
2 | from PIL import Image
3 | from torchvision.transforms import ToTensor
4 | import argparse
5 | import os
6 | from model import *
7 | from utils import *
8 | import numpy as np
9 | import torch.nn.functional as F
10 | import re
11 | from metric import *
12 | import imageio
13 |
14 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
15 | os.environ["CUDA_VISIBLE_DEVICES"] = '0'
16 |
17 | def parse_args():
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--testset_dir', type=str, default='./data/test/')
20 | parser.add_argument('--scale_factor', type=int, default=4)
21 | parser.add_argument('--device', type=str, default='cuda')
22 | parser.add_argument('--model_name', type=str, default='SSRDEF_4xSR')
23 | return parser.parse_args()
24 |
25 | def read_disp(filename, subset=False):
26 | # Scene Flow dataset
27 | if filename.endswith('pfm'):
28 | # For finalpass and cleanpass, gt disparity is positive, subset is negative
29 | disp = np.ascontiguousarray(_read_pfm(filename)[0])
30 | if subset:
31 | disp = -disp
32 | # KITTI
33 | elif filename.endswith('png'):
34 | disp = _read_kitti_disp(filename)
35 | elif filename.endswith('npy'):
36 | disp = np.load(filename)
37 | else:
38 | raise Exception('Invalid disparity file format!')
39 | return disp # [H, W]
40 |
41 | def _read_pfm(file):
42 | file = open(file, 'rb')
43 |
44 | color = None
45 | width = None
46 | height = None
47 | scale = None
48 | endian = None
49 |
50 | header = file.readline().rstrip()
51 | if header.decode("ascii") == 'PF':
52 | color = True
53 | elif header.decode("ascii") == 'Pf':
54 | color = False
55 | else:
56 | raise Exception('Not a PFM file.')
57 |
58 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii"))
59 | if dim_match:
60 | width, height = list(map(int, dim_match.groups()))
61 | else:
62 | raise Exception('Malformed PFM header.')
63 |
64 | scale = float(file.readline().decode("ascii").rstrip())
65 | if scale < 0: # little-endian
66 | endian = '<'
67 | scale = -scale
68 | else:
69 | endian = '>' # big-endian
70 |
71 | data = np.fromfile(file, endian + 'f')
72 | shape = (height, width, 3) if color else (height, width)
73 |
74 | data = np.reshape(data, shape)
75 | data = np.flipud(data)
76 | return data, scale
77 |
78 | def _read_kitti_disp(filename):
79 | depth = np.array(Image.open(filename))
80 | depth = depth.astype(np.float32)/256.
81 | return depth
82 |
83 | def test(cfg, loadname, net):
84 | print(loadname)
85 | psnr_list = []
86 | psnr_list_r = []
87 | psnr_list_m = []
88 | psnr_list_r_m = []
89 | disphigh_occ0 = []
90 | disphigh_noc0 = []
91 | disp_occ0 = []
92 | disp_noc0 = []
93 | disphigh_occ1 = []
94 | disphigh_noc1 = []
95 | disp_occ1 = []
96 | disp_noc1 = []
97 |
98 | file_list = os.listdir(cfg.testset_dir + cfg.dataset + '/hr')
99 | for idx in range(len(file_list)):
100 | LR_left = Image.open(cfg.testset_dir + cfg.dataset + '/lr_x' + str(cfg.scale_factor) + '/' + file_list[idx] + '/lr0.png')
101 | LR_right = Image.open(cfg.testset_dir + cfg.dataset + '/lr_x' + str(cfg.scale_factor) + '/' + file_list[idx] + '/lr1.png')
102 | HR_left = Image.open(cfg.testset_dir + cfg.dataset + '/hr/' + file_list[idx] + '/hr0.png')
103 | HR_right = Image.open(cfg.testset_dir + cfg.dataset + '/hr/' + file_list[idx] + '/hr1.png')
104 | disp_left = read_disp(cfg.testset_dir + cfg.dataset + '/hr/' + file_list[idx] + '/dispocc0.png')
105 | disp_leftall = read_disp(cfg.testset_dir + cfg.dataset + '/hr/' + file_list[idx] + '/dispnoc0.png')
106 |
107 | LR_left, LR_right, HR_left, HR_right, disp_left, disp_leftall = ToTensor()(LR_left), ToTensor()(LR_right), ToTensor()(HR_left), ToTensor()(HR_right), ToTensor()(disp_left), ToTensor()(disp_leftall)
108 | LR_left, LR_right, HR_left, HR_right, disp_left, disp_leftall = LR_left.unsqueeze(0), LR_right.unsqueeze(0), HR_left.unsqueeze(0), HR_right.unsqueeze(0), disp_left.unsqueeze(0), disp_leftall.unsqueeze(0)
109 | LR_left, LR_right, HR_left, HR_right, disp_left, disp_leftall = Variable(LR_left).cuda(), Variable(LR_right).cuda(), Variable(HR_left).cuda(), Variable(HR_right).cuda(), Variable(disp_left).cuda(), Variable(disp_leftall).cuda()
110 | scene_name = file_list[idx]
111 |
112 | _,_,h,w=disp_left.shape
113 |
114 |
115 | disp_left=disp_left.view(1,h,w)
116 | disp_leftall=disp_leftall.view(1,h,w)
117 |
118 | mask0 = (disp_left > 0) & (disp_left < 192)
119 | mask1 = (disp_leftall > 0) & (disp_leftall < 192)
120 | _,h,w=mask0.shape
121 |
122 |
123 | #print('Running Scene ' + scene_name + ' of ' + cfg.dataset + ' Dataset......')
124 | with torch.no_grad():
125 | _,_,_,_,_,_,SR_left, SR_right, (disp1, disp2), (disp1_3, disp2_3), (disp1_high, disp2_high), (disp1_high2, disp2_high2) = net(LR_left, LR_right, is_training=0)
126 |
127 | SR_left, SR_right = torch.clamp(SR_left, 0, 1), torch.clamp(SR_right, 0, 1)
128 |
129 | psnr_list.append(cal_psnr(HR_left[:,:,:,64:].data.cpu(), SR_left[:,:,:,64:].data.cpu()))
130 | psnr_list_r.append(cal_psnr(HR_right.data.cpu(), SR_right.data.cpu()))
131 | psnr_list_r.append(cal_psnr(HR_left.data.cpu(), SR_left.data.cpu()))
132 | '''
133 | psnr_list_m.append(cal_psnr(HR_left[:,:,:,64:].data.cpu(), SR_left[0][:,:,:,64:].data.cpu()))
134 | psnr_list_r_m.append(cal_psnr(HR_right.data.cpu(), SR_right[0].data.cpu()))
135 | psnr_list_r_m.append(cal_psnr(HR_left.data.cpu(), SR_left[0].data.cpu()))
136 | '''
137 |
138 | disphigh_occ0.append((disp1_high[mask0].cpu()-disp_left[mask0].cpu()).abs().mean())
139 | disphigh_noc0.append((disp1_high[mask1].cpu()-disp_leftall[mask1].cpu()).abs().mean())
140 | disphigh_occ1.append((disp1_high2[mask0].cpu()-disp_left[mask0].cpu()).abs().mean())
141 | disphigh_noc1.append((disp1_high2[mask1].cpu()-disp_leftall[mask1].cpu()).abs().mean())
142 |
143 | disp_occ0.append((disp1[mask0].cpu()-disp_left[mask0].cpu()).abs().mean())
144 | disp_noc0.append((disp1[mask1].cpu()-disp_leftall[mask1].cpu()).abs().mean())
145 | disp_occ1.append((disp1_3[mask0].cpu()-disp_left[mask0].cpu()).abs().mean())
146 | disp_noc1.append((disp1_3[mask1].cpu()-disp_leftall[mask1].cpu()).abs().mean())
147 |
148 |
149 | #psnr_list_m.append(cal_psnr(HR_left[:,:,:,64:].data.cpu(), out1_left[:,:,:,64:].data.cpu()))
150 | #psnr_list_r_m.append(cal_psnr(HR_right.data.cpu(), out1_right.data.cpu()))
151 | #psnr_list_r_m.append(cal_psnr(HR_left.data.cpu(), out1_left.data.cpu()))
152 | #print(torch.mean(V_left2))
153 |
154 | save_path = './results/' + cfg.model_name + '/' + cfg.dataset
155 | if not os.path.exists(save_path):
156 | os.makedirs(save_path)
157 | SR_left_img = transforms.ToPILImage()(torch.squeeze(SR_left.data.cpu(), 0))
158 | SR_left_img.save(save_path + '/' + scene_name + '_L.png')
159 | SR_right_img = transforms.ToPILImage()(torch.squeeze(SR_right.data.cpu(), 0))
160 | SR_right_img.save(save_path + '/' + scene_name + '_R.png')
161 |
162 | imageio.imsave(save_path + '/' + scene_name + '_Lhdisp.png', torch.squeeze(disp1_high2.data.cpu(), 0))
163 | imageio.imsave(save_path + '/' + scene_name + '_Rhdisp.png', torch.squeeze(disp2_high2.data.cpu(), 0))
164 |
165 | print(cfg.dataset + ' mean psnr left: ', float(np.array(psnr_list).mean()))
166 | print(cfg.dataset + ' mean psnr average: ', float(np.array(psnr_list_r).mean()))
167 | #print(cfg.dataset + ' mean psnr left intermedite: ', float(np.array(psnr_list_m).mean()))
168 | #print(cfg.dataset + ' mean psnr average intermedite: ', float(np.array(psnr_list_r_m).mean()))
169 |
170 | print(cfg.dataset + ' disp high left all1: ', float(np.array(disphigh_occ1).mean()))
171 | print(cfg.dataset + ' disp high left noc1: ', float(np.array(disphigh_noc1).mean()))
172 | print(cfg.dataset + ' disp high left all0: ', float(np.array(disphigh_occ0).mean()))
173 | print(cfg.dataset + ' disp high left noc0: ', float(np.array(disphigh_noc0).mean()))
174 |
175 | print(cfg.dataset + ' disp left all1: ', float(np.array(disp_occ1).mean()))
176 | print(cfg.dataset + ' disp left noc1: ', float(np.array(disp_noc1).mean()))
177 | print(cfg.dataset + ' disp left all0: ', float(np.array(disp_occ0).mean()))
178 | print(cfg.dataset + ' disp left noc0: ', float(np.array(disp_noc0).mean()))
179 |
180 |
181 |
182 | if __name__ == '__main__':
183 | cfg = parse_args()
184 | dataset_list = ['K2012','K2015']
185 | net = SSRDEFNet(cfg.scale_factor).cuda()
186 | net = torch.nn.DataParallel(net)
187 | net.eval()
188 |
189 | for j in range(80,81):
190 | loadname = './checkpoints/SSRDEF_4xSR_epoch' + str(j) + '.pth.tar'
191 | model = torch.load(loadname)
192 | net.load_state_dict(model['state_dict'])
193 | for i in range(len(dataset_list)):
194 | cfg.dataset = dataset_list[i]
195 | test(cfg, loadname, net)
196 | print('Finished!')
197 |
--------------------------------------------------------------------------------
/test_sr.py:
--------------------------------------------------------------------------------
1 | from torch.autograd import Variable
2 | from PIL import Image
3 | from torchvision.transforms import ToTensor
4 | import argparse
5 | import os
6 | from model import *
7 | from utils import *
8 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
9 | os.environ["CUDA_VISIBLE_DEVICES"] = '0'
10 |
11 | def parse_args():
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--testset_dir', type=str, default='./data/test/')
14 | parser.add_argument('--scale_factor', type=int, default=4)
15 | parser.add_argument('--device', type=str, default='cuda')
16 | parser.add_argument('--model_name', type=str, default='SSRDEF_4xSR')
17 | return parser.parse_args()
18 |
19 |
20 | def test(cfg, loadname, net):
21 | print(loadname)
22 | psnr_list = []
23 | psnr_list_r = []
24 | psnr_list_m1 = []
25 | psnr_list_r_m1 = []
26 | psnr_list_m2 = []
27 | psnr_list_r_m2 = []
28 | psnr_list_m3 = []
29 | psnr_list_r_m3 = []
30 |
31 | file_list = os.listdir(cfg.testset_dir + cfg.dataset + '/hr')
32 | for idx in range(len(file_list)):
33 | LR_left = Image.open(cfg.testset_dir + cfg.dataset + '/lr_x' + str(cfg.scale_factor) + '/' + file_list[idx] + '/lr0.png')
34 | LR_right = Image.open(cfg.testset_dir + cfg.dataset + '/lr_x' + str(cfg.scale_factor) + '/' + file_list[idx] + '/lr1.png')
35 | HR_left = Image.open(cfg.testset_dir + cfg.dataset + '/hr/' + file_list[idx] + '/hr0.png')
36 | HR_right = Image.open(cfg.testset_dir + cfg.dataset + '/hr/' + file_list[idx] + '/hr1.png')
37 |
38 | LR_left, LR_right, HR_left, HR_right = ToTensor()(LR_left), ToTensor()(LR_right), ToTensor()(HR_left), ToTensor()(HR_right)
39 | LR_left, LR_right, HR_left, HR_right = LR_left.unsqueeze(0), LR_right.unsqueeze(0), HR_left.unsqueeze(0), HR_right.unsqueeze(0)
40 | LR_left, LR_right, HR_left, HR_right = Variable(LR_left).cuda(), Variable(LR_right).cuda(), Variable(HR_left).cuda(), Variable(HR_right).cuda()
41 | scene_name = file_list[idx]
42 | #print('Running Scene ' + scene_name + ' of ' + cfg.dataset + ' Dataset......')
43 | with torch.no_grad():
44 | SR_left1, SR_right1, SR_left2, SR_right2, SR_left3, SR_right3, SR_left4, SR_right4,_,_,_,_ = net(LR_left, LR_right, is_training=0)
45 | SR_left1, SR_right1 = torch.clamp(SR_left1, 0, 1), torch.clamp(SR_right1, 0, 1)
46 | SR_left2, SR_right2 = torch.clamp(SR_left2, 0, 1), torch.clamp(SR_right2, 0, 1)
47 | SR_left3, SR_right3 = torch.clamp(SR_left3, 0, 1), torch.clamp(SR_right3, 0, 1)
48 | SR_left4, SR_right4 = torch.clamp(SR_left4, 0, 1), torch.clamp(SR_right4, 0, 1)
49 |
50 | psnr_list.append(cal_psnr(HR_left[:,:,:,64:].data.cpu(), SR_left4[:,:,:,64:].data.cpu()))
51 | psnr_list_r.append(cal_psnr(HR_right.data.cpu(), SR_right4.data.cpu()))
52 | psnr_list_r.append(cal_psnr(HR_left.data.cpu(), SR_left4.data.cpu()))
53 |
54 | psnr_l = cal_psnr(HR_left.data.cpu(), SR_left4.data.cpu())
55 | psnr_r = cal_psnr(HR_right.data.cpu(), SR_right4.data.cpu())
56 |
57 | psnr_l1 = cal_psnr(HR_left.data.cpu(), SR_left1.data.cpu())
58 | psnr_r1 = cal_psnr(HR_right.data.cpu(), SR_right1.data.cpu())
59 |
60 | psnr_l2 = cal_psnr(HR_left.data.cpu(), SR_left2.data.cpu())
61 | psnr_r2 = cal_psnr(HR_right.data.cpu(), SR_right2.data.cpu())
62 |
63 | psnr_l3 = cal_psnr(HR_left.data.cpu(), SR_left3.data.cpu())
64 | psnr_r3 = cal_psnr(HR_right.data.cpu(), SR_right3.data.cpu())
65 |
66 | psnr_list_m1.append(cal_psnr(HR_left[:,:,:,64:].data.cpu(), SR_left1[:,:,:,64:].data.cpu()))
67 | psnr_list_r_m1.append(cal_psnr(HR_right.data.cpu(), SR_right1.data.cpu()))
68 | psnr_list_r_m1.append(cal_psnr(HR_left.data.cpu(), SR_left1.data.cpu()))
69 |
70 | psnr_list_m2.append(cal_psnr(HR_left[:,:,:,64:].data.cpu(), SR_left2[:,:,:,64:].data.cpu()))
71 | psnr_list_r_m2.append(cal_psnr(HR_right.data.cpu(), SR_right2.data.cpu()))
72 | psnr_list_r_m2.append(cal_psnr(HR_left.data.cpu(), SR_left2.data.cpu()))
73 |
74 | psnr_list_m3.append(cal_psnr(HR_left[:,:,:,64:].data.cpu(), SR_left3[:,:,:,64:].data.cpu()))
75 | psnr_list_r_m3.append(cal_psnr(HR_right.data.cpu(), SR_right3.data.cpu()))
76 | psnr_list_r_m3.append(cal_psnr(HR_left.data.cpu(), SR_left3.data.cpu()))
77 |
78 | save_path = './results/' + cfg.model_name + '/' + cfg.dataset
79 | if not os.path.exists(save_path):
80 | os.makedirs(save_path)
81 |
82 | SR_left_img = transforms.ToPILImage()(torch.squeeze(SR_left4.data.cpu(), 0))
83 | SR_left_img.save(save_path + '/' + scene_name + '_L%.2f.png'%psnr_l)
84 | SR_right_img = transforms.ToPILImage()(torch.squeeze(SR_right4.data.cpu(), 0))
85 | SR_right_img.save(save_path + '/' + scene_name + '_R%.2f.png'%psnr_r)
86 | '''
87 | SR_left_img = transforms.ToPILImage()(torch.squeeze(SR_left3.data.cpu(), 0))
88 | SR_left_img.save(save_path + '/' + scene_name + '_L%.2f.png'%psnr_l3)
89 | SR_right_img = transforms.ToPILImage()(torch.squeeze(SR_right3.data.cpu(), 0))
90 | SR_right_img.save(save_path + '/' + scene_name + '_R%.2f.png'%psnr_r3)
91 |
92 | SR_left_img = transforms.ToPILImage()(torch.squeeze(SR_left2.data.cpu(), 0))
93 | SR_left_img.save(save_path + '/' + scene_name + '_L%.2f.png'%psnr_l2)
94 | SR_right_img = transforms.ToPILImage()(torch.squeeze(SR_right2.data.cpu(), 0))
95 | SR_right_img.save(save_path + '/' + scene_name + '_R%.2f.png'%psnr_r2)
96 |
97 | SR_left_img = transforms.ToPILImage()(torch.squeeze(SR_left1.data.cpu(), 0))
98 | SR_left_img.save(save_path + '/' + scene_name + '_L%.2f.png'%psnr_l1)
99 | SR_right_img = transforms.ToPILImage()(torch.squeeze(SR_right1.data.cpu(), 0))
100 | SR_right_img.save(save_path + '/' + scene_name + '_R%.2f.png'%psnr_r1)
101 | '''
102 |
103 |
104 | print(cfg.dataset + ' mean psnr left: ', float(np.array(psnr_list).mean()))
105 | print(cfg.dataset + ' mean psnr average: ', float(np.array(psnr_list_r).mean()))
106 | print(cfg.dataset + ' mean psnr left intermediate3: ', float(np.array(psnr_list_m3).mean()))
107 | print(cfg.dataset + ' mean psnr average intermediate3: ', float(np.array(psnr_list_r_m3).mean()))
108 | print(cfg.dataset + ' mean psnr left intermediate2: ', float(np.array(psnr_list_m2).mean()))
109 | print(cfg.dataset + ' mean psnr average intermediate2: ', float(np.array(psnr_list_r_m2).mean()))
110 | print(cfg.dataset + ' mean psnr left intermediate1: ', float(np.array(psnr_list_m1).mean()))
111 | print(cfg.dataset + ' mean psnr average intermediate1: ', float(np.array(psnr_list_r_m1).mean()))
112 |
113 | if __name__ == '__main__':
114 | cfg = parse_args()
115 | dataset_list = ['Middlebury', 'KITTI2012','KITTI2015','Flickr1024']
116 | net = SSRDEFNet(cfg.scale_factor).cuda()
117 | net = torch.nn.DataParallel(net)
118 | net.eval()
119 |
120 | for j in range(80,81):
121 | loadname = './checkpoints/SSRDEF_4xSR_epoch' + str(j) + '.pth.tar'
122 | print(loadname)
123 | model = torch.load(loadname)
124 | net.load_state_dict(model['state_dict'])
125 | for i in range(len(dataset_list)):
126 | cfg.dataset = dataset_list[i]
127 | test(cfg, loadname, net)
128 | print('Finished!')
129 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from torch.autograd import Variable
2 | from torch.utils.data import DataLoader
3 | import torch.backends.cudnn as cudnn
4 | import argparse
5 | from utils import *
6 | from model import *
7 | from torchvision.transforms import ToTensor
8 | import os
9 | import torch.nn.functional as F
10 | from loss import *
11 |
12 |
13 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
14 | os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3'
15 |
16 |
17 | def get_parameter_number(net):
18 | total_num = sum(p.numel() for p in net.parameters())
19 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
20 | return {'Total': total_num, 'Trainable': trainable_num}
21 |
22 | def parse_args():
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--scale_factor", type=int, default=4)
25 | parser.add_argument('--device', type=str, default='cuda')
26 | parser.add_argument('--batch_size', type=int, default=12)
27 | parser.add_argument('--lr', type=float, default=2e-4, help='initial learning rate')
28 | parser.add_argument('--gamma', type=float, default=0.5, help='')
29 | parser.add_argument('--start_epoch', type=int, default=0, help='start epoch')
30 | parser.add_argument('--n_epochs', type=int, default=80, help='number of epochs to train')
31 | parser.add_argument('--n_steps', type=int, default=30, help='number of epochs to update learning rate')
32 | parser.add_argument('--trainset_dir', type=str, default='./data/train/Flickr1024_patches')
33 | parser.add_argument('--model_name', type=str, default='SSRDEF')
34 | parser.add_argument('--load_pretrain', type=bool, default=False)
35 | parser.add_argument('--model_path', type=str, default='')
36 | parser.add_argument('--testset_dir', type=str, default='./data/test/')
37 | return parser.parse_args()
38 |
39 | def warpfeature(feat, disp_range_samples, cost_prob, ndisp):
40 | bs, channels, height, width = feat.size()
41 | mh,_ = torch.meshgrid([torch.arange(0, height, dtype=feat.dtype, device=feat.device), torch.arange(0, width, dtype=feat.dtype, device=feat.device)]) # (H *W)
42 | mh = mh.reshape(1, 1, height, width).repeat(bs, ndisp, 1, 1)
43 |
44 | cur_disp_coords_y = mh
45 | cur_disp_coords_x = disp_range_samples
46 |
47 | coords_x = cur_disp_coords_x / ((width - 1.0) / 2.0) - 1.0 # trans to -1 - 1
48 | coords_y = cur_disp_coords_y / ((height - 1.0) / 2.0) - 1.0
49 | grid = torch.stack([coords_x, coords_y], dim=4).view(bs, ndisp * height, width, 2) #(B, D, H, W, 2)->(B, D*H, W, 2)
50 |
51 | #warped = F.grid_sample(feat, grid.view(bs, ndisp * height, width, 2), mode='bilinear', padding_mode='zeros').view(bs, channels, ndisp, height, width)
52 | warped_feat = cost_prob[:, 0, :, :].unsqueeze(1) * F.grid_sample(feat, grid[:, :height, :, :], mode='bilinear', padding_mode='zeros').view(bs, channels,height, width)
53 | for i in range(1, ndisp):
54 | warped_feat += cost_prob[:, i, :, :].unsqueeze(1) * F.grid_sample(feat, grid[:, i*height:(i+1)*height, :, :], mode='bilinear', padding_mode='zeros').view(bs, channels,height, width)
55 |
56 | return warped_feat
57 |
58 | def dispwarpfeature(feat, disp):
59 | bs, channels, height, width = feat.size()
60 | mh,_ = torch.meshgrid([torch.arange(0, height, dtype=feat.dtype, device=feat.device), torch.arange(0, width, dtype=feat.dtype, device=feat.device)]) # (H *W)
61 | mh = mh.reshape(1, 1, height, width).repeat(bs, 1, 1, 1)
62 |
63 | cur_disp_coords_y = mh
64 | cur_disp_coords_x = disp
65 |
66 | coords_x = cur_disp_coords_x / ((width - 1.0) / 2.0) - 1.0 # trans to -1 - 1
67 | coords_y = cur_disp_coords_y / ((height - 1.0) / 2.0) - 1.0
68 | grid = torch.stack([coords_x, coords_y], dim=4).view(bs, height, width, 2) #(B, D, H, W, 2)->(B, D*H, W, 2)
69 |
70 | #warped = F.grid_sample(feat, grid.view(bs, ndisp * height, width, 2), mode='bilinear', padding_mode='zeros').view(bs, channels, ndisp, height, width)
71 | warped_feat = F.grid_sample(feat, grid, mode='bilinear', padding_mode='zeros').view(bs, channels,height, width)
72 |
73 | return warped_feat
74 |
75 | def cal_grad(image):
76 | """
77 | Calculate the image-edge-aware second-order smoothness loss for flo
78 | """
79 |
80 | def gradient(pred):
81 | D_dy = pred[:, :, 1:, :] - pred[:, :, :-1, :]
82 | D_dx = pred[:, :, :, 1:] - pred[:, :, :, :-1]
83 | D_dy = F.pad(D_dy, pad=(0,0,0,1), mode="constant", value=0)
84 | D_dx = F.pad(D_dx, pad=(0,1,0,0), mode="constant", value=0)
85 | return D_dx, D_dy
86 |
87 |
88 | img_grad_x, img_grad_y = gradient(image)
89 | weights_x = torch.exp(-10.0 * torch.mean(torch.abs(img_grad_x), 1, keepdim=True))
90 | weights_y = torch.exp(-10.0 * torch.mean(torch.abs(img_grad_y), 1, keepdim=True))
91 |
92 | return weights_x, weights_y
93 |
94 | def load_pretrain(model, pretrained_dict):
95 | torch_params = model.state_dict()
96 | for k,v in pretrained_dict.items():
97 | print(k)
98 | pretrained_dict_1 = {k: v for k, v in pretrained_dict.items() if k in torch_params}
99 | torch_params.update(pretrained_dict_1)
100 | model.load_state_dict(torch_params)
101 |
102 | def train(train_loader, cfg):
103 | net = SSRDEFNet(cfg.scale_factor).cuda()
104 | print(get_parameter_number(net))
105 | cudnn.benchmark = True
106 | scale = cfg.scale_factor
107 |
108 | net = torch.nn.DataParallel(net)
109 |
110 | if cfg.load_pretrain:
111 | if os.path.isfile(cfg.model_path):
112 | model = torch.load(cfg.model_path)
113 | net.load_state_dict(model['state_dict'])
114 | cfg.start_epoch = model["epoch"]
115 | else:
116 | print("=> no model found at '{}'".format(cfg.load_model))
117 |
118 |
119 | # net = torch.nn.DataParallel(net, device_ids=[0, 1])
120 | criterion_L1 = torch.nn.L1Loss().cuda()
121 | optimizer = torch.optim.Adam([paras for paras in net.parameters() if paras.requires_grad == True], lr=cfg.lr)
122 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=cfg.n_steps, gamma=cfg.gamma)
123 |
124 | loss_epoch = []
125 | loss_list = []
126 | psnr_epoch = []
127 | psnr_epoch_r = []
128 | psnr_epoch_m = []
129 | psnr_epoch_r_m = []
130 |
131 | for idx_epoch in range(cfg.start_epoch, cfg.n_epochs):
132 |
133 | for idx_iter, (HR_left, HR_right, LR_left, LR_right) in enumerate(train_loader):
134 | b, c, h, w = LR_left.shape
135 | _, _, h2, w2 = HR_left.shape
136 | HR_left, HR_right, LR_left, LR_right = Variable(HR_left).cuda(), Variable(HR_right).cuda(),\
137 | Variable(LR_left).cuda(), Variable(LR_right).cuda()
138 |
139 | SR_left, SR_right, SR_left2, SR_right2, SR_left3, SR_right3, SR_left4, SR_right4,\
140 | (M_right_to_left, M_left_to_right), (disp1, disp2), (V_left, V_right), (V_left2, V_right2), (disp1_high, disp2_high),\
141 | (M_right_to_left3, M_left_to_right3), (disp1_3, disp2_3), (V_left3, V_right3), (V_left4, V_right4), (disp1_high_2, disp2_high_2)\
142 | =net(LR_left, LR_right, is_training=1)
143 |
144 | ''' SR Loss '''
145 | loss_SR = criterion_L1(SR_left, HR_left) + criterion_L1(SR_right, HR_right) + criterion_L1(SR_left2, HR_left) + criterion_L1(SR_right2, HR_right) +\
146 | criterion_L1(SR_left3, HR_left) + criterion_L1(SR_right3, HR_right) + criterion_L1(SR_left4, HR_left) + criterion_L1(SR_right4, HR_right)
147 |
148 | loss_S = loss_disp_smoothness(disp1_high, HR_left) + loss_disp_smoothness(disp2_high, HR_right) + \
149 | loss_disp_smoothness(disp1_high_2, HR_left) + loss_disp_smoothness(disp2_high_2, HR_right)
150 |
151 | loss_P = loss_disp_unsupervised(HR_left, HR_right, disp1, F.interpolate(V_left, scale_factor=4, mode='nearest')) + loss_disp_unsupervised(HR_right, HR_left, disp2, F.interpolate(V_right, scale_factor=4, mode='nearest')) +\
152 | loss_disp_unsupervised(HR_left, HR_right, disp1_high, V_left2) + loss_disp_unsupervised(HR_right, HR_left, disp2_high, V_right2) + \
153 | loss_disp_unsupervised(HR_left, HR_right, disp1_3, F.interpolate(V_left3, scale_factor=4, mode='nearest')) + loss_disp_unsupervised(HR_right, HR_left, disp2_3, F.interpolate(V_right3, scale_factor=4, mode='nearest')) +\
154 | loss_disp_unsupervised(HR_left, HR_right, disp1_high_2, V_left4) + loss_disp_unsupervised(HR_right, HR_left, disp2_high_2, V_right4)
155 |
156 | ''' Photometric Loss '''
157 | Res_left = torch.abs(HR_left - F.interpolate(LR_left, scale_factor=scale, mode='bicubic', align_corners=False))
158 | Res_left_low = F.interpolate(Res_left, scale_factor=1 / scale, mode='bicubic', align_corners=False)
159 | Res_right = torch.abs(HR_right - F.interpolate(LR_right, scale_factor=scale, mode='bicubic', align_corners=False))
160 | Res_right_low = F.interpolate(Res_right, scale_factor=1 / scale, mode='bicubic', align_corners=False)
161 | Res_leftT_low = torch.bmm(M_right_to_left.contiguous().view(b * h, w, w), Res_right_low.permute(0, 2, 3, 1).contiguous().view(b * h, w, c)
162 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2)
163 | Res_rightT_low = torch.bmm(M_left_to_right.contiguous().view(b * h, w, w), Res_left_low.permute(0, 2, 3, 1).contiguous().view(b * h, w, c)
164 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2)
165 | Res_leftT_low2 = torch.bmm(M_right_to_left3.contiguous().view(b * h, w, w), Res_right_low.permute(0, 2, 3, 1).contiguous().view(b * h, w, c)
166 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2)
167 | Res_rightT_low2 = torch.bmm(M_left_to_right3.contiguous().view(b * h, w, w), Res_left_low.permute(0, 2, 3, 1).contiguous().view(b * h, w, c)
168 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2)
169 | Res_leftT = dispwarpfeature(Res_right, disp1_high)
170 | Res_rightT = dispwarpfeature(Res_left, disp2_high)
171 | Res_leftT2 = dispwarpfeature(Res_right, disp1_high_2)
172 | Res_rightT2 = dispwarpfeature(Res_left, disp2_high_2)
173 |
174 | loss_photo = criterion_L1(Res_left_low * V_left.repeat(1, 3, 1, 1), Res_leftT_low * V_left.repeat(1, 3, 1, 1)) + \
175 | criterion_L1(Res_right_low * V_right.repeat(1, 3, 1, 1), Res_rightT_low * V_right.repeat(1, 3, 1, 1)) + \
176 | criterion_L1(Res_left * V_left2.repeat(1, 3, 1, 1), Res_leftT * V_left2.repeat(1, 3, 1, 1)) + \
177 | criterion_L1(Res_right * V_right2.repeat(1, 3, 1, 1), Res_rightT * V_right2.repeat(1, 3, 1, 1)) +\
178 | criterion_L1(Res_left_low * V_left3.repeat(1, 3, 1, 1), Res_leftT_low2 * V_left3.repeat(1, 3, 1, 1)) + \
179 | criterion_L1(Res_right_low * V_right3.repeat(1, 3, 1, 1), Res_rightT_low2 * V_right3.repeat(1, 3, 1, 1)) + \
180 | criterion_L1(Res_left * V_left4.repeat(1, 3, 1, 1), Res_leftT2 * V_left4.repeat(1, 3, 1, 1)) + \
181 | criterion_L1(Res_right * V_right4.repeat(1, 3, 1, 1), Res_rightT2 * V_right4.repeat(1, 3, 1, 1))
182 |
183 | loss_h = criterion_L1(M_right_to_left[:, :-1, :, :], M_right_to_left[:, 1:, :, :]) + \
184 | criterion_L1(M_left_to_right[:, :-1, :, :], M_left_to_right[:, 1:, :, :]) + \
185 | criterion_L1(M_right_to_left3[:, :-1, :, :], M_right_to_left3[:, 1:, :, :]) + \
186 | criterion_L1(M_left_to_right3[:, :-1, :, :], M_left_to_right3[:, 1:, :, :])
187 |
188 | loss_w = criterion_L1(M_right_to_left[:, :, :-1, :-1], M_right_to_left[:, :, 1:, 1:]) + \
189 | criterion_L1(M_left_to_right[:, :, :-1, :-1], M_left_to_right[:, :, 1:, 1:]) + \
190 | criterion_L1(M_right_to_left3[:, :, :-1, :-1], M_right_to_left3[:, :, 1:, 1:]) + \
191 | criterion_L1(M_left_to_right3[:, :, :-1, :-1], M_left_to_right3[:, :, 1:, 1:])
192 |
193 | loss_smooth = loss_w + loss_h
194 |
195 | ''' Cycle Loss '''
196 | Res_left_cycle_low = torch.bmm(M_right_to_left.contiguous().view(b * h, w, w), Res_rightT_low.permute(0, 2, 3, 1).contiguous().view(b * h, w, c)
197 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2)
198 | Res_right_cycle_low = torch.bmm(M_left_to_right.contiguous().view(b * h, w, w), Res_leftT_low.permute(0, 2, 3, 1).contiguous().view(b * h, w, c)
199 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2)
200 | Res_left_cycle = dispwarpfeature(Res_rightT, disp1_high)
201 | Res_right_cycle = dispwarpfeature(Res_leftT, disp2_high)
202 |
203 | Res_left_cycle_low2 = torch.bmm(M_right_to_left3.contiguous().view(b * h, w, w), Res_rightT_low2.permute(0, 2, 3, 1).contiguous().view(b * h, w, c)
204 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2)
205 | Res_right_cycle_low2 = torch.bmm(M_left_to_right3.contiguous().view(b * h, w, w), Res_leftT_low2.permute(0, 2, 3, 1).contiguous().view(b * h, w, c)
206 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2)
207 | Res_left_cycle2 = dispwarpfeature(Res_rightT2, disp1_high_2)
208 | Res_right_cycle2 = dispwarpfeature(Res_leftT2, disp2_high_2)
209 |
210 | loss_cycle = criterion_L1(Res_left_low * V_left.repeat(1, 3, 1, 1), Res_left_cycle_low * V_left.repeat(1, 3, 1, 1)) + \
211 | criterion_L1(Res_right_low * V_right.repeat(1, 3, 1, 1), Res_right_cycle_low * V_right.repeat(1, 3, 1, 1)) + \
212 | criterion_L1(Res_left * V_left2.repeat(1, 3, 1, 1), Res_left_cycle * V_left2.repeat(1, 3, 1, 1)) + \
213 | criterion_L1(Res_right * V_right2.repeat(1, 3, 1, 1), Res_right_cycle * V_right2.repeat(1, 3, 1, 1)) +\
214 | criterion_L1(Res_left_low * V_left3.repeat(1, 3, 1, 1), Res_left_cycle_low2 * V_left3.repeat(1, 3, 1, 1)) + \
215 | criterion_L1(Res_right_low * V_right3.repeat(1, 3, 1, 1), Res_right_cycle_low2 * V_right3.repeat(1, 3, 1, 1)) + \
216 | criterion_L1(Res_left * V_left4.repeat(1, 3, 1, 1), Res_left_cycle2 * V_left4.repeat(1, 3, 1, 1)) + \
217 | criterion_L1(Res_right * V_right4.repeat(1, 3, 1, 1), Res_right_cycle2 * V_right4.repeat(1, 3, 1, 1))
218 |
219 | ''' Consistency Loss '''
220 | SR_left_res = F.interpolate(torch.abs(HR_left - SR_left), scale_factor=1 / scale, mode='bicubic', align_corners=False)
221 | SR_right_res = F.interpolate(torch.abs(HR_right - SR_right), scale_factor=1 / scale, mode='bicubic', align_corners=False)
222 | SR_left_res3 = F.interpolate(torch.abs(HR_left - SR_left3), scale_factor=1 / scale, mode='bicubic', align_corners=False)
223 | SR_right_res3 = F.interpolate(torch.abs(HR_right - SR_right3), scale_factor=1 / scale, mode='bicubic', align_corners=False)
224 |
225 | SR_left_res2 = torch.abs(HR_left - SR_left2)
226 | SR_right_res2 = torch.abs(HR_right - SR_right2)
227 | SR_left_res4 = torch.abs(HR_left - SR_left4)
228 | SR_right_res4 = torch.abs(HR_right - SR_right4)
229 |
230 | SR_left_resT = torch.bmm(M_right_to_left.detach().contiguous().view(b * h, w, w), SR_right_res.permute(0, 2, 3, 1).contiguous().view(b * h, w, c)
231 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2)
232 | SR_right_resT = torch.bmm(M_left_to_right.detach().contiguous().view(b * h, w, w), SR_left_res.permute(0, 2, 3, 1).contiguous().view(b * h, w, c)
233 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2)
234 | SR_left_resT2 = dispwarpfeature(SR_right_res2, disp1_high)
235 | SR_right_resT2 = dispwarpfeature(SR_left_res2, disp2_high)
236 |
237 | SR_left_resT3 = torch.bmm(M_right_to_left3.detach().contiguous().view(b * h, w, w), SR_right_res3.permute(0, 2, 3, 1).contiguous().view(b * h, w, c)
238 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2)
239 | SR_right_resT3 = torch.bmm(M_left_to_right3.detach().contiguous().view(b * h, w, w), SR_left_res3.permute(0, 2, 3, 1).contiguous().view(b * h, w, c)
240 | ).view(b, h, w, c).contiguous().permute(0, 3, 1, 2)
241 | SR_left_resT4 = dispwarpfeature(SR_right_res4, disp1_high_2)
242 | SR_right_resT4 = dispwarpfeature(SR_left_res4, disp2_high_2)
243 |
244 | loss_cons = criterion_L1(SR_left_res * V_left.repeat(1, 3, 1, 1), SR_left_resT * V_left.repeat(1, 3, 1, 1)) + \
245 | criterion_L1(SR_right_res * V_right.repeat(1, 3, 1, 1), SR_right_resT * V_right.repeat(1, 3, 1, 1)) + \
246 | criterion_L1(SR_left_res2 * V_left2.repeat(1, 3, 1, 1), SR_left_resT2 * V_left2.repeat(1, 3, 1, 1)) + \
247 | criterion_L1(SR_right_res2 * V_right2.repeat(1, 3, 1, 1), SR_right_resT2 * V_right2.repeat(1, 3, 1, 1)) + \
248 | criterion_L1(SR_left_res3 * V_left3.repeat(1, 3, 1, 1), SR_left_resT3 * V_left3.repeat(1, 3, 1, 1)) + \
249 | criterion_L1(SR_right_res3 * V_right3.repeat(1, 3, 1, 1), SR_right_resT3 * V_right3.repeat(1, 3, 1, 1)) + \
250 | criterion_L1(SR_left_res4 * V_left4.repeat(1, 3, 1, 1), SR_left_resT4 * V_left4.repeat(1, 3, 1, 1)) + \
251 | criterion_L1(SR_right_res4 * V_right4.repeat(1, 3, 1, 1), SR_right_resT4 * V_right4.repeat(1, 3, 1, 1))
252 | ''' Total Loss '''
253 | loss = loss_SR + 0.1 * loss_cons + 0.1 * (loss_photo + loss_smooth + loss_cycle) + 0.001*loss_S + 0.01*loss_P
254 | optimizer.zero_grad()
255 | loss.backward()
256 | optimizer.step()
257 |
258 | psnr_epoch.append(cal_psnr(HR_left[:,:,:,64:].data.cpu(), SR_left4[:,:,:,64:].data.cpu()))
259 | psnr_epoch_r.append(cal_psnr(HR_right[:,:,:,:HR_right.shape[3]-64].data.cpu(), SR_right4[:,:,:,:HR_right.shape[3]-64].data.cpu()))
260 |
261 | psnr_epoch_m.append(cal_psnr(HR_left[:,:,:,64:].data.cpu(), SR_left2[:,:,:,64:].data.cpu()))
262 | psnr_epoch_r_m.append(cal_psnr(HR_right[:,:,:,:HR_right.shape[3]-64].data.cpu(), SR_right2[:,:,:,:HR_right.shape[3]-64].data.cpu()))
263 | loss_epoch.append(loss.data.cpu())
264 | if idx_iter%300==0:
265 | print("SRloss: {:.4f} Photoloss: {:.5f} Smoothloss: {:.5f} Cycleloss: {:.5f} Consloss: {:.5f} Ploss: {:.5f} Sloss: {:.5f}".format(loss_SR.item(), 0.1*loss_photo.item(), 0.1*loss_smooth.item(), 0.1*loss_cycle.item(), 0.1*loss_cons.item(), 0.02*loss_P.item(), 0.001*loss_S.item()))
266 | print(torch.mean(V_left2))
267 | print(torch.mean(V_left4))
268 |
269 | scheduler.step()
270 |
271 |
272 | if idx_epoch % 1 == 0:
273 | loss_list.append(float(np.array(loss_epoch).mean()))
274 |
275 | print('Epoch--%4d, loss--%f, loss_SR--%f, loss_photo--%f, loss_smooth--%f, loss_cycle--%f, loss_cons--%f' %
276 | (idx_epoch + 1, float(np.array(loss_epoch).mean()), float(np.array(loss_SR.data.cpu()).mean()),
277 | float(np.array(loss_photo.data.cpu()).mean()), float(np.array(loss_smooth.data.cpu()).mean()),
278 | float(np.array(loss_cycle.data.cpu()).mean()), float(np.array(loss_cons.data.cpu()).mean())))
279 | print('PSNR left---%f, PSNR right---%f' % (float(np.array(psnr_epoch).mean()), float(np.array(psnr_epoch_r).mean())))
280 | print('intermediate PSNR left---%f, PSNR right---%f' % (float(np.array(psnr_epoch_m).mean()), float(np.array(psnr_epoch_r_m).mean())))
281 | loss_epoch = []
282 | psnr_epoch = []
283 | psnr_epoch_r = []
284 | psnr_epoch_m = []
285 | psnr_epoch_r_m = []
286 |
287 | torch.save({'epoch': idx_epoch + 1, 'state_dict': net.state_dict()},
288 | 'checkpoints/' + cfg.model_name + '_' + str(cfg.scale_factor) + 'xSR_epoch' + str(idx_epoch + 1) + '.pth.tar')
289 |
290 |
291 | def main(cfg):
292 | train_set = TrainSetLoader(cfg)
293 | train_loader = DataLoader(dataset=train_set, num_workers=6, batch_size=cfg.batch_size, shuffle=True)
294 | train(train_loader, cfg)
295 |
296 | if __name__ == '__main__':
297 | cfg = parse_args()
298 | main(cfg)
299 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import os
3 | from torch.utils.data.dataset import Dataset
4 | import random
5 | import torch
6 | import numpy as np
7 | from skimage import measure
8 |
9 | def read_disp(filename, subset=False):
10 | # Scene Flow dataset
11 | if filename.endswith('pfm'):
12 | # For finalpass and cleanpass, gt disparity is positive, subset is negative
13 | disp = np.ascontiguousarray(_read_pfm(filename)[0])
14 | if subset:
15 | disp = -disp
16 | # KITTI
17 | elif filename.endswith('png'):
18 | disp = _read_kitti_disp(filename)
19 | elif filename.endswith('npy'):
20 | disp = np.load(filename)
21 | else:
22 | raise Exception('Invalid disparity file format!')
23 | return disp # [H, W]
24 |
25 |
26 | def _read_pfm(file):
27 | file = open(file, 'rb')
28 |
29 | color = None
30 | width = None
31 | height = None
32 | scale = None
33 | endian = None
34 |
35 | header = file.readline().rstrip()
36 | if header.decode("ascii") == 'PF':
37 | color = True
38 | elif header.decode("ascii") == 'Pf':
39 | color = False
40 | else:
41 | raise Exception('Not a PFM file.')
42 |
43 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii"))
44 | if dim_match:
45 | width, height = list(map(int, dim_match.groups()))
46 | else:
47 | raise Exception('Malformed PFM header.')
48 |
49 | scale = float(file.readline().decode("ascii").rstrip())
50 | if scale < 0: # little-endian
51 | endian = '<'
52 | scale = -scale
53 | else:
54 | endian = '>' # big-endian
55 |
56 | data = np.fromfile(file, endian + 'f')
57 | shape = (height, width, 3) if color else (height, width)
58 |
59 | data = np.reshape(data, shape)
60 | data = np.flipud(data)
61 | return data, scale
62 |
63 | def _read_kitti_disp(filename):
64 | depth = np.array(Image.open(filename))
65 | depth = depth.astype(np.float32) / 256.
66 | return depth
67 |
68 | class TrainSetLoader(Dataset):
69 | def __init__(self, cfg):
70 | super(TrainSetLoader, self).__init__()
71 | self.dataset_dir = cfg.trainset_dir + '/patches_x' + str(cfg.scale_factor)
72 | self.file_list = os.listdir(self.dataset_dir)
73 | def __getitem__(self, index):
74 | img_hr_left = Image.open(self.dataset_dir + '/' + self.file_list[index] + '/hr0.png')
75 | img_hr_right = Image.open(self.dataset_dir + '/' + self.file_list[index] + '/hr1.png')
76 | img_lr_left = Image.open(self.dataset_dir + '/' + self.file_list[index] + '/lr0.png')
77 | img_lr_right = Image.open(self.dataset_dir + '/' + self.file_list[index] + '/lr1.png')
78 | img_hr_left = np.array(img_hr_left, dtype=np.float32)
79 | img_hr_right = np.array(img_hr_right, dtype=np.float32)
80 | img_lr_left = np.array(img_lr_left, dtype=np.float32)
81 | img_lr_right = np.array(img_lr_right, dtype=np.float32)
82 | img_hr_left, img_hr_right, img_lr_left, img_lr_right = augmentation(img_hr_left, img_hr_right, img_lr_left, img_lr_right)
83 | return toTensor(img_hr_left), toTensor(img_hr_right), toTensor(img_lr_left), toTensor(img_lr_right)
84 |
85 | def __len__(self):
86 | return len(self.file_list)
87 |
88 | def cal_psnr(img1, img2):
89 | img1_np = np.array(img1)
90 | img2_np = np.array(img2)
91 | return measure.compare_psnr(img1_np, img2_np)
92 |
93 | def augmentation(hr_image_left, hr_image_right, lr_image_left, lr_image_right):
94 | flag=0
95 | if random.random()<0.5: # flip horizonly
96 | lr_image_left_ = lr_image_right[:, ::-1, :]
97 | lr_image_right_ = lr_image_left[:, ::-1, :]
98 | hr_image_left_ = hr_image_right[:, ::-1, :]
99 | hr_image_right_ = hr_image_left[:, ::-1, :]
100 | lr_image_left, lr_image_right = lr_image_left_, lr_image_right_
101 | hr_image_left, hr_image_right = hr_image_left_, hr_image_right_
102 | flag=1
103 |
104 | if random.random()<0.5: #flip vertically
105 | lr_image_left = lr_image_left[::-1, :, :]
106 | lr_image_right = lr_image_right[::-1, :, :]
107 | hr_image_left = hr_image_left[::-1, :, :]
108 | hr_image_right = hr_image_right[::-1, :, :]
109 |
110 | return np.ascontiguousarray(hr_image_left), np.ascontiguousarray(hr_image_right), \
111 | np.ascontiguousarray(lr_image_left), np.ascontiguousarray(lr_image_right)
112 |
113 | def toTensor(img):
114 | img = torch.from_numpy(img.transpose((2, 0, 1)))
115 | return img.float().div(255)
116 |
117 | def D1_metric(D_est, D_gt, mask, threshold=3):
118 | mask = mask.byte()
119 | error = []
120 | for i in range(D_gt.size(0)):
121 | D_est_, D_gt_ = D_est[i,...][mask[i,...]], D_gt[i,...][mask[i,...]]
122 | if len(D_gt_) > 0:
123 | E = torch.abs(D_gt_ - D_est_)
124 | err_mask = (E > threshold) & (E / D_gt_.abs() > 0.05)
125 | error.append(torch.mean(err_mask.float()).data.cpu())
126 | return error
127 |
128 |
129 | def EPE_metric(D_est, D_gt, mask):
130 | mask = mask.byte()
131 | error = []
132 | for i in range(D_gt.size(0)):
133 | D_est_, D_gt_ = D_est[i,...][mask[i,...]], D_gt[i,...][mask[i,...]]
134 | if len(D_gt_) > 0:
135 | error.append(F.l1_loss(D_est_, D_gt_, size_average=True).data.cpu())
136 | return error
137 |
138 |
139 |
140 | '''
141 | from PIL import Image
142 | import os
143 | from torch.utils.data.dataset import Dataset
144 | import random
145 | import torch
146 | import numpy as np
147 | from skimage import measure
148 | import torch.nn.functional as F
149 |
150 | def read_disp(filename, subset=False):
151 | # Scene Flow dataset
152 | if filename.endswith('pfm'):
153 | # For finalpass and cleanpass, gt disparity is positive, subset is negative
154 | disp = np.ascontiguousarray(_read_pfm(filename)[0])
155 | if subset:
156 | disp = -disp
157 | # KITTI
158 | elif filename.endswith('png'):
159 | disp = _read_kitti_disp(filename)
160 | elif filename.endswith('npy'):
161 | disp = np.load(filename)
162 | else:
163 | raise Exception('Invalid disparity file format!')
164 | return disp # [H, W]
165 |
166 |
167 | def _read_pfm(file):
168 | file = open(file, 'rb')
169 |
170 | color = None
171 | width = None
172 | height = None
173 | scale = None
174 | endian = None
175 |
176 | header = file.readline().rstrip()
177 | if header.decode("ascii") == 'PF':
178 | color = True
179 | elif header.decode("ascii") == 'Pf':
180 | color = False
181 | else:
182 | raise Exception('Not a PFM file.')
183 |
184 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii"))
185 | if dim_match:
186 | width, height = list(map(int, dim_match.groups()))
187 | else:
188 | raise Exception('Malformed PFM header.')
189 |
190 | scale = float(file.readline().decode("ascii").rstrip())
191 | if scale < 0: # little-endian
192 | endian = '<'
193 | scale = -scale
194 | else:
195 | endian = '>' # big-endian
196 |
197 | data = np.fromfile(file, endian + 'f')
198 | shape = (height, width, 3) if color else (height, width)
199 |
200 | data = np.reshape(data, shape)
201 | data = np.flipud(data)
202 | return data, scale
203 |
204 | def _read_kitti_disp(filename):
205 | depth = np.array(Image.open(filename))
206 | depth = depth.astype(np.float32) / 256.
207 | return depth
208 |
209 | class TrainSetLoader(Dataset):
210 | def __init__(self, cfg):
211 | super(TrainSetLoader, self).__init__()
212 | self.dataset_dir = cfg.trainset_dir + '/patches_x' + str(cfg.scale_factor)
213 | self.file_list = os.listdir(self.dataset_dir)
214 | def __getitem__(self, index):
215 | img_hr_left = Image.open(self.dataset_dir + '/' + self.file_list[index] + '/hr0.png')
216 | img_hr_right = Image.open(self.dataset_dir + '/' + self.file_list[index] + '/hr1.png')
217 | img_lr_left = Image.open(self.dataset_dir + '/' + self.file_list[index] + '/lr0.png')
218 | img_lr_right = Image.open(self.dataset_dir + '/' + self.file_list[index] + '/lr1.png')
219 | img_hr_left = np.array(img_hr_left, dtype=np.float32)
220 | img_hr_right = np.array(img_hr_right, dtype=np.float32)
221 | img_lr_left = np.array(img_lr_left, dtype=np.float32)
222 | img_lr_right = np.array(img_lr_right, dtype=np.float32)
223 | img_hr_left, img_hr_right, img_lr_left, img_lr_right = augmentation(img_hr_left, img_hr_right, img_lr_left, img_lr_right)
224 | return toTensor(img_hr_left), toTensor(img_hr_right), toTensor(img_lr_left), toTensor(img_lr_right)
225 |
226 | def __len__(self):
227 | return len(self.file_list)
228 |
229 | def cal_psnr(img1, img2):
230 | img1_np = np.array(img1)
231 | img2_np = np.array(img2)
232 | return measure.compare_psnr(img1_np, img2_np)
233 |
234 | def augmentation(hr_image_left, hr_image_right, lr_image_left, lr_image_right):
235 |
236 | if random.random()<0.5: # flip horizonly
237 | lr_image_left_ = lr_image_right[:, ::-1, :]
238 | lr_image_right_ = lr_image_left[:, ::-1, :]
239 | hr_image_left_ = hr_image_right[:, ::-1, :]
240 | hr_image_right_ = hr_image_left[:, ::-1, :]
241 | lr_image_left, lr_image_right = lr_image_left_, lr_image_right_
242 | hr_image_left, hr_image_right = hr_image_left_, hr_image_right_
243 |
244 |
245 | if random.random()<0.5: #flip vertically
246 | lr_image_left = lr_image_left[::-1, :, :]
247 | lr_image_right = lr_image_right[::-1, :, :]
248 | hr_image_left = hr_image_left[::-1, :, :]
249 | hr_image_right = hr_image_right[::-1, :, :]
250 |
251 | return np.ascontiguousarray(hr_image_left), np.ascontiguousarray(hr_image_right), \
252 | np.ascontiguousarray(lr_image_left), np.ascontiguousarray(lr_image_right)
253 |
254 | def toTensor(img):
255 | img = torch.from_numpy(img.transpose((2, 0, 1)))
256 | return img.float().div(255)
257 |
258 | def D1_metric(D_est, D_gt, mask, threshold=3):
259 | mask = mask.byte()
260 | error = []
261 | for i in range(D_gt.size(0)):
262 | D_est_, D_gt_ = D_est[i,...][mask[i,...]], D_gt[i,...][mask[i,...]]
263 | if len(D_gt_) > 0:
264 | E = torch.abs(D_gt_ - D_est_)
265 | err_mask = (E > threshold) & (E / D_gt_.abs() > 0.05)
266 | error.append(torch.mean(err_mask.float()).data.cpu())
267 | return error
268 |
269 |
270 | def EPE_metric(D_est, D_gt, mask):
271 | mask = mask.byte()
272 | error = []
273 | for i in range(D_gt.size(0)):
274 | D_est_, D_gt_ = D_est[i,...][mask[i,...]], D_gt[i,...][mask[i,...]]
275 | if len(D_gt_) > 0:
276 | error.append(F.l1_loss(D_est_, D_gt_, size_average=True).data.cpu())
277 | return error
278 | '''
--------------------------------------------------------------------------------