├── .gitignore
├── LICENSE
├── README.md
├── ckpt
└── .gitkeep
├── config.py
├── core
├── __init__.py
├── build.py
├── test_deblur.py
├── test_disp.py
├── test_stereodeblur.py
├── train_deblur.py
├── train_disp.py
└── train_stereodeblur.py
├── datasets
├── flyingthings3d.json
└── stereo_deblur_data.json
├── losses
├── __init__.py
└── multiscaleloss.py
├── models
├── DeblurNet.py
├── DispNet_Bi.py
├── StereoDeblurNet.py
├── VGG19.py
├── __init__.py
└── submodules.py
├── requirements.txt
├── runner.py
└── utils
├── __init__.py
├── data_loaders.py
├── data_transforms.py
├── imgio_gen.py
└── network_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
103 | # PyCharm
104 | .idea
105 |
106 | # Checkpoints
107 | ckpt/
108 |
109 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Shangchen Zhou @ SenseTime
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DAVANet
2 |
3 | Code repo for the paper "DAVANet: Stereo Deblurring with View Aggregation" (CVPR'19, Oral). [[Paper]](http://openaccess.thecvf.com/content_CVPR_2019/papers/Zhou_DAVANet_Stereo_Deblurring_With_View_Aggregation_CVPR_2019_paper.pdf) [[Project Page]](https://shangchenzhou.com/projects/davanet/)
4 |
5 |
6 |
7 |
8 |
9 |
10 | ## Stereo Blur Dataset
11 |
12 |
13 |
14 |
15 | Download the dataset (192.5GB, unzipped 202.2GB) from [[Data Website]](https://stereoblur.shangchenzhou.com/).
16 |
17 | ## Pretrained Models
18 |
19 | You could download the pretrained model (34.8MB) of DAVANet from [[Here]](https://drive.google.com/file/d/1oVhKnPe_zrRa_JQUinW52ycJ2EGoAcHG/view?usp=sharing).
20 |
21 | (Note that the model does not need to unzip, just load it directly.)
22 |
23 | ## Prerequisites
24 |
25 | - Linux (tested on Ubuntu 14.04/16.04)
26 | - Python 2.7+
27 | - Pytorch 0.4.1
28 | - easydict
29 | - tensorboardX
30 | - pyexr
31 |
32 | #### Installation
33 |
34 | ```
35 | pip install -r requirements.txt
36 | ```
37 |
38 | ## Get Started
39 |
40 | Use the following command to train the neural network:
41 |
42 | ```
43 | python runner.py
44 | --phase 'train'\
45 | --data [dataset path]\
46 | --out [output path]
47 | ```
48 |
49 | Use the following command to test the neural network:
50 |
51 | ```
52 | python runner.py \
53 | --phase 'test'\
54 | --weights './ckpt/best-ckpt.pth.tar'\
55 | --data [dataset path]\
56 | --out [output path]
57 | ```
58 | Use the following command to resume training the neural network:
59 |
60 | ```
61 | python runner.py
62 | --phase 'resume'\
63 | --weights './ckpt/best-ckpt.pth.tar'\
64 | --data [dataset path]\
65 | --out [output path]
66 | ```
67 | You can also use the following simple command, with changing the settings in config.py:
68 |
69 | ```
70 | python runner.py
71 | ```
72 |
73 | ## Results on the testing dataset
74 |
75 |
76 |
77 |
78 |
79 | ## Citation
80 | If you find DAVANet, or Stereo Blur Dataset useful in your research, please consider citing:
81 |
82 | ```
83 | @inproceedings{zhou2019davanet,
84 | title={{DAVANet}: Stereo Deblurring with View Aggregation},
85 | author={Zhou, Shangchen and Zhang, Jiawei and Zuo, Wangmeng and Xie, Haozhe and Pan, Jinshan and Ren, Jimmy},
86 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
87 | year={2019}
88 | }
89 | ```
90 |
91 | ## Contact
92 |
93 | We are glad to hear if you have any suggestions and questions.
94 |
95 | Please send email to shangchenzhou@gmail.com
96 |
97 | ## Reference
98 | [1] Zhe Hu, Li Xu, and Ming-Hsuan Yang. Joint depth estimation and camera shake removal from single blurry image. In *CVPR*, 2014.
99 |
100 | [2] Seungjun Nah, Tae Hyun Kim, and Kyoung Mu Lee. Deep multi-scale convolutional neural network for dynamic scene deblurring. In *CVPR*, 2017.
101 |
102 | [3] Orest Kupyn, Volodymyr Budzan, Mykola Mykhailych, Dmytro Mishkin, and Jiri Matas. Deblurgan: Blind motion deblurring using conditional adversarial networks. In CVPR, 2018.
103 |
104 | [4] Jiawei Zhang, Jinshan Pan, Jimmy Ren, Yibing Song, Lin- chao Bao, Rynson WH Lau, and Ming-Hsuan Yang. Dynamic scene deblurring using spatially variant recurrent neural networks. In *CVPR*, 2018.
105 |
106 | [5] Xin Tao, Hongyun Gao, Xiaoyong Shen, Jue Wang, and Jiaya Jia. Scale-recurrent network for deep image deblurring. In *CVPR*, 2018.
107 |
108 | ## License
109 |
110 | This project is open sourced under MIT license.
111 |
--------------------------------------------------------------------------------
/ckpt/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sczhou/DAVANet/2bbe35ae01c0f1af718a1bc19272cda5ed3c320a/ckpt/.gitkeep
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Developed by Shangchen Zhou
5 |
6 | from easydict import EasyDict as edict
7 | import socket
8 |
9 | __C = edict()
10 | cfg = __C
11 |
12 | #
13 | # Common
14 | #
15 | __C.CONST = edict()
16 | __C.CONST.DEVICE = 'all' # '0'
17 | __C.CONST.NUM_WORKER = 1 # number of data workers
18 | __C.CONST.WEIGHTS = '/data/code/StereodeblurNet-release/ckpt/best-ckpt.pth.tar'
19 | __C.CONST.TRAIN_BATCH_SIZE = 1
20 | __C.CONST.TEST_BATCH_SIZE = 1
21 |
22 |
23 | #
24 | # Dataset
25 | #
26 | __C.DATASET = edict()
27 | __C.DATASET.DATASET_NAME = 'StereoDeblur' # FlyingThings3D, StereoDeblur
28 | __C.DATASET.WITH_MASK = True
29 |
30 | if cfg.DATASET.DATASET_NAME == 'StereoDeblur':
31 | __C.DATASET.SPARSE = True
32 | else:
33 | __C.DATASET.SPARSE = False
34 |
35 | #
36 | # Directories
37 | #
38 | __C.DIR = edict()
39 | __C.DIR.OUT_PATH = '/data/code/StereodeblurNet/output'
40 |
41 | # For FlyingThings3D Dataset
42 | if cfg.DATASET.DATASET_NAME == 'FlyingThings3D':
43 | __C.DIR.DATASET_JSON_FILE_PATH = './datasets/flyingthings3d.json'
44 | __C.DIR.DATASET_ROOT = '/data/scene_flow/FlyingThings3D/'
45 | __C.DIR.IMAGE_LEFT_PATH = __C.DIR.DATASET_ROOT + '%s/%s/%s/%s/left/%s.png'
46 | __C.DIR.IMAGE_RIGHT_PATH = __C.DIR.DATASET_ROOT + '%s/%s/%s/%s/right/%s.png'
47 | __C.DIR.DISPARITY_LEFT_PATH = __C.DIR.DATASET_ROOT + 'disparity/%s/%s/%s/left/%s.pfm'
48 | __C.DIR.DISPARITY_RIGHT_PATH = __C.DIR.DATASET_ROOT + 'disparity/%s/%s/%s/right/%s.pfm'
49 |
50 | # For Stereo_Blur_Dataset
51 | elif cfg.DATASET.DATASET_NAME == 'StereoDeblur':
52 | __C.DIR.DATASET_JSON_FILE_PATH = './datasets/stereo_deblur_data.json'
53 | __C.DIR.DATASET_ROOT = '/data1/stereo_deblur_data_final_gamma/'
54 | __C.DIR.IMAGE_LEFT_BLUR_PATH = __C.DIR.DATASET_ROOT + '%s/image_left_blur_ga/%s.png'
55 | __C.DIR.IMAGE_LEFT_CLEAR_PATH = __C.DIR.DATASET_ROOT + '%s/image_left/%s.png'
56 | __C.DIR.IMAGE_RIGHT_BLUR_PATH = __C.DIR.DATASET_ROOT + '%s/image_right_blur_ga/%s.png'
57 | __C.DIR.IMAGE_RIGHT_CLEAR_PATH = __C.DIR.DATASET_ROOT + '%s/image_right/%s.png'
58 | __C.DIR.DISPARITY_LEFT_PATH = __C.DIR.DATASET_ROOT + '%s/disparity_left/%s.exr'
59 | __C.DIR.DISPARITY_RIGHT_PATH = __C.DIR.DATASET_ROOT + '%s/disparity_right/%s.exr'
60 |
61 | #
62 | # data augmentation
63 | #
64 | __C.DATA = edict()
65 | __C.DATA.STD = [255.0, 255.0, 255.0]
66 | __C.DATA.MEAN = [0.0, 0.0, 0.0]
67 | __C.DATA.DIV_DISP = 40.0 # 40.0 for disparity
68 | __C.DATA.CROP_IMG_SIZE = [256, 256] # Crop image size: height, width
69 | __C.DATA.GAUSSIAN = [0, 1e-4] # mu, std_var
70 | __C.DATA.COLOR_JITTER = [0.2, 0.15, 0.3, 0.1] # brightness, contrast, saturation, hue
71 |
72 | #
73 | # Network
74 | #
75 | __C.NETWORK = edict()
76 | __C.NETWORK.DISPNETARCH = 'DispNet_Bi' # available options: DispNet_Bi
77 | __C.NETWORK.DEBLURNETARCH = 'StereoDeblurNet' # available options: DeblurNet, StereoDeblurNet
78 | __C.NETWORK.LEAKY_VALUE = 0.1
79 | __C.NETWORK.BATCHNORM = False
80 | __C.NETWORK.PHASE = 'train' # available options: 'train', 'test', 'resume'
81 | __C.NETWORK.MODULE = 'all' # available options: 'dispnet', 'deblurnet', 'all'
82 | #
83 | # Training
84 | #
85 |
86 | __C.TRAIN = edict()
87 | __C.TRAIN.USE_PERCET_LOSS = True
88 | __C.TRAIN.NUM_EPOCHES = 400 # maximum number of epoches
89 | __C.TRAIN.BRIGHTNESS = .25
90 | __C.TRAIN.CONTRAST = .25
91 | __C.TRAIN.SATURATION = .25
92 | __C.TRAIN.HUE = .25
93 | __C.TRAIN.DISPNET_LEARNING_RATE = 1e-6
94 | __C.TRAIN.DEBLURNET_LEARNING_RATE = 1e-4
95 | __C.TRAIN.DISPNET_LR_MILESTONES = [100,200,300]
96 | __C.TRAIN.DEBLURNET_LR_MILESTONES = [80,160,240]
97 | __C.TRAIN.LEARNING_RATE_DECAY = 0.1 # Multiplicative factor of learning rate decay
98 | __C.TRAIN.MOMENTUM = 0.9
99 | __C.TRAIN.BETA = 0.999
100 | __C.TRAIN.BIAS_DECAY = 0.0 # regularization of bias, default: 0
101 | __C.TRAIN.WEIGHT_DECAY = 0.0 # regularization of weight, default: 0
102 | __C.TRAIN.PRINT_FREQ = 10
103 | __C.TRAIN.SAVE_FREQ = 5 # weights will be overwritten every save_freq epoch
104 |
105 | __C.LOSS = edict()
106 | __C.LOSS.MULTISCALE_WEIGHTS = [0.3, 0.3, 0.2, 0.1, 0.1]
107 |
108 | #
109 | # Testing options
110 | #
111 | __C.TEST = edict()
112 | __C.TEST.VISUALIZATION_NUM = 3
113 | __C.TEST.PRINT_FREQ = 5
114 | if __C.NETWORK.PHASE == 'test':
115 | __C.CONST.TEST_BATCH_SIZE = 1
116 |
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sczhou/DAVANet/2bbe35ae01c0f1af718a1bc19272cda5ed3c320a/core/__init__.py
--------------------------------------------------------------------------------
/core/build.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Developed by Shangchen Zhou
5 |
6 | import os
7 | import sys
8 | import torch.backends.cudnn
9 | import torch.utils.data
10 |
11 | import utils.data_loaders
12 | import utils.data_transforms
13 | import utils.network_utils
14 | import models
15 | from models.DispNet_Bi import DispNet_Bi
16 | from models.DeblurNet import DeblurNet
17 | from models.StereoDeblurNet import StereoDeblurNet
18 |
19 | from datetime import datetime as dt
20 | from tensorboardX import SummaryWriter
21 | from core.train_disp import train_dispnet
22 | from core.test_disp import test_dispnet
23 | from core.train_deblur import train_deblurnet
24 | from core.test_deblur import test_deblurnet
25 | from core.train_stereodeblur import train_stereodeblurnet
26 | from core.test_stereodeblur import test_stereodeblurnet
27 | from losses.multiscaleloss import *
28 |
29 | def bulid_net(cfg):
30 |
31 | # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
32 | torch.backends.cudnn.benchmark = True
33 |
34 | # Set up data augmentation
35 | train_transforms = utils.data_transforms.Compose([
36 | utils.data_transforms.ColorJitter(cfg.DATA.COLOR_JITTER),
37 | utils.data_transforms.Normalize(mean=cfg.DATA.MEAN, std=cfg.DATA.STD, div_disp=cfg.DATA.DIV_DISP),
38 | utils.data_transforms.RandomCrop(cfg.DATA.CROP_IMG_SIZE),
39 | utils.data_transforms.RandomVerticalFlip(),
40 | utils.data_transforms.RandomColorChannel(),
41 | utils.data_transforms.RandomGaussianNoise(cfg.DATA.GAUSSIAN),
42 | utils.data_transforms.ToTensor(),
43 | ])
44 |
45 | test_transforms = utils.data_transforms.Compose([
46 | utils.data_transforms.Normalize(mean=cfg.DATA.MEAN, std=cfg.DATA.STD, div_disp=cfg.DATA.DIV_DISP),
47 | utils.data_transforms.ToTensor(),
48 | ])
49 |
50 | # Set up data loader
51 | dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.DATASET_NAME]()
52 | if cfg.NETWORK.PHASE in ['train', 'resume']:
53 | train_data_loader = torch.utils.data.DataLoader(
54 | dataset=dataset_loader.get_dataset(utils.data_loaders.DatasetType.TRAIN, train_transforms),
55 | batch_size=cfg.CONST.TRAIN_BATCH_SIZE,
56 | num_workers=cfg.CONST.NUM_WORKER, pin_memory=True, shuffle=True)
57 |
58 | test_data_loader = torch.utils.data.DataLoader(
59 | dataset=dataset_loader.get_dataset(utils.data_loaders.DatasetType.TEST, test_transforms),
60 | batch_size=cfg.CONST.TEST_BATCH_SIZE,
61 | num_workers=cfg.CONST.NUM_WORKER, pin_memory=True, shuffle=False)
62 |
63 | # Set up networks
64 | dispnet = models.__dict__[cfg.NETWORK.DISPNETARCH].__dict__[cfg.NETWORK.DISPNETARCH]()
65 | deblurnet = models.__dict__[cfg.NETWORK.DEBLURNETARCH].__dict__[cfg.NETWORK.DEBLURNETARCH]()
66 |
67 |
68 | print('[DEBUG] %s Parameters in %s: %d.' % (dt.now(), cfg.NETWORK.DISPNETARCH,
69 | utils.network_utils.count_parameters(dispnet)))
70 |
71 | print('[DEBUG] %s Parameters in %s: %d.' % (dt.now(), cfg.NETWORK.DEBLURNETARCH,
72 | utils.network_utils.count_parameters(deblurnet)))
73 |
74 | # Initialize weights of networks
75 | dispnet.apply(utils.network_utils.init_weights_kaiming)
76 | deblurnet.apply(utils.network_utils.init_weights_xavier)
77 | # Set up solver
78 | dispnet_solver = torch.optim.Adam(filter(lambda p: p.requires_grad, dispnet.parameters()), lr=cfg.TRAIN.DISPNET_LEARNING_RATE,
79 | betas=(cfg.TRAIN.MOMENTUM, cfg.TRAIN.BETA))
80 | deblurnet_solver = torch.optim.Adam(filter(lambda p: p.requires_grad, deblurnet.parameters()), lr=cfg.TRAIN.DEBLURNET_LEARNING_RATE,
81 | betas=(cfg.TRAIN.MOMENTUM, cfg.TRAIN.BETA))
82 |
83 | if torch.cuda.is_available():
84 | dispnet = torch.nn.DataParallel(dispnet).cuda()
85 | deblurnet = torch.nn.DataParallel(deblurnet).cuda()
86 |
87 | # Load pretrained model if exists
88 | init_epoch = 0
89 | Best_Epoch = -1
90 | Best_Disp_EPE = float('Inf')
91 | Best_Img_PSNR = 0
92 | if cfg.NETWORK.PHASE in ['test', 'resume']:
93 | print('[INFO] %s Recovering from %s ...' % (dt.now(), cfg.CONST.WEIGHTS))
94 | checkpoint = torch.load(cfg.CONST.WEIGHTS)
95 |
96 | if cfg.NETWORK.MODULE == 'dispnet':
97 | dispnet.load_state_dict(checkpoint['dispnet_state_dict'])
98 | init_epoch = checkpoint['epoch_idx']+1
99 | Best_Disp_EPE = checkpoint['Best_Disp_EPE']
100 | Best_Epoch = checkpoint['Best_Epoch']
101 | dispnet_solver.load_state_dict(checkpoint['dispnet_solver_state_dict'])
102 | print('[INFO] {0} Recover complete. Current epoch #{1}, Best_Disp_EPE = {2} at epoch #{3}.' \
103 | .format(dt.now(), init_epoch, Best_Disp_EPE, Best_Epoch))
104 | elif cfg.NETWORK.MODULE == 'deblurnet':
105 | deblurnet.load_state_dict(checkpoint['deblurnet_state_dict'])
106 | init_epoch = checkpoint['epoch_idx']+1
107 | Best_Img_PSNR = checkpoint['Best_Img_PSNR']
108 | Best_Epoch = checkpoint['Best_Epoch']
109 | deblurnet_solver.load_state_dict(checkpoint['deblurnet_solver_state_dict'])
110 | print('[INFO] {0} Recover complete. Current epoch #{1}, Best_Img_PSNR = {2} at epoch #{3}.' \
111 | .format(dt.now(), init_epoch, Best_Img_PSNR, Best_Epoch))
112 | init_epoch = 0
113 | elif cfg.NETWORK.MODULE == 'all':
114 | Best_Img_PSNR = checkpoint['Best_Img_PSNR']
115 | dispnet.load_state_dict(checkpoint['dispnet_state_dict'])
116 | deblurnet.load_state_dict(checkpoint['deblurnet_state_dict'])
117 | print('[INFO] {0} Recover complete. Best_Img_PSNR = {1}'.format(dt.now(), Best_Img_PSNR))
118 |
119 |
120 | # Set up learning rate scheduler to decay learning rates dynamically
121 | dispnet_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(dispnet_solver,
122 | milestones=cfg.TRAIN.DISPNET_LR_MILESTONES,
123 | gamma=cfg.TRAIN.LEARNING_RATE_DECAY)
124 | deblurnet_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(deblurnet_solver,
125 | milestones=cfg.TRAIN.DEBLURNET_LR_MILESTONES,
126 | gamma=cfg.TRAIN.LEARNING_RATE_DECAY)
127 |
128 | # Summary writer for TensorBoard
129 | if cfg.NETWORK.MODULE == 'dispnet':
130 | output_dir = os.path.join(cfg.DIR.OUT_PATH, dt.now().isoformat()+'_'+cfg.NETWORK.DISPNETARCH, '%s')
131 | elif cfg.NETWORK.MODULE == 'deblurnet':
132 | output_dir = os.path.join(cfg.DIR.OUT_PATH, dt.now().isoformat()+'_'+cfg.NETWORK.DEBLURNETARCH, '%s')
133 | elif cfg.NETWORK.MODULE == 'all':
134 | output_dir = os.path.join(cfg.DIR.OUT_PATH, dt.now().isoformat() + '_' + cfg.NETWORK.DEBLURNETARCH, '%s')
135 | log_dir = output_dir % 'logs'
136 | ckpt_dir = output_dir % 'checkpoints'
137 | train_writer = SummaryWriter(os.path.join(log_dir, 'train'))
138 | test_writer = SummaryWriter(os.path.join(log_dir, 'test'))
139 |
140 |
141 | if cfg.NETWORK.PHASE in ['train', 'resume']:
142 | # train and val
143 | if cfg.NETWORK.MODULE == 'dispnet':
144 | train_dispnet(cfg, init_epoch, train_data_loader, test_data_loader, dispnet, dispnet_solver,
145 | dispnet_lr_scheduler, ckpt_dir, train_writer, test_writer, Best_Disp_EPE, Best_Epoch)
146 | return
147 | elif cfg.NETWORK.MODULE == 'deblurnet':
148 | train_deblurnet(cfg, init_epoch, train_data_loader, test_data_loader, deblurnet, deblurnet_solver,
149 | deblurnet_lr_scheduler, ckpt_dir, train_writer, test_writer, Best_Img_PSNR, Best_Epoch)
150 | return
151 | elif cfg.NETWORK.MODULE == 'all':
152 | train_stereodeblurnet(cfg, init_epoch, train_data_loader, test_data_loader,
153 | dispnet, dispnet_solver, dispnet_lr_scheduler,
154 | deblurnet, deblurnet_solver, deblurnet_lr_scheduler,
155 | ckpt_dir, train_writer, test_writer,
156 | Best_Disp_EPE, Best_Img_PSNR, Best_Epoch)
157 |
158 | else:
159 | assert os.path.exists(cfg.CONST.WEIGHTS),'[FATAL] Please specify the file path of checkpoint!'
160 | if cfg.NETWORK.MODULE == 'dispnet':
161 | test_dispnet(cfg, init_epoch, test_data_loader, dispnet, test_writer)
162 | return
163 | elif cfg.NETWORK.MODULE == 'deblurnet':
164 | test_deblurnet(cfg, init_epoch, test_data_loader, deblurnet, test_writer)
165 | return
166 | elif cfg.NETWORK.MODULE == 'all':
167 | test_stereodeblurnet(cfg, init_epoch, test_data_loader, dispnet, deblurnet, test_writer)
168 |
--------------------------------------------------------------------------------
/core/test_deblur.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Developed by Shangchen Zhou
5 | import os
6 | import sys
7 | import torch.backends.cudnn
8 | import torch.utils.data
9 | import numpy as np
10 | import utils.data_loaders
11 | import utils.data_transforms
12 | import utils.network_utils
13 | from losses.multiscaleloss import *
14 | from time import time
15 | import cv2
16 |
17 | def mkdir(path):
18 | if not os.path.isdir(path):
19 | mkdir(os.path.split(path)[0])
20 | else:
21 | return
22 | os.mkdir(path)
23 |
24 | def test_deblurnet(cfg, epoch_idx, test_data_loader, deblurnet, test_writer):
25 |
26 | # Testing loop
27 | n_batches = len(test_data_loader)
28 | test_epe = dict()
29 | # Batch average meterics
30 | batch_time = utils.network_utils.AverageMeter()
31 | data_time = utils.network_utils.AverageMeter()
32 | img_PSNRs = utils.network_utils.AverageMeter()
33 | batch_end_time = time()
34 |
35 | test_psnr = dict()
36 | g_names= 'init'
37 | save_num = 0
38 | for batch_idx, (names, images, DISPs, OCCs) in enumerate(test_data_loader):
39 | data_time.update(time() - batch_end_time)
40 | if not g_names == names:
41 | g_names = names
42 | save_num = 0
43 | save_num = save_num+1
44 | # Switch models to testing mode
45 | deblurnet.eval()
46 |
47 | if cfg.NETWORK.PHASE == 'test':
48 | assert (len(names) == 1)
49 | name = names[0]
50 | if not name in test_psnr:
51 | test_psnr[name] = {
52 | 'n_samples': 0,
53 | 'psnr': []
54 | }
55 |
56 | with torch.no_grad():
57 | # Get data from data loader
58 | imgs = [utils.network_utils.var_or_cuda(img) for img in images]
59 | img_blur_left, img_blur_right, img_clear_left, img_clear_right = imgs
60 |
61 | # Test the decoder
62 | output_img_clear_left = deblurnet(img_blur_left)
63 | output_img_clear_right = deblurnet(img_blur_right)
64 |
65 | # Append loss and accuracy to average metrics
66 | img_PSNR = PSNR(output_img_clear_left, img_clear_left) / 2 + PSNR(output_img_clear_right, img_clear_right) / 2
67 | img_PSNRs.update(img_PSNR.item(), cfg.CONST.TEST_BATCH_SIZE)
68 |
69 | if cfg.NETWORK.PHASE == 'test':
70 | test_psnr[name]['n_samples'] += 1
71 | test_psnr[name]['psnr'].append(img_PSNR)
72 |
73 | batch_time.update(time() - batch_end_time)
74 | batch_end_time = time()
75 |
76 | # Print result
77 | if (batch_idx+1) % cfg.TEST.PRINT_FREQ == 0:
78 | print('[TEST] [Epoch {0}/{1}][Batch {2}/{3}]\t BatchTime {4}\t DataTime {5}\t\t ImgPSNR {6}'
79 | .format(epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, batch_idx + 1, n_batches, batch_time, data_time, img_PSNRs))
80 |
81 | if batch_idx < cfg.TEST.VISUALIZATION_NUM:
82 | if epoch_idx == 0 or cfg.NETWORK.PHASE in ['test', 'resume']:
83 | test_writer.add_image('DeblurNet/IMG_BLUR_LEFT'+str(batch_idx+1),
84 | images[0][0][[2,1,0],:,:] + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1), epoch_idx+1)
85 | test_writer.add_image('DeblurNet/IMG_BLUR_RIGHT'+str(batch_idx+1),
86 | images[1][0][[2,1,0],:,:] + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1), epoch_idx+1)
87 | test_writer.add_image('DeblurNet/IMG_CLEAR_LEFT' + str(batch_idx + 1),
88 | images[2][0][[2,1,0],:,:] + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1), epoch_idx + 1)
89 | test_writer.add_image('DeblurNet/IMG_CLEAR_RIGHT' + str(batch_idx + 1),
90 | images[3][0][[2,1,0],:,:] + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1), epoch_idx + 1)
91 |
92 | test_writer.add_image('DeblurNet/OUT_IMG_CLEAR_LEFT'+str(batch_idx+1), output_img_clear_left[0][[2,1,0],:,:].cpu().clamp(0.0,1.0) + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1), epoch_idx+1)
93 | test_writer.add_image('DeblurNet/OUT_IMG_CLEAR_RIGHT'+str(batch_idx+1), output_img_clear_right[0][[2,1,0],:,:].cpu().clamp(0.0,1.0) + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1), epoch_idx+1)
94 |
95 | if cfg.NETWORK.PHASE == 'test':
96 | left_out_dir = os.path.join(cfg.DIR.OUT_PATH,'single',names[0],'left')
97 | right_out_dir = os.path.join(cfg.DIR.OUT_PATH,'single',names[0],'right')
98 | if not os.path.isdir(left_out_dir):
99 | mkdir(left_out_dir)
100 | if not os.path.isdir(right_out_dir):
101 | mkdir(right_out_dir)
102 | print(left_out_dir+'/'+str(save_num).zfill(4)+'.png')
103 | cv2.imwrite(left_out_dir+'/'+str(save_num).zfill(4)+'.png', (output_img_clear_left.clamp(0.0, 1.0)[0].cpu().numpy().transpose(1, 2, 0) * 255.0).astype(np.uint8),
104 | [int(cv2.IMWRITE_PNG_COMPRESSION), 5])
105 |
106 | cv2.imwrite(right_out_dir + '/' + str(save_num).zfill(4) + '.png',
107 | (output_img_clear_right.clamp(0.0, 1.0)[0].cpu().numpy().transpose(1, 2, 0) * 255.0).astype(np.uint8),[int(cv2.IMWRITE_PNG_COMPRESSION), 5])
108 |
109 | if cfg.NETWORK.PHASE == 'test':
110 |
111 | # Output test results
112 | print('============================ TEST RESULTS ============================')
113 | print('[TEST] Total_Mean_PSNR:' + str(img_PSNRs.avg))
114 | for name in test_psnr:
115 | test_psnr[name]['psnr'] = np.mean(test_psnr[name]['psnr'], axis=0)
116 | print('[TEST] Name: {0}\t Num: {1}\t Mean_PSNR: {2}'.format(name, test_psnr[name]['n_samples'],
117 | test_psnr[name]['psnr']))
118 |
119 | result_file = open(os.path.join(cfg.DIR.OUT_PATH, 'test_result.txt'), 'w')
120 | sys.stdout = result_file
121 | print('============================ TEST RESULTS ============================')
122 | print('[TEST] Total_Mean_PSNR:' + str(img_PSNRs.avg))
123 | for name in test_psnr:
124 | print('[TEST] Name: {0}\t Num: {1}\t Mean_PSNR: {2}'.format(name, test_psnr[name]['n_samples'],
125 | test_psnr[name]['psnr']))
126 | result_file.close()
127 | else:
128 | # Output val results
129 | print('============================ TEST RESULTS ============================')
130 | print('[TEST] Total_Mean_PSNR:' + str(img_PSNRs.avg))
131 | print('[TEST] [Epoch{0}]\t BatchTime_avg {1}\t DataTime_avg {2}\t ImgPSNR_avg {3}\n'
132 | .format(cfg.TRAIN.NUM_EPOCHES, batch_time.avg, data_time.avg, img_PSNRs.avg))
133 |
134 | # Add testing results to TensorBoard
135 | test_writer.add_scalar('DeblurNet/EpochPSNR_1_TEST', img_PSNRs.avg, epoch_idx + 1)
136 |
137 | return img_PSNRs.avg
--------------------------------------------------------------------------------
/core/test_disp.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Developed by Shangchen Zhou
5 |
6 | import torch.backends.cudnn
7 | import torch.utils.data
8 |
9 | import utils.data_loaders
10 | import utils.data_transforms
11 | import utils.network_utils
12 | from losses.multiscaleloss import *
13 | import torchvision
14 |
15 | from time import time
16 |
17 | def test_dispnet(cfg, epoch_idx, test_data_loader, dispnet, test_writer):
18 |
19 | # Testing loop
20 | n_batches = len(test_data_loader)
21 | test_epe = dict()
22 | # Batch average meterics
23 | batch_time = utils.network_utils.AverageMeter()
24 | data_time = utils.network_utils.AverageMeter()
25 | disp_EPEs = utils.network_utils.AverageMeter()
26 | test_time = utils.network_utils.AverageMeter()
27 |
28 | batch_end_time = time()
29 | for batch_idx, (_, images, disps, occs) in enumerate(test_data_loader):
30 | data_time.update(time() - batch_end_time)
31 |
32 | # Switch models to testing mode
33 | dispnet.eval();
34 |
35 | with torch.no_grad():
36 | # Get data from data loader
37 | disparities = disps
38 | imgs = [utils.network_utils.var_or_cuda(img) for img in images]
39 | imgs = torch.cat(imgs, 1)
40 | ground_truth_disps = [utils.network_utils.var_or_cuda(disp) for disp in disparities]
41 | ground_truth_disps = torch.cat(ground_truth_disps, 1)
42 | occs = [utils.network_utils.var_or_cuda(occ) for occ in occs]
43 | occs = torch.cat(occs, 1)
44 |
45 | # Test the decoder
46 | torch.cuda.synchronize()
47 | test_time_start = time()
48 | output_disps = dispnet(imgs)
49 | torch.cuda.synchronize()
50 | test_time.update(time() - test_time_start)
51 | print('[TIME] {0}'.format(test_time))
52 |
53 | disp_EPE = cfg.DATA.DIV_DISP * realEPE(output_disps, ground_truth_disps, occs)
54 |
55 | # Append loss and accuracy to average metrics
56 | disp_EPEs.update(disp_EPE.item(), cfg.CONST.TEST_BATCH_SIZE)
57 |
58 | batch_time.update(time() - batch_end_time)
59 | batch_end_time = time()
60 |
61 | # Print result
62 | if (batch_idx+1) % cfg.TEST.PRINT_FREQ == 0:
63 | print('[TEST] [Epoch {0}/{1}][Batch {2}/{3}]\t BatchTime {4}\t DataTime {5}\t\t DispEPE {6}'
64 | .format(epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, batch_idx + 1, n_batches, batch_time, data_time, disp_EPEs))
65 |
66 | if batch_idx < cfg.TEST.VISUALIZATION_NUM:
67 |
68 | if epoch_idx == 0 or cfg.NETWORK.PHASE in ['test', 'resume']:
69 | test_writer.add_image('DispNet/IMG_LEFT'+str(batch_idx+1),
70 | images[0][0][[2,1,0],:,:] + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1), epoch_idx+1)
71 | test_writer.add_image('DispNet/IMG_RIGHT'+str(batch_idx+1),
72 | images[1][0][[2,1,0],:,:] + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1), epoch_idx+1)
73 |
74 | gt_disp_left, gt_disp_right = utils.network_utils.graybi2rgb(ground_truth_disps[0])
75 | test_writer.add_image('DispNet/DISP_GT_LEFT' +str(batch_idx + 1), gt_disp_left, epoch_idx + 1)
76 | test_writer.add_image('DispNet/DISP_GT_RIGHT'+str(batch_idx+1), gt_disp_right, epoch_idx+1)
77 |
78 | b, _, h, w = imgs.size()
79 | output_disps_up = torch.nn.functional.interpolate(output_disps, size=(h, w), mode = 'bilinear', align_corners=True)
80 | output_disp_up_left, output_disp_up_right = utils.network_utils.graybi2rgb(output_disps_up[0])
81 | test_writer.add_image('DispNet/DISP_OUT_LEFT_'+str(batch_idx+1), output_disp_up_left, epoch_idx+1)
82 | test_writer.add_image('DispNet/DISP_OUT_RIGHT_'+str(batch_idx+1), output_disp_up_right, epoch_idx+1)
83 |
84 |
85 |
86 | # Output testing results
87 | print('============================ TEST RESULTS ============================')
88 | print('[TEST] [Epoch{0}]\t BatchTime_avg {1}\t DataTime_avg {2}\t DispEPE_avg {3}\n'
89 | .format(cfg.TRAIN.NUM_EPOCHES, batch_time.avg, data_time.avg, disp_EPEs.avg))
90 |
91 | # Add testing results to TensorBoard
92 | test_writer.add_scalar('DispNet/EpochEPE_1_TEST', disp_EPEs.avg, epoch_idx+1)
93 | return disp_EPEs.avg
94 |
--------------------------------------------------------------------------------
/core/test_stereodeblur.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Developed by Shangchen Zhou
5 |
6 | import torch.backends.cudnn
7 | import torch.utils.data
8 |
9 | import utils.data_loaders
10 | import utils.data_transforms
11 | import utils.network_utils
12 | from losses.multiscaleloss import *
13 | import torchvision
14 |
15 | from time import time
16 |
17 | def mkdir(path):
18 | if not os.path.isdir(path):
19 | mkdir(os.path.split(path)[0])
20 | else:
21 | return
22 | os.mkdir(path)
23 |
24 | def test_stereodeblurnet(cfg, epoch_idx, test_data_loader, dispnet, deblurnet, test_writer):
25 |
26 | # Testing loop
27 | n_batches = len(test_data_loader)
28 | # Batch average meterics
29 | batch_time = utils.network_utils.AverageMeter()
30 | test_time = utils.network_utils.AverageMeter()
31 | data_time = utils.network_utils.AverageMeter()
32 | disp_EPEs = utils.network_utils.AverageMeter()
33 | img_PSNRs = utils.network_utils.AverageMeter()
34 | batch_end_time = time()
35 | test_psnr = dict()
36 | g_names= 'init'
37 | save_num = 0
38 |
39 | for batch_idx, (names, images, disps, occs) in enumerate(test_data_loader):
40 | data_time.update(time() - batch_end_time)
41 | if not g_names == names:
42 | g_names = names
43 | save_num = 0
44 | save_num = save_num+1
45 | # Switch models to testing mode
46 | dispnet.eval()
47 | deblurnet.eval()
48 | if cfg.NETWORK.PHASE == 'test':
49 | assert (len(names) == 1)
50 | name = names[0]
51 | if not name in test_psnr:
52 | test_psnr[name] = {
53 | 'n_samples': 0,
54 | 'psnr': []
55 | }
56 |
57 | with torch.no_grad():
58 | # Get data from data loader
59 | disparities = disps
60 | imgs = [utils.network_utils.var_or_cuda(img) for img in images]
61 | img_blur_left, img_blur_right, img_clear_left, img_clear_right = imgs
62 | imgs_blur = torch.cat([img_blur_left, img_blur_right], 1)
63 |
64 | ground_truth_disps = [utils.network_utils.var_or_cuda(disp) for disp in disparities]
65 | ground_truth_disps = torch.cat(ground_truth_disps, 1)
66 | occs = [utils.network_utils.var_or_cuda(occ) for occ in occs]
67 | occs = torch.cat(occs, 1)
68 |
69 | # Test the dispnet
70 | # torch.cuda.synchronize()
71 | # test_time_start = time()
72 |
73 | output_disps = dispnet(imgs_blur)
74 |
75 | output_disp_feature = output_disps[1]
76 | output_disps = output_disps[0]
77 |
78 | # Test the deblurnet
79 | imgs_prd, output_diffs, output_masks= deblurnet(imgs_blur, output_disps, output_disp_feature)
80 |
81 | # torch.cuda.synchronize()
82 | # test_time.update(time() - test_time_start)
83 | # print('[TIME] {0}'.format(test_time))
84 | disp_EPE = cfg.DATA.DIV_DISP * realEPE(output_disps, ground_truth_disps, occs)
85 | disp_EPEs.update(disp_EPE.item(), cfg.CONST.TEST_BATCH_SIZE)
86 |
87 | img_PSNR = (PSNR(imgs_prd[0], img_clear_left) + PSNR(imgs_prd[1], img_clear_right)) / 2
88 | img_PSNRs.update(img_PSNR.item(), cfg.CONST.TRAIN_BATCH_SIZE)
89 |
90 | if cfg.NETWORK.PHASE == 'test':
91 | test_psnr[name]['n_samples'] += 1
92 | test_psnr[name]['psnr'].append(img_PSNR)
93 |
94 | batch_time.update(time() - batch_end_time)
95 | batch_end_time = time()
96 |
97 | # Print result
98 | if (batch_idx+1) % cfg.TEST.PRINT_FREQ == 0:
99 | print('[TEST] [Epoch {0}/{1}][Batch {2}/{3}]\t BatchTime {4}\t DataTime {5}\t DispEPE {6}\t imgPSNR {7}'
100 | .format(epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, batch_idx + 1, n_batches, batch_time, data_time, disp_EPEs, img_PSNRs))
101 |
102 | if batch_idx < cfg.TEST.VISUALIZATION_NUM and cfg.NETWORK.PHASE in ['train', 'resume']:
103 |
104 |
105 | if epoch_idx == 0 or cfg.NETWORK.PHASE in ['test', 'resume']:
106 | img_blur_left = images[0][0][[2, 1, 0], :, :] + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
107 | img_blur_right = images[1][0][[2, 1, 0], :, :] + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
108 | img_clear_left = images[2][0][[2, 1, 0], :, :] + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
109 | img_clear_right = images[3][0][[2, 1, 0], :, :] + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
110 | test_writer.add_image('StereoDeblurNet/IMG_BLUR' + str(batch_idx + 1),
111 | torch.cat([img_blur_left, img_blur_right], 2), epoch_idx + 1)
112 |
113 | test_writer.add_image('StereoDeblurNet/IMG_CLEAR' + str(batch_idx + 1),
114 | torch.cat([img_clear_left, img_clear_right], 2), epoch_idx + 1)
115 |
116 | gt_disp_left, gt_disp_right = utils.network_utils.graybi2rgb(ground_truth_disps[0])
117 | test_writer.add_image('StereoDeblurNet/DISP_GT' +str(batch_idx + 1),
118 | torch.cat([gt_disp_left, gt_disp_right], 2), epoch_idx + 1)
119 |
120 | b, _, h, w = imgs[0].size()
121 | diff_out_left, diff_out_right = utils.network_utils.graybi2rgb(torch.cat(output_diffs, 1)[0])
122 | output_masks = torch.nn.functional.interpolate(torch.cat(output_masks, 1), size=(h, w), mode='bilinear', align_corners=True)
123 | mask_out_left, mask_out_right = utils.network_utils.graybi2rgb(output_masks[0])
124 | disp_out_left, disp_out_right = utils.network_utils.graybi2rgb(output_disps[0])
125 | img_out_left = imgs_prd[0][0][[2, 1, 0], :, :].cpu().clamp(0.0, 1.0) + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
126 | img_out_right = imgs_prd[1][0][[2, 1, 0], :, :].cpu().clamp(0.0, 1.0) + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
127 | test_writer.add_image('StereoDeblurNet/IMG_OUT' + str(batch_idx + 1), torch.cat([img_out_left, img_out_right], 2), epoch_idx + 1)
128 | test_writer.add_image('StereoDeblurNet/DISP_OUT'+str(batch_idx+1), torch.cat([disp_out_left, disp_out_right], 2), epoch_idx+1)
129 | test_writer.add_image('StereoDeblurNet/DIFF_OUT'+str(batch_idx+1), torch.cat([diff_out_left, diff_out_right], 2), epoch_idx+1)
130 | test_writer.add_image('StereoDeblurNet/MAST_OUT'+str(batch_idx+1), torch.cat([mask_out_left, mask_out_right], 2), epoch_idx+1)
131 | if cfg.NETWORK.PHASE == 'test':
132 | img_left_dir = os.path.join(cfg.DIR.OUT_PATH,'stereo',names[0],'left')
133 | img_right_dir = os.path.join(cfg.DIR.OUT_PATH,'stereo',names[0],'right')
134 |
135 | if not os.path.isdir(img_left_dir):
136 | mkdir(img_left_dir)
137 | if not os.path.isdir(img_right_dir):
138 | mkdir(img_right_dir)
139 |
140 | print(img_left_dir + '/' + str(save_num).zfill(4) + '.png')
141 | cv2.imwrite(img_left_dir + '/' + str(save_num).zfill(4) + '.png',
142 | (imgs_prd[0].clamp(0.0, 1.0)[0].cpu().numpy().transpose(1, 2, 0) * 255.0).astype(
143 | np.uint8),
144 | [int(cv2.IMWRITE_PNG_COMPRESSION), 5])
145 |
146 | print(img_right_dir + '/' + str(save_num).zfill(4) + '.png')
147 | cv2.imwrite(img_right_dir + '/' + str(save_num).zfill(4) + '.png',
148 | (imgs_prd[1].clamp(0.0, 1.0)[0].cpu().numpy().transpose(1, 2, 0) * 255.0).astype(
149 | np.uint8), [int(cv2.IMWRITE_PNG_COMPRESSION), 5])
150 |
151 | # Output testing results
152 |
153 | if cfg.NETWORK.PHASE == 'test':
154 | # Output test results
155 | print('============================ TEST RESULTS ============================')
156 | print('[TEST] Total_Mean_PSNR:' + str(img_PSNRs.avg))
157 | for name in test_psnr:
158 | test_psnr[name]['psnr'] = np.mean(test_psnr[name]['psnr'], axis=0)
159 | print('[TEST] Name: {0}\t Num: {1}\t Mean_PSNR: {2}'.format(name, test_psnr[name]['n_samples'],
160 | test_psnr[name]['psnr']))
161 |
162 | result_file = open(os.path.join(cfg.DIR.OUT_PATH, 'test_result.txt'), 'w')
163 | sys.stdout = result_file
164 | print('============================ TEST RESULTS ============================')
165 | print('[TEST] Total_Mean_PSNR:' + str(img_PSNRs.avg))
166 | for name in test_psnr:
167 | print('[TEST] Name: {0}\t Num: {1}\t Mean_PSNR: {2}'.format(name, test_psnr[name]['n_samples'],
168 | test_psnr[name]['psnr']))
169 | result_file.close()
170 | else:
171 | # Output val results
172 | print('============================ TEST RESULTS ============================')
173 | print('[TEST] Total_Mean_PSNR:' + str(img_PSNRs.avg))
174 | print('[TEST] [Epoch{0}]\t BatchTime_avg {1}\t DataTime_avg {2}\t DispEPE_avg {3}\t ImgPSNR_avg {4}\n'
175 | .format(cfg.TRAIN.NUM_EPOCHES, batch_time.avg, data_time.avg, disp_EPEs.avg, img_PSNRs.avg))
176 |
177 | # Add testing results to TensorBoard
178 | test_writer.add_scalar('StereoDeblurNet/EpochEPE_1_TEST', disp_EPEs.avg, epoch_idx + 1)
179 | test_writer.add_scalar('StereoDeblurNet/EpochPSNR_1_TEST', img_PSNRs.avg, epoch_idx + 1)
180 |
181 | return disp_EPEs.avg, img_PSNRs.avg
--------------------------------------------------------------------------------
/core/train_deblur.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Developed by Shangchen Zhou
5 |
6 | import os
7 | import torch.backends.cudnn
8 | import torch.utils.data
9 |
10 | import utils.data_loaders
11 | import utils.data_transforms
12 | import utils.network_utils
13 | import torchvision
14 |
15 | from losses.multiscaleloss import *
16 | from time import time
17 |
18 | from core.test_deblur import test_deblurnet
19 | from models.VGG19 import VGG19
20 |
21 |
22 | def train_deblurnet(cfg, init_epoch, train_data_loader, val_data_loader, deblurnet, deblurnet_solver,
23 | deblurnet_lr_scheduler, ckpt_dir, train_writer, val_writer, Best_Img_PSNR, Best_Epoch):
24 | # Training loop
25 | for epoch_idx in range(init_epoch, cfg.TRAIN.NUM_EPOCHES):
26 | # Tick / tock
27 | epoch_start_time = time()
28 |
29 | # Batch average meterics
30 | batch_time = utils.network_utils.AverageMeter()
31 | data_time = utils.network_utils.AverageMeter()
32 | test_time = utils.network_utils.AverageMeter()
33 | deblur_losses = utils.network_utils.AverageMeter()
34 | mse_losses = utils.network_utils.AverageMeter()
35 | if cfg.TRAIN.USE_PERCET_LOSS:
36 | percept_losses = utils.network_utils.AverageMeter()
37 | img_PSNRs = utils.network_utils.AverageMeter()
38 |
39 | # Adjust learning rate
40 | deblurnet_lr_scheduler.step()
41 |
42 | batch_end_time = time()
43 | n_batches = len(train_data_loader)
44 | if cfg.TRAIN.USE_PERCET_LOSS:
45 | vggnet = VGG19()
46 | if torch.cuda.is_available():
47 | vggnet = torch.nn.DataParallel(vggnet).cuda()
48 |
49 | for batch_idx, (_, images, DISPs, OCCs) in enumerate(train_data_loader):
50 | # Measure data time
51 |
52 | data_time.update(time() - batch_end_time)
53 | # Get data from data loader
54 | imgs = [utils.network_utils.var_or_cuda(img) for img in images]
55 | img_blur_left, img_blur_right, img_clear_left, img_clear_right = imgs
56 |
57 | # switch models to training mode
58 | deblurnet.train()
59 |
60 | output_img_clear_left = deblurnet(img_blur_left)
61 |
62 | mse_left_loss = mseLoss(output_img_clear_left, img_clear_left)
63 | if cfg.TRAIN.USE_PERCET_LOSS:
64 | percept_left_loss = perceptualLoss(output_img_clear_left, img_clear_left, vggnet)
65 | deblur_left_loss = mse_left_loss + 0.01 * percept_left_loss
66 | else:
67 | deblur_left_loss = mse_left_loss
68 |
69 | img_PSNR_left = PSNR(output_img_clear_left, img_clear_left)
70 |
71 | # Gradient decent
72 | deblurnet_solver.zero_grad()
73 | deblur_left_loss.backward()
74 |
75 | # For right
76 | output_img_clear_right = deblurnet(img_blur_right)
77 | mse_right_loss = mseLoss(output_img_clear_right, img_clear_right)
78 | if cfg.TRAIN.USE_PERCET_LOSS:
79 | percept_right_loss = perceptualLoss(output_img_clear_right, img_clear_right, vggnet)
80 | deblur_right_loss = mse_right_loss + 0.01 * percept_right_loss
81 | else:
82 | deblur_right_loss = mse_right_loss
83 |
84 | img_PSNR_right = PSNR(output_img_clear_right, img_clear_right)
85 |
86 | # Gradient decent
87 | deblurnet_solver.zero_grad()
88 | deblur_right_loss.backward()
89 | deblurnet_solver.step()
90 |
91 | mse_loss = (mse_left_loss + mse_right_loss) / 2
92 | mse_losses.update(mse_loss.item(), cfg.CONST.TRAIN_BATCH_SIZE)
93 | if cfg.TRAIN.USE_PERCET_LOSS:
94 | percept_loss = 0.01 *(percept_left_loss + percept_right_loss) / 2
95 | percept_losses.update(percept_loss.item(), cfg.CONST.TRAIN_BATCH_SIZE)
96 |
97 | deblur_loss = (deblur_left_loss + deblur_right_loss) / 2
98 | deblur_losses.update(deblur_loss.item(), cfg.CONST.TRAIN_BATCH_SIZE)
99 | img_PSNR = img_PSNR_left / 2 + img_PSNR_right / 2
100 | img_PSNRs.update(img_PSNR.item(), cfg.CONST.TRAIN_BATCH_SIZE)
101 |
102 | # Append loss to TensorBoard
103 | n_itr = epoch_idx * n_batches + batch_idx
104 | train_writer.add_scalar('DeblurNet/MSELoss_0_TRAIN', mse_loss.item(), n_itr)
105 | if cfg.TRAIN.USE_PERCET_LOSS:
106 | train_writer.add_scalar('DeblurNet/PerceptLoss_0_TRAIN', percept_loss.item(), n_itr)
107 | train_writer.add_scalar('DeblurNet/DeblurLoss_0_TRAIN', deblur_loss.item(), n_itr)
108 |
109 | # Tick / tock
110 | batch_time.update(time() - batch_end_time)
111 | batch_end_time = time()
112 |
113 | if (batch_idx + 1) % cfg.TRAIN.PRINT_FREQ == 0:
114 | if cfg.TRAIN.USE_PERCET_LOSS:
115 | print('[TRAIN] [Ech {0}/{1}][Bch {2}/{3}]\t BT {4}\t DT {5}\t Loss {6} [{7}, {8}]\t PSNR {9}'
116 | .format(epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, batch_idx + 1, n_batches, batch_time, data_time,
117 | deblur_losses, mse_losses, percept_losses, img_PSNRs))
118 | else:
119 | print('[TRAIN] [Ech {0}/{1}][Bch {2}/{3}]\t BT {4}\t DT {5}\t DeblurLoss {6} \t PSNR {7}'
120 | .format(epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, batch_idx + 1, n_batches, batch_time, data_time,
121 | deblur_losses, img_PSNRs))
122 |
123 | if batch_idx < cfg.TEST.VISUALIZATION_NUM:
124 |
125 | img_left_blur = images[0][0][[2, 1, 0], :, :] + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
126 | img_right_blur = images[1][0][[2, 1, 0], :, :] + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
127 | img_left_clear = images[2][0][[2, 1, 0], :, :] + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
128 | img_right_clear = images[3][0][[2, 1, 0], :, :] + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
129 | out_left = output_img_clear_left[0][[2,1,0],:,:].cpu().clamp(0.0,1.0) + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
130 | out_right = output_img_clear_right[0][[2,1,0],:,:].cpu().clamp(0.0,1.0) + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
131 | result = torch.cat([torch.cat([img_left_blur, img_right_blur], 2),torch.cat([img_left_clear, img_right_clear], 2),torch.cat([out_left, out_right], 2)],1)
132 | result = torchvision.utils.make_grid(result, nrow=1, normalize=True)
133 | train_writer.add_image('DeblurNet/TRAIN_RESULT' + str(batch_idx + 1), result, epoch_idx + 1)
134 |
135 |
136 | # Append epoch loss to TensorBoard
137 | train_writer.add_scalar('DeblurNet/EpochPSNR_0_TRAIN', img_PSNRs.avg, epoch_idx + 1)
138 |
139 | # Tick / tock
140 | epoch_end_time = time()
141 | print('[TRAIN] [Epoch {0}/{1}]\t EpochTime {2}\t DeblurLoss_avg {3}\t ImgPSNR_avg {4}'
142 | .format(epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, epoch_end_time - epoch_start_time, deblur_losses.avg,
143 | img_PSNRs.avg))
144 |
145 | # Validate the training models
146 | img_PSNR = test_deblurnet(cfg, epoch_idx, val_data_loader, deblurnet, val_writer)
147 |
148 | # Save weights to file
149 | if (epoch_idx + 1) % cfg.TRAIN.SAVE_FREQ == 0:
150 | if not os.path.exists(ckpt_dir):
151 | os.makedirs(ckpt_dir)
152 |
153 | utils.network_utils.save_deblur_checkpoints(os.path.join(ckpt_dir, 'ckpt-epoch-%04d.pth.tar' % (epoch_idx + 1)), \
154 | epoch_idx + 1, deblurnet, deblurnet_solver, Best_Img_PSNR,
155 | Best_Epoch)
156 | if img_PSNR > Best_Img_PSNR:
157 | if not os.path.exists(ckpt_dir):
158 | os.makedirs(ckpt_dir)
159 |
160 | Best_Img_PSNR = img_PSNR
161 | Best_Epoch = epoch_idx + 1
162 | utils.network_utils.save_deblur_checkpoints(os.path.join(ckpt_dir, 'best-ckpt.pth.tar'), \
163 | epoch_idx + 1, deblurnet, deblurnet_solver, Best_Img_PSNR,
164 | Best_Epoch)
165 |
166 | # Close SummaryWriter for TensorBoard
167 | train_writer.close()
168 | val_writer.close()
169 |
170 |
171 |
--------------------------------------------------------------------------------
/core/train_disp.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Developed by Shangchen Zhou
5 |
6 | import os
7 | import torch.backends.cudnn
8 | import torch.utils.data
9 |
10 | import utils.data_loaders
11 | import utils.data_transforms
12 | import utils.network_utils
13 | import torchvision
14 |
15 | from losses.multiscaleloss import *
16 | from time import time
17 |
18 | from core.test_disp import test_dispnet
19 |
20 |
21 | def train_dispnet(cfg, init_epoch, train_data_loader, val_data_loader, dispnet, dispnet_solver,
22 | dispnet_lr_scheduler, ckpt_dir, train_writer, val_writer, Best_Disp_EPE, Best_Epoch):
23 | # Training loop
24 | Best_Disp_EPE = float('Inf')
25 | for epoch_idx in range(init_epoch, cfg.TRAIN.NUM_EPOCHES):
26 | # Tick / tock
27 | epoch_start_time = time()
28 |
29 | # Batch average meterics
30 | batch_time = utils.network_utils.AverageMeter()
31 | data_time = utils.network_utils.AverageMeter()
32 | disp_losses = utils.network_utils.AverageMeter()
33 | disp_EPEs = utils.network_utils.AverageMeter()
34 | disp_EPEs_blur = utils.network_utils.AverageMeter()
35 | disp_EPEs_clear = utils.network_utils.AverageMeter()
36 |
37 | # Adjust learning rate
38 | dispnet_lr_scheduler.step()
39 |
40 | batch_end_time = time()
41 | n_batches = len(train_data_loader)
42 |
43 | for batch_idx, (_, images, disps, occs) in enumerate(train_data_loader):
44 | # Measure data time
45 |
46 | data_time.update(time() - batch_end_time)
47 | # Get data from data loader
48 | disparities = disps
49 | imgs_blur = [utils.network_utils.var_or_cuda(img) for img in images[:2]]
50 | imgs_clear = [utils.network_utils.var_or_cuda(img) for img in images[2:]]
51 |
52 | imgs_blur = torch.cat(imgs_blur, 1)
53 | imgs_clear = torch.cat(imgs_clear, 1)
54 | ground_truth_disps = [utils.network_utils.var_or_cuda(disp) for disp in disparities]
55 | ground_truth_disps = torch.cat(ground_truth_disps, 1)
56 | occs = [utils.network_utils.var_or_cuda(occ) for occ in occs]
57 | occs = torch.cat(occs, 1)
58 |
59 | # switch models to training mode
60 | dispnet.train()
61 |
62 | # Train the model
63 | output_disps_blur = dispnet(imgs_blur)
64 |
65 | disp_loss_blur = multiscaleLoss(output_disps_blur, ground_truth_disps, imgs_blur, occs, cfg.LOSS.MULTISCALE_WEIGHTS)
66 | disp_EPE_blur = cfg.DATA.DIV_DISP * realEPE(output_disps_blur[0], ground_truth_disps, occs)
67 | disp_EPEs_blur.update(disp_EPE_blur.item(), cfg.CONST.TRAIN_BATCH_SIZE)
68 |
69 | output_disps_clear = dispnet(imgs_clear)
70 |
71 | disp_loss_clear = multiscaleLoss(output_disps_clear, ground_truth_disps, imgs_clear, occs, cfg.LOSS.MULTISCALE_WEIGHTS)
72 | disp_EPE_clear = cfg.DATA.DIV_DISP * realEPE(output_disps_clear[0], ground_truth_disps, occs)
73 | disp_EPEs_clear.update(disp_EPE_clear.item(), cfg.CONST.TRAIN_BATCH_SIZE)
74 |
75 | # Gradient decent
76 | dispnet_solver.zero_grad()
77 | disp_loss_clear.backward()
78 | dispnet_solver.step()
79 |
80 | disp_loss = (disp_loss_blur + disp_loss_clear) / 2.0
81 | disp_EPE = (disp_EPE_blur + disp_EPE_clear) / 2.0
82 | disp_losses.update(disp_loss.item(), cfg.CONST.TRAIN_BATCH_SIZE)
83 | disp_EPEs.update(disp_EPE.item(), cfg.CONST.TRAIN_BATCH_SIZE)
84 |
85 |
86 | # Append loss to TensorBoard
87 | n_itr = epoch_idx * n_batches + batch_idx
88 | train_writer.add_scalar('DispNet/BatchLoss_0_TRAIN', disp_loss.item(), n_itr)
89 |
90 | # Tick / tock
91 | batch_time.update(time() - batch_end_time)
92 | batch_end_time = time()
93 |
94 | if (batch_idx+1) % cfg.TRAIN.PRINT_FREQ == 0:
95 | print(
96 | '[TRAIN] [Epoch {0}/{1}][Batch {2}/{3}]\t BatchTime {4}\t DataTime {5}\t DispLoss {6}\t blurEPE {7}\t clearEPE {8}'
97 | .format(epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, batch_idx + 1, n_batches, batch_time, data_time,
98 | disp_losses, disp_EPEs_blur, disp_EPEs_clear))
99 |
100 | if batch_idx < cfg.TEST.VISUALIZATION_NUM:
101 | img_left_blur = images[0][0][[2, 1, 0], :, :] + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
102 | img_right_blur = images[1][0][[2, 1, 0], :, :] + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
103 | img_left_clear = images[2][0][[2, 1, 0], :, :] + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
104 | img_right_clear = images[3][0][[2, 1, 0], :, :] + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
105 | gt_disp_left, gt_disp_right = utils.network_utils.graybi2rgb(ground_truth_disps[0])
106 | b, _, h, w = imgs_clear.size()
107 | output_disps_up_blur = torch.nn.functional.interpolate(output_disps_blur[0], size=(h, w), mode='bilinear', align_corners=True)
108 | output_disps_up_clear = torch.nn.functional.interpolate(output_disps_clear[0], size=(h, w), mode='bilinear', align_corners=True)
109 | output_disp_up_left_blur, output_disp_up_right_blur = utils.network_utils.graybi2rgb(output_disps_up_blur[0])
110 | output_disp_up_left_clear, output_disp_up_right_clear = utils.network_utils.graybi2rgb(output_disps_up_clear[0])
111 | result = torch.cat([torch.cat([img_left_blur, img_right_blur], 2),
112 | torch.cat([img_left_clear, img_right_clear], 2),
113 | torch.cat([gt_disp_left, gt_disp_right], 2),
114 | torch.cat([output_disp_up_left_blur, output_disp_up_right_blur], 2),
115 | torch.cat([output_disp_up_left_clear, output_disp_up_right_clear], 2)],1)
116 | result = torchvision.utils.make_grid(result, nrow=1, normalize=True)
117 | train_writer.add_image('DispNet/TRAIN_RESULT' + str(batch_idx + 1), result, epoch_idx + 1)
118 |
119 | # Append epoch loss to TensorBoard
120 | train_writer.add_scalar('DispNet/EpochEPE_0_TRAIN', disp_EPEs.avg, epoch_idx + 1)
121 |
122 | # Tick / tock
123 | epoch_end_time = time()
124 | print('[TRAIN] [Epoch {0}/{1}]\t EpochTime {2}\t DispLoss_avg {3}\t DispEPE_avg {4}'
125 | .format(epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, epoch_end_time - epoch_start_time, disp_losses.avg,
126 | disp_EPEs.avg))
127 |
128 |
129 | # Save weights to file
130 | if (epoch_idx + 1) % cfg.TRAIN.SAVE_FREQ == 0:
131 | if not os.path.exists(ckpt_dir):
132 | os.makedirs(ckpt_dir)
133 |
134 | utils.network_utils.save_disp_checkpoints(os.path.join(ckpt_dir, 'ckpt-epoch-%04d.pth.tar' % (epoch_idx + 1)), \
135 | epoch_idx + 1, dispnet, dispnet_solver, Best_Disp_EPE,
136 | Best_Epoch)
137 | if disp_EPEs.avg < Best_Disp_EPE:
138 | if not os.path.exists(ckpt_dir):
139 | os.makedirs(ckpt_dir)
140 |
141 | Best_Disp_EPE = disp_EPEs.avg
142 | Best_Epoch = epoch_idx + 1
143 | utils.network_utils.save_disp_checkpoints(os.path.join(ckpt_dir, 'best-ckpt.pth.tar'), \
144 | epoch_idx + 1, dispnet, dispnet_solver, Best_Disp_EPE,
145 | Best_Epoch)
146 |
147 | # Close SummaryWriter for TensorBoard
148 | train_writer.close()
149 | val_writer.close()
150 |
151 |
--------------------------------------------------------------------------------
/core/train_stereodeblur.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Developed by Shangchen Zhou
5 |
6 | import os
7 | import torch.backends.cudnn
8 | import torch.utils.data
9 |
10 | import utils.data_loaders
11 | import utils.data_transforms
12 | import utils.network_utils
13 | import torchvision
14 |
15 | from losses.multiscaleloss import *
16 | from time import time
17 |
18 | from core.test_stereodeblur import test_stereodeblurnet
19 | from models.VGG19 import VGG19
20 |
21 |
22 | def train_stereodeblurnet(cfg, init_epoch, train_data_loader, val_data_loader,
23 | dispnet, dispnet_solver, dispnet_lr_scheduler,
24 | deblurnet, deblurnet_solver, deblurnet_lr_scheduler,
25 | ckpt_dir, train_writer, val_writer,
26 | Disp_EPE, Best_Img_PSNR, Best_Epoch):
27 | # Training loop
28 | for epoch_idx in range(init_epoch, cfg.TRAIN.NUM_EPOCHES):
29 | # Tick / tock
30 | epoch_start_time = time()
31 |
32 | # Batch average meterics
33 | batch_time = utils.network_utils.AverageMeter()
34 | data_time = utils.network_utils.AverageMeter()
35 | disp_EPEs = utils.network_utils.AverageMeter()
36 | deblur_mse_losses = utils.network_utils.AverageMeter()
37 | if cfg.TRAIN.USE_PERCET_LOSS == True:
38 | deblur_percept_losses = utils.network_utils.AverageMeter()
39 | deblur_losses = utils.network_utils.AverageMeter()
40 | img_PSNRs = utils.network_utils.AverageMeter()
41 |
42 | # Adjust learning rate
43 | dispnet_lr_scheduler.step()
44 | deblurnet_lr_scheduler.step()
45 |
46 | batch_end_time = time()
47 | n_batches = len(train_data_loader)
48 |
49 | vggnet = VGG19()
50 | if torch.cuda.is_available():
51 | vggnet = torch.nn.DataParallel(vggnet).cuda()
52 |
53 | for batch_idx, (_, images, disps, occ_masks) in enumerate(train_data_loader):
54 | # Measure data time
55 |
56 | data_time.update(time() - batch_end_time)
57 | # Get data from data loader
58 | imgs = [utils.network_utils.var_or_cuda(img) for img in images]
59 | img_blur_left, img_blur_right, img_clear_left, img_clear_right = imgs
60 |
61 | imgs_blur = torch.cat([img_blur_left, img_blur_right], 1)
62 | ground_truth_disps = [utils.network_utils.var_or_cuda(disp) for disp in disps]
63 | ground_truth_disps = torch.cat(ground_truth_disps, 1)
64 | occ_masks = [utils.network_utils.var_or_cuda(occ_mask) for occ_mask in occ_masks]
65 | occ_masks = torch.cat(occ_masks, 1)
66 |
67 | # switch models to training mode
68 | dispnet.train()
69 | deblurnet.train()
70 |
71 | # Train the model
72 | output_disps = dispnet(imgs_blur)
73 |
74 | output_disp_feature = output_disps[-1]
75 | output_disps = output_disps[:-1]
76 | imgs_prd, output_diffs, output_masks = deblurnet(imgs_blur, output_disps[0], output_disp_feature)
77 |
78 | disp_EPE = cfg.DATA.DIV_DISP * realEPE(output_disps[0], ground_truth_disps, occ_masks)
79 | disp_EPEs.update(disp_EPE.item(), cfg.CONST.TRAIN_BATCH_SIZE)
80 |
81 | # deblur loss
82 | deblur_mse_left_loss = mseLoss(imgs_prd[0], img_clear_left)
83 | deblur_mse_right_loss = mseLoss(imgs_prd[1], img_clear_right)
84 | deblur_mse_loss = (deblur_mse_left_loss + deblur_mse_right_loss) / 2
85 | deblur_mse_losses.update(deblur_mse_loss.item(), cfg.CONST.TRAIN_BATCH_SIZE)
86 | if cfg.TRAIN.USE_PERCET_LOSS == True:
87 | deblur_percept_left_loss = perceptualLoss(imgs_prd[0], img_clear_left, vggnet)
88 | deblur_percept_right_loss = perceptualLoss(imgs_prd[1], img_clear_right, vggnet)
89 | deblur_percept_loss = (deblur_percept_left_loss + deblur_percept_right_loss) / 2
90 | deblur_percept_losses.update(deblur_percept_loss.item(), cfg.CONST.TRAIN_BATCH_SIZE)
91 | deblur_loss = deblur_mse_loss + 0.01 * deblur_percept_loss
92 | else:
93 | deblur_loss = deblur_mse_loss
94 | deblur_losses.update(deblur_loss.item(), cfg.CONST.TRAIN_BATCH_SIZE)
95 |
96 | img_PSNR = (PSNR(imgs_prd[0], img_clear_left) + PSNR(imgs_prd[1], img_clear_right)) / 2
97 | img_PSNRs.update(img_PSNR.item(), cfg.CONST.TRAIN_BATCH_SIZE)
98 |
99 | deblurnet_solver.zero_grad()
100 | deblurnet_loss = deblur_loss
101 | deblurnet_loss.backward()
102 | deblurnet_solver.step()
103 |
104 | # Append loss to TensorBoard
105 | n_itr = epoch_idx * n_batches + batch_idx
106 |
107 | train_writer.add_scalar('StereoDeblurNet/DeblurLoss_0_TRAIN', deblur_loss.item(), n_itr)
108 | train_writer.add_scalar('StereoDeblurNet/DeblurMSELoss_0_TRAIN', deblur_mse_loss.item(), n_itr)
109 | if cfg.TRAIN.USE_PERCET_LOSS == True:
110 | train_writer.add_scalar('StereoDeblurNet/DeblurPerceptLoss_0_TRAIN', deblur_percept_loss.item(), n_itr)
111 |
112 | # Tick / tock
113 | batch_time.update(time() - batch_end_time)
114 | batch_end_time = time()
115 |
116 | if (batch_idx + 1) % cfg.TRAIN.PRINT_FREQ == 0:
117 | print(
118 | '[TRAIN] [Ech {0}/{1}][Bch {2}/{3}] BT {4} DT {5} EPE {6} DeblurLoss {7} [{8}, {9}] PSNR {10}'
119 | .format(epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, batch_idx + 1, n_batches, batch_time, data_time,
120 | disp_EPEs, deblur_losses, deblur_mse_losses, deblur_percept_losses, img_PSNRs))
121 |
122 | if batch_idx < cfg.TEST.VISUALIZATION_NUM:
123 | img_blur_left = images[0][0][[2, 1, 0], :, :] + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
124 | img_blur_right = images[1][0][[2, 1, 0], :, :] + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
125 | img_clear_left = images[2][0][[2, 1, 0], :, :] + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
126 | img_clear_right = images[3][0][[2, 1, 0], :, :] + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
127 | img_out_left = imgs_prd[0][0][[2, 1, 0], :, :].cpu().clamp(0.0, 1.0) + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
128 | img_out_right = imgs_prd[1][0][[2, 1, 0], :, :].cpu().clamp(0.0, 1.0) + torch.Tensor(cfg.DATA.MEAN).view(3, 1, 1)
129 | disp_gt_left, disp_gt_right = utils.network_utils.graybi2rgb(ground_truth_disps[0])
130 | b, _, h, w = imgs[0].shape
131 | diff_out_left, diff_out_right = utils.network_utils.graybi2rgb(torch.cat(output_diffs, 1)[0])
132 | output_masks = torch.nn.functional.interpolate(torch.cat(output_masks, 1), size=(h, w), mode='bilinear', align_corners=True)
133 | mask_out_left, mask_out_right = utils.network_utils.graybi2rgb(output_masks[0])
134 | disp_out_left, disp_out_right = utils.network_utils.graybi2rgb(output_disps[0][0])
135 | result = torch.cat([torch.cat([img_blur_left, img_blur_right], 2),
136 | torch.cat([img_clear_left, img_clear_right], 2),
137 | torch.cat([img_out_left, img_out_right], 2),
138 | torch.cat([disp_gt_left, disp_gt_right], 2),
139 | torch.cat([disp_out_left, disp_out_right], 2),
140 | torch.cat([diff_out_left, diff_out_right], 2),
141 | torch.cat([mask_out_left, mask_out_right], 2)], 1)
142 | result = torchvision.utils.make_grid(result, nrow=1, normalize=True)
143 | train_writer.add_image('StereoDeblurNet/TRAIN_RESULT' + str(batch_idx + 1), result, epoch_idx + 1)
144 |
145 | # Append epoch loss to TensorBoard
146 | train_writer.add_scalar('StereoDeblurNet/EpochEPE_0_TRAIN', disp_EPEs.avg, epoch_idx + 1)
147 | train_writer.add_scalar('StereoDeblurNet/EpochPSNR_0_TRAIN', img_PSNRs.avg, epoch_idx + 1)
148 |
149 | # Tick / tock
150 | epoch_end_time = time()
151 | print('[TRAIN] [Epoch {0}/{1}]\t EpochTime {2}\t DispEPE_avg {3}\t ImgPSNR_avg {4}'
152 | .format(epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, epoch_end_time - epoch_start_time, disp_EPEs.avg, img_PSNRs.avg))
153 |
154 | # Validate the training models
155 | Disp_EPE, img_PSNR = test_stereodeblurnet(cfg, epoch_idx, val_data_loader, dispnet, deblurnet, val_writer)
156 |
157 | # Save weights to file
158 | if (epoch_idx + 1) % cfg.TRAIN.SAVE_FREQ == 0:
159 | if not os.path.exists(ckpt_dir):
160 | os.makedirs(ckpt_dir)
161 |
162 | utils.network_utils.save_checkpoints(os.path.join(ckpt_dir, 'ckpt-epoch-%04d.pth.tar' % (epoch_idx + 1)), \
163 | epoch_idx + 1, dispnet, dispnet_solver, deblurnet, deblurnet_solver, \
164 | Disp_EPE, Best_Img_PSNR, Best_Epoch)
165 | if img_PSNR >= Best_Img_PSNR:
166 | if not os.path.exists(ckpt_dir):
167 | os.makedirs(ckpt_dir)
168 |
169 | Best_Img_PSNR = img_PSNR
170 | Best_Epoch = epoch_idx + 1
171 | utils.network_utils.save_checkpoints(os.path.join(ckpt_dir, 'best-ckpt.pth.tar'), \
172 | epoch_idx + 1, dispnet, dispnet_solver, deblurnet, deblurnet_solver, \
173 | Disp_EPE, Best_Img_PSNR, Best_Epoch)
174 |
175 | # Close SummaryWriter for TensorBoard
176 | train_writer.close()
177 | val_writer.close()
178 |
179 |
--------------------------------------------------------------------------------
/losses/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sczhou/DAVANet/2bbe35ae01c0f1af718a1bc19272cda5ed3c320a/losses/__init__.py
--------------------------------------------------------------------------------
/losses/multiscaleloss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from config import cfg
4 | from utils.network_utils import *
5 | #
6 | # Disparity Loss
7 | #
8 | def EPE(output, target, occ_mask):
9 | N = torch.sum(occ_mask)
10 | d_diff = output - target
11 | EPE_map = torch.abs(d_diff)
12 | EPE_map = torch.mul(EPE_map, occ_mask)
13 |
14 | EPE_mean = torch.sum(EPE_map)/N
15 | return EPE_mean
16 |
17 | def multiscaleLoss(outputs, target, img, occ_mask, weights):
18 |
19 | def one_scale(output, target, occ_mask):
20 | b, _, h, w = output.size()
21 | occ_mask = nn.functional.adaptive_max_pool2d(occ_mask, (h, w))
22 | if cfg.DATASET.SPARSE:
23 | target_scaled = nn.functional.adaptive_max_pool2d(target, (h, w))
24 | else:
25 | target_scaled = nn.functional.adaptive_avg_pool2d(target, (h, w))
26 | return EPE(output, target_scaled, occ_mask)
27 |
28 | if type(outputs) not in [tuple, list]:
29 | outputs = [outputs]
30 |
31 | assert(len(weights) == len(outputs))
32 |
33 | loss = 0
34 | for output, weight in zip(outputs, weights):
35 | loss += weight * one_scale(output, target, occ_mask)
36 | return loss
37 |
38 | def realEPE(output, target, occ_mask):
39 | b, _, h, w = target.size()
40 | upsampled_output = nn.functional.interpolate(output, size=(h,w), mode = 'bilinear', align_corners=True)
41 | return EPE(upsampled_output, target, occ_mask)
42 |
43 | #
44 | # Deblurring Loss
45 | #
46 | def mseLoss(output, target):
47 | mse_loss = nn.MSELoss(reduction ='elementwise_mean')
48 | MSE = mse_loss(output, target)
49 | return MSE
50 |
51 | def PSNR(output, target, max_val = 1.0):
52 | output = output.clamp(0.0,1.0)
53 | mse = torch.pow(target - output, 2).mean()
54 | if mse == 0:
55 | return torch.Tensor([100.0])
56 | return 10 * torch.log10(max_val**2 / mse)
57 |
58 |
59 | def perceptualLoss(fakeIm, realIm, vggnet):
60 | '''
61 | use vgg19 conv1_2, conv2_2, conv3_3 feature, before relu layer
62 | '''
63 |
64 | weights = [1, 0.2, 0.04]
65 | features_fake = vggnet(fakeIm)
66 | features_real = vggnet(realIm)
67 | features_real_no_grad = [f_real.detach() for f_real in features_real]
68 | mse_loss = nn.MSELoss(reduction='elementwise_mean')
69 |
70 | loss = 0
71 | for i in range(len(features_real)):
72 | loss_i = mse_loss(features_fake[i], features_real_no_grad[i])
73 | loss = loss + loss_i * weights[i]
74 |
75 | return loss
76 |
--------------------------------------------------------------------------------
/models/DeblurNet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Developed by Shangchen Zhou
5 |
6 | # from models.submodules import *
7 |
8 | from models.submodules import *
9 | class DeblurNet(nn.Module):
10 | def __init__(self):
11 | super(DeblurNet, self).__init__()
12 | # encoder
13 | ks = 3
14 | self.conv1_1 = conv(3, 32, kernel_size=ks, stride=1)
15 | self.conv1_2 = resnet_block(32, kernel_size=ks)
16 | self.conv1_3 = resnet_block(32, kernel_size=ks)
17 | self.conv1_4 = resnet_block(32, kernel_size=ks)
18 |
19 | self.conv2_1 = conv(32, 64, kernel_size=ks, stride=2)
20 | self.conv2_2 = resnet_block(64, kernel_size=ks)
21 | self.conv2_3 = resnet_block(64, kernel_size=ks)
22 | self.conv2_4 = resnet_block(64, kernel_size=ks)
23 |
24 |
25 | self.conv3_1 = conv(64, 128, kernel_size=ks, stride=2)
26 | self.conv3_2 = resnet_block(128, kernel_size=ks)
27 | self.conv3_3 = resnet_block(128, kernel_size=ks)
28 | self.conv3_4 = resnet_block(128, kernel_size=ks)
29 |
30 | dilation = [1,2,3,4]
31 | self.convd_1 = resnet_block(128, kernel_size=ks, dilation = [2, 1])
32 | self.convd_2 = resnet_block(128, kernel_size=ks, dilation = [3, 1])
33 | self.convd_3 = ms_dilate_block(128, kernel_size=ks, dilation = dilation)
34 |
35 | # decoder
36 | self.upconv3_i = conv(128, 128, kernel_size=ks,stride=1)
37 | self.upconv3_3 = resnet_block(128, kernel_size=ks)
38 | self.upconv3_2 = resnet_block(128, kernel_size=ks)
39 | self.upconv3_1 = resnet_block(128, kernel_size=ks)
40 |
41 | self.upconv2_u = upconv(128, 64)
42 | self.upconv2_i = conv(128, 64, kernel_size=ks,stride=1)
43 | self.upconv2_3 = resnet_block(64, kernel_size=ks)
44 | self.upconv2_2 = resnet_block(64, kernel_size=ks)
45 | self.upconv2_1 = resnet_block(64, kernel_size=ks)
46 |
47 | self.upconv1_u = upconv(64, 32)
48 | self.upconv1_i = conv(64, 32, kernel_size=ks,stride=1)
49 | self.upconv1_3 = resnet_block(32, kernel_size=ks)
50 | self.upconv1_2 = resnet_block(32, kernel_size=ks)
51 | self.upconv1_1 = resnet_block(32, kernel_size=ks)
52 |
53 | self.img_prd = conv(32, 3, kernel_size=ks, stride=1)
54 |
55 | def forward(self, x):
56 | # encoder
57 | conv1 = self.conv1_4(self.conv1_3(self.conv1_2(self.conv1_1(x))))
58 | conv2 = self.conv2_4(self.conv2_3(self.conv2_2(self.conv2_1(conv1))))
59 | conv3 = self.conv3_4(self.conv3_3(self.conv3_2(self.conv3_1(conv2))))
60 | convd = self.convd_3(self.convd_2(self.convd_1(conv3)))
61 |
62 | # decoder
63 | cat3 = self.upconv3_i(convd)
64 | upconv2 = self.upconv2_u(self.upconv3_1(self.upconv3_2(self.upconv3_3(cat3))))
65 | cat2 = self.upconv2_i(cat_with_crop(conv2, [conv2, upconv2]))
66 | upconv1 = self.upconv1_u(self.upconv2_1(self.upconv2_2(self.upconv2_3(cat2))))
67 | cat1 = self.upconv1_i(cat_with_crop(conv1, [conv1, upconv1]))
68 | img_prd = self.img_prd(self.upconv1_1(self.upconv1_2(self.upconv1_3(cat1))))
69 |
70 | return img_prd + x
71 |
--------------------------------------------------------------------------------
/models/DispNet_Bi.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Developed by Shangchen Zhou
5 |
6 | from models.submodules import *
7 |
8 | class DispNet_Bi(nn.Module):
9 | def __init__(self):
10 | super(DispNet_Bi, self).__init__()
11 | # encoder
12 | ks = 3
13 | self.conv0 = conv(6, 48, kernel_size=ks, stride=1)
14 |
15 | self.conv1_1 = conv(48, 48, kernel_size=ks, stride=2)
16 | self.conv1_2 = conv(48, 48, kernel_size=ks, stride=1)
17 |
18 | self.conv2_1 = conv(48, 96, kernel_size=ks, stride=2)
19 | self.conv2_2 = conv(96, 96, kernel_size=ks, stride=1)
20 |
21 | self.conv3_1 = conv(96, 128, kernel_size=ks, stride=2)
22 | self.conv3_2 = conv(128, 128, kernel_size=ks, stride=1)
23 |
24 | self.conv4_1 = resnet_block(128, kernel_size=ks)
25 | self.conv4_2 = resnet_block(128, kernel_size=ks)
26 |
27 | self.convd_1 = resnet_block(128, kernel_size=ks, dilation=[2, 1])
28 | self.convd_2 = ms_dilate_block(128, kernel_size=ks, dilation=[1, 2, 3, 4])
29 |
30 | # decoder
31 | self.upconvd_i = conv(128, 128, kernel_size=ks, stride=1)
32 | self.dispd = predict_disp_bi(128)
33 |
34 | self.upconv3 = conv(128, 128, kernel_size=ks, stride=1)
35 | self.upconv3_i = conv(258, 128, kernel_size=ks, stride=1)
36 | self.upconv3_f = conv(128, 128, kernel_size=ks, stride=1)
37 | self.disp3 = predict_disp_bi(128)
38 |
39 | self.updisp3 = up_disp_bi()
40 | self.upconv2 = upconv(128, 96)
41 | self.upconv2_i = conv(194, 96, kernel_size=ks, stride=1)
42 | self.upconv2_f = conv(96, 96, kernel_size=ks, stride=1)
43 | self.disp2 = predict_disp_bi(96)
44 |
45 | self.updisp2 = up_disp_bi()
46 | self.upconv1 = upconv(96, 48)
47 | self.upconv1_i = conv(50, 48, kernel_size=ks, stride=1)
48 | self.upconv1_f = conv(48, 48, kernel_size=ks, stride=1)
49 | self.disp1 = predict_disp_bi(48)
50 |
51 | self.updisp1 = up_disp_bi()
52 | self.upconv0 = upconv(48, 32)
53 | self.upconv0_i = conv(34, 32, kernel_size=ks, stride=1)
54 | self.upconv0_f = conv(32, 32, kernel_size=ks, stride=1)
55 | self.disp0 = predict_disp_bi(32)
56 |
57 | def forward(self, x):
58 | # encoder
59 | conv0 = self.conv0(x)
60 | conv1 = self.conv1_2(self.conv1_1(conv0))
61 | conv2 = self.conv2_2(self.conv2_1(conv1))
62 | conv3 = self.conv3_2(self.conv3_1(conv2))
63 | conv4 = self.conv4_2(self.conv4_1(conv3))
64 | convd = self.convd_2(self.convd_1(conv4))
65 |
66 | # decoder
67 | upconvd_i = self.upconvd_i(convd)
68 | disp4 = self.dispd(upconvd_i)
69 |
70 | upconv3 = self.upconv3(upconvd_i)
71 | cat3 = torch.cat([conv3, upconv3, disp4], 1)
72 | upconv3_i = self.upconv3_f(self.upconv3_i(cat3))
73 | disp3 = self.disp3(upconv3_i) + disp4
74 |
75 | updisp3 = self.updisp3(disp3)
76 | upconv2 = self.upconv2(upconv3_i)
77 | cat2 = torch.cat([conv2, upconv2, updisp3], 1)
78 | upconv2_i = self.upconv2_f(self.upconv2_i(cat2))
79 | disp2 = self.disp2(upconv2_i) + updisp3
80 |
81 | updisp2 = self.updisp2(disp2)
82 | upconv1 = self.upconv1(upconv2_i)
83 | cat1 = torch.cat([upconv1, updisp2], 1)
84 | upconv1_i = self.upconv1_f(self.upconv1_i(cat1))
85 | disp1 = self.disp1(upconv1_i) + updisp2
86 |
87 | updisp1 = self.updisp1(disp1)
88 | upconv0 = self.upconv0(upconv1_i)
89 | cat0 = torch.cat([upconv0, updisp1], 1)
90 | upconv0_i = self.upconv0_f(self.upconv0_i(cat0))
91 | disp0 = self.disp0(upconv0_i) + updisp1
92 |
93 | # if self.training:
94 | # return disp0, disp1, disp2, disp3, disp4
95 | # else:
96 | # return disp0
97 |
98 | if self.training:
99 | return disp0, disp1, disp2, disp3, disp4, upconv0_i
100 | else:
101 | return disp0, upconv0_i
--------------------------------------------------------------------------------
/models/StereoDeblurNet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Developed by Shangchen Zhou
5 |
6 | from models.submodules import *
7 | from utils.network_utils import *
8 | from config import cfg
9 |
10 | class StereoDeblurNet(nn.Module):
11 | def __init__(self):
12 | super(StereoDeblurNet, self).__init__()
13 | # encoder
14 | ks = 3
15 | self.conv1_1 = conv(3, 32, kernel_size=ks, stride=1)
16 | self.conv1_2 = resnet_block(32, kernel_size=ks)
17 | self.conv1_3 = resnet_block(32, kernel_size=ks)
18 | self.conv1_4 = resnet_block(32, kernel_size=ks)
19 |
20 | self.conv2_1 = conv(32, 64, kernel_size=ks, stride=2)
21 | self.conv2_2 = resnet_block(64, kernel_size=ks)
22 | self.conv2_3 = resnet_block(64, kernel_size=ks)
23 | self.conv2_4 = resnet_block(64, kernel_size=ks)
24 |
25 | self.conv3_1 = conv(64, 128, kernel_size=ks, stride=2)
26 | self.conv3_2 = resnet_block(128, kernel_size=ks)
27 | self.conv3_3 = resnet_block(128, kernel_size=ks)
28 | self.conv3_4 = resnet_block(128, kernel_size=ks)
29 |
30 | dilation = [1,2,3,4]
31 | self.convd_1 = resnet_block(128, kernel_size=ks, dilation = [2, 1])
32 | self.convd_2 = resnet_block(128, kernel_size=ks, dilation = [3, 1])
33 | self.convd_3 = ms_dilate_block(128, kernel_size=ks, dilation = dilation)
34 |
35 | self.gatenet = gatenet()
36 |
37 | self.depth_sense_l = depth_sense(33, 32, kernel_size=ks)
38 | self.depth_sense_r = depth_sense(33, 32, kernel_size=ks)
39 |
40 | # decoder
41 | self.upconv3_i = conv(288, 128, kernel_size=ks,stride=1)
42 | self.upconv3_3 = resnet_block(128, kernel_size=ks)
43 | self.upconv3_2 = resnet_block(128, kernel_size=ks)
44 | self.upconv3_1 = resnet_block(128, kernel_size=ks)
45 |
46 | self.upconv2_u = upconv(128, 64)
47 | self.upconv2_i = conv(128, 64, kernel_size=ks,stride=1)
48 | self.upconv2_3 = resnet_block(64, kernel_size=ks)
49 | self.upconv2_2 = resnet_block(64, kernel_size=ks)
50 | self.upconv2_1 = resnet_block(64, kernel_size=ks)
51 |
52 | self.upconv1_u = upconv(64, 32)
53 | self.upconv1_i = conv(64, 32, kernel_size=ks,stride=1)
54 | self.upconv1_3 = resnet_block(32, kernel_size=ks)
55 | self.upconv1_2 = resnet_block(32, kernel_size=ks)
56 | self.upconv1_1 = resnet_block(32, kernel_size=ks)
57 |
58 | self.img_prd = conv(32, 3, kernel_size=ks, stride=1)
59 |
60 | def forward(self, imgs, disps_bi, disp_feature):
61 | img_left = imgs[:,:3]
62 | img_right = imgs[:,3:]
63 |
64 | disp_left = disps_bi[:, 0]
65 | disp_right = disps_bi[:, 1]
66 |
67 | # encoder-left
68 | conv1_left = self.conv1_4(self.conv1_3(self.conv1_2(self.conv1_1(img_left))))
69 | conv2_left = self.conv2_4(self.conv2_3(self.conv2_2(self.conv2_1(conv1_left))))
70 | conv3_left = self.conv3_4(self.conv3_3(self.conv3_2(self.conv3_1(conv2_left))))
71 | convd_left = self.convd_3(self.convd_2(self.convd_1(conv3_left)))
72 |
73 | # encoder-right
74 | conv1_right = self.conv1_4(self.conv1_3(self.conv1_2(self.conv1_1(img_right))))
75 | conv2_right = self.conv2_4(self.conv2_3(self.conv2_2(self.conv2_1(conv1_right))))
76 | conv3_right = self.conv3_4(self.conv3_3(self.conv3_2(self.conv3_1(conv2_right))))
77 | convd_right = self.convd_3(self.convd_2(self.convd_1(conv3_right)))
78 |
79 | b, c, h, w = convd_left.shape
80 |
81 | warp_img_left = disp_warp(img_right, -disp_left*cfg.DATA.DIV_DISP, cuda=True)
82 | warp_img_right = disp_warp(img_left, disp_right*cfg.DATA.DIV_DISP, cuda=True)
83 | diff_left = torch.sum(torch.abs(img_left - warp_img_left), 1).view(b,1,*warp_img_left.shape[-2:])
84 | diff_right = torch.sum(torch.abs(img_right - warp_img_right), 1).view(b,1,*warp_img_right.shape[-2:])
85 | diff_2_left = nn.functional.adaptive_avg_pool2d(diff_left, (h, w))
86 | diff_2_right = nn.functional.adaptive_avg_pool2d(diff_right, (h, w))
87 |
88 | disp_2_left = nn.functional.adaptive_avg_pool2d(disp_left, (h, w))
89 | disp_2_right = nn.functional.adaptive_avg_pool2d(disp_right, (h, w))
90 |
91 | disp_feature_2 = nn.functional.adaptive_avg_pool2d(disp_feature, (h, w))
92 |
93 | depth_aware_left = self.depth_sense_l(torch.cat([disp_feature_2, disp_2_left.view(b,1,h,w)], 1))
94 | depth_aware_right = self.depth_sense_r(torch.cat([disp_feature_2, disp_2_right.view(b,1,h,w)], 1))
95 |
96 | # the larger, the more accurate
97 | gate_left = self.gatenet(diff_2_left)
98 | gate_right = self.gatenet(diff_2_right)
99 |
100 | warp_convd_left = disp_warp(convd_right, -disp_2_left)
101 | warp_convd_right = disp_warp(convd_left, disp_2_right)
102 |
103 | # aggregate features
104 | agg_left = convd_left * (1.0-gate_left) + warp_convd_left * gate_left.repeat(1,c,1,1)
105 | agg_right = convd_right * (1.0-gate_right) + warp_convd_right * gate_right.repeat(1,c,1,1)
106 |
107 | # decoder-left
108 | cat3_left = self.upconv3_i(torch.cat([convd_left, agg_left, depth_aware_left], 1))
109 | upconv3_left = self.upconv3_1(self.upconv3_2(self.upconv3_3(cat3_left))) # upconv3 feature
110 |
111 | upconv2_u_left = self.upconv2_u(upconv3_left)
112 | cat2_left = self.upconv2_i(torch.cat([conv2_left, upconv2_u_left],1))
113 | upconv2_left = self.upconv2_1(self.upconv2_2(self.upconv2_3(cat2_left))) # upconv2 feature
114 | upconv1_u_left = self.upconv1_u(upconv2_left)
115 | cat1_left = self.upconv1_i(torch.cat([conv1_left, upconv1_u_left], 1))
116 |
117 | upconv1_left = self.upconv1_1(self.upconv1_2(self.upconv1_3(cat1_left))) # upconv1 feature
118 | img_prd_left = self.img_prd(upconv1_left) + img_left # predict img
119 |
120 | # decoder-right
121 | cat3_right = self.upconv3_i(torch.cat([convd_right, agg_right, depth_aware_right], 1))
122 | upconv3_right = self.upconv3_1(self.upconv3_2(self.upconv3_3(cat3_right))) # upconv3 feature
123 |
124 | upconv2_u_right = self.upconv2_u(upconv3_right)
125 | cat2_right = self.upconv2_i(torch.cat([conv2_right, upconv2_u_right], 1))
126 | upconv2_right = self.upconv2_1(self.upconv2_2(self.upconv2_3(cat2_right))) # upconv2 feature
127 | upconv1_u_right = self.upconv1_u(upconv2_right)
128 | cat1_right = self.upconv1_i(torch.cat([conv1_right, upconv1_u_right], 1))
129 |
130 | upconv1_right = self.upconv1_1(self.upconv1_2(self.upconv1_3(cat1_right))) # upconv1 feature
131 | img_prd_right = self.img_prd(upconv1_right) + img_right # predict img
132 |
133 | imgs_prd = [img_prd_left, img_prd_right]
134 |
135 | diff = [diff_left, diff_right]
136 | gate = [gate_left, gate_right]
137 |
138 | return imgs_prd, diff, gate
139 |
--------------------------------------------------------------------------------
/models/VGG19.py:
--------------------------------------------------------------------------------
1 | from models.submodules import *
2 | import torchvision.models
3 |
4 | class VGG19(nn.Module):
5 | def __init__(self):
6 | super(VGG19, self).__init__()
7 | '''
8 | use vgg19 conv1_2, conv2_2, conv3_3 feature, before relu layer
9 | '''
10 | self.feature_list = [2, 7, 14]
11 | vgg19 = torchvision.models.vgg19(pretrained=True)
12 |
13 | self.model = torch.nn.Sequential(*list(vgg19.features.children())[:self.feature_list[-1]+1])
14 |
15 |
16 | def forward(self, x):
17 | x = (x-0.5)/0.5
18 | features = []
19 | for i, layer in enumerate(list(self.model)):
20 | x = layer(x)
21 | if i in self.feature_list:
22 | features.append(x)
23 |
24 | return features
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/models/submodules.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Developed by Shangchen Zhou
5 |
6 | import torch.nn as nn
7 | import torch
8 | import numpy as np
9 | from config import cfg
10 |
11 | def conv(in_channels, out_channels, kernel_size=3, stride=1,dilation=1, bias=True):
12 | return nn.Sequential(
13 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias),
14 | nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE,inplace=True)
15 | )
16 |
17 | def predict_disp(in_channels):
18 | return nn.Conv2d(in_channels,1,kernel_size=3,stride=1,padding=1,bias=True)
19 |
20 | def predict_disp_bi(in_channels):
21 | return nn.Conv2d(in_channels,2,kernel_size=3,stride=1,padding=1,bias=True)
22 |
23 | def up_disp_bi():
24 | return nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
25 |
26 | def predict_occ(in_channels):
27 | return nn.Conv2d(in_channels,1,kernel_size=3,stride=1,padding=1,bias=True)
28 |
29 | def predict_occ_bi(in_channels):
30 | return nn.Conv2d(in_channels,2,kernel_size=3,stride=1,padding=1,bias=True)
31 |
32 | def upconv(in_channels, out_channels):
33 | return nn.Sequential(
34 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=True),
35 | nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE,inplace=True)
36 | )
37 |
38 | def resnet_block(in_channels, kernel_size=3, dilation=[1,1], bias=True):
39 | return ResnetBlock(in_channels, kernel_size, dilation, bias=bias)
40 |
41 | class ResnetBlock(nn.Module):
42 | def __init__(self, in_channels, kernel_size, dilation, bias):
43 | super(ResnetBlock, self).__init__()
44 | self.stem = nn.Sequential(
45 | nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=1, dilation=dilation[0], padding=((kernel_size-1)//2)*dilation[0], bias=bias),
46 | nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE, inplace=True),
47 | nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=1, dilation=dilation[1], padding=((kernel_size-1)//2)*dilation[1], bias=bias),
48 | )
49 | def forward(self, x):
50 | out = self.stem(x) + x
51 | return out
52 |
53 |
54 | def gatenet(bias=True):
55 | return nn.Sequential(
56 | nn.Conv2d(1, 16, kernel_size=3, stride=1, dilation=1, padding=1, bias=bias),
57 | nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE,inplace=True),
58 | resnet_block(16, kernel_size=1),
59 | nn.Conv2d(16, 1, kernel_size=1, padding=0),
60 | nn.Sigmoid()
61 | )
62 |
63 | def depth_sense(in_channels, out_channels, kernel_size=3, dilation=1, bias=True):
64 | return nn.Sequential(
65 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, dilation=1,padding=((kernel_size - 1) // 2)*dilation, bias=bias),
66 | nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE, inplace=True),
67 | resnet_block(out_channels, kernel_size= 3),
68 | )
69 |
70 | def conv2x(in_channels, kernel_size=3,dilation=[1,1], bias=True):
71 | return nn.Sequential(
72 | nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=1, dilation=dilation[0], padding=((kernel_size-1)//2)*dilation[0], bias=bias),
73 | nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE,inplace=True),
74 | nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=1, dilation=dilation[1], padding=((kernel_size-1)//2)*dilation[1], bias=bias),
75 | nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE, inplace=True)
76 | )
77 |
78 |
79 | def ms_dilate_block(in_channels, kernel_size=3, dilation=[1,1,1,1], bias=True):
80 | return MSDilateBlock(in_channels, kernel_size, dilation, bias)
81 |
82 | class MSDilateBlock(nn.Module):
83 | def __init__(self, in_channels, kernel_size, dilation, bias):
84 | super(MSDilateBlock, self).__init__()
85 | self.conv1 = conv(in_channels, in_channels, kernel_size,dilation=dilation[0], bias=bias)
86 | self.conv2 = conv(in_channels, in_channels, kernel_size,dilation=dilation[1], bias=bias)
87 | self.conv3 = conv(in_channels, in_channels, kernel_size,dilation=dilation[2], bias=bias)
88 | self.conv4 = conv(in_channels, in_channels, kernel_size,dilation=dilation[3], bias=bias)
89 | self.convi = nn.Conv2d(in_channels*4, in_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size-1)//2, bias=bias)
90 | def forward(self, x):
91 | conv1 = self.conv1(x)
92 | conv2 = self.conv2(x)
93 | conv3 = self.conv3(x)
94 | conv4 = self.conv4(x)
95 | cat = torch.cat([conv1, conv2, conv3, conv4], 1)
96 | out = self.convi(cat) + x
97 | return out
98 |
99 |
100 | def cat_with_crop(target, input):
101 | output = []
102 | for item in input:
103 | if item.size()[2:] == target.size()[2:]:
104 | output.append(item)
105 | else:
106 | output.append(item[:, :, :target.size(2), :target.size(3)])
107 | output = torch.cat(output,1)
108 | return output
109 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | argparse
2 | easydict
3 | numpy
4 | matplotlib
5 | scipy
6 | opencv-python==4.0.0.21
7 | torch==0.4.1
8 | torchvision==0.2.0
9 | tensorboardX
10 | OpenEXR
11 | pyexr
12 |
--------------------------------------------------------------------------------
/runner.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Developed by Shangchen Zhou
5 |
6 | import matplotlib
7 | import os
8 | # Fix problem: no $DISPLAY environment variable
9 | matplotlib.use('Agg')
10 |
11 | from argparse import ArgumentParser
12 | from pprint import pprint
13 |
14 | from config import cfg
15 | from core.build import bulid_net
16 | import torch
17 |
18 | def get_args_from_command_line():
19 | parser = ArgumentParser(description='Parser of Runner of Network')
20 | parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [cuda]', default=cfg.CONST.DEVICE, type=str)
21 | parser.add_argument('--phase', dest='phase', help='phase of CNN', default=cfg.NETWORK.PHASE, type=str)
22 | parser.add_argument('--weights', dest='weights', help='Initialize network from the weights file', default=cfg.CONST.WEIGHTS, type=str)
23 | parser.add_argument('--data', dest='data_path', help='Set dataset root_path', default=cfg.DIR.DATASET_ROOT, type=str)
24 | parser.add_argument('--out', dest='out_path', help='Set output path', default=cfg.DIR.OUT_PATH)
25 | args = parser.parse_args()
26 | return args
27 |
28 | def main():
29 | # Get args from command line
30 | args = get_args_from_command_line()
31 |
32 | if args.gpu_id is not None:
33 | cfg.CONST.DEVICE = args.gpu_id
34 | if args.phase is not None:
35 | cfg.NETWORK.PHASE = args.phase
36 | if args.weights is not None:
37 | cfg.CONST.WEIGHTS = args.weights
38 | if args.data_path is not None:
39 | cfg.DIR.DATASET_ROOT = args.data_path
40 | if args.out_path is not None:
41 | cfg.DIR.OUT_PATH = args.out_path
42 |
43 | # Print config
44 | print('Use config:')
45 | pprint(cfg)
46 |
47 | # Set GPU to use
48 | if type(cfg.CONST.DEVICE) == str and not cfg.CONST.DEVICE == 'all':
49 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.CONST.DEVICE
50 | print('CUDA DEVICES NUMBER: '+ str(torch.cuda.device_count()))
51 |
52 | # Setup Network & Start train/test process
53 | bulid_net(cfg)
54 |
55 | if __name__ == '__main__':
56 | main()
57 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sczhou/DAVANet/2bbe35ae01c0f1af718a1bc19272cda5ed3c320a/utils/__init__.py
--------------------------------------------------------------------------------
/utils/data_loaders.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Developed by Shangchen Zhou
5 |
6 | import cv2
7 | import json
8 | import numpy as np
9 | import os
10 | import io
11 | import random
12 | import scipy.io
13 | import sys
14 | import torch.utils.data.dataset
15 |
16 | from config import cfg
17 | from datetime import datetime as dt
18 | from enum import Enum, unique
19 | from utils.imgio_gen import readgen
20 | import utils.network_utils
21 |
22 | class DatasetType(Enum):
23 | TRAIN = 0
24 | TEST = 1
25 |
26 | class FlyingThings3DDataset(torch.utils.data.dataset.Dataset):
27 | """DrivingDataset class used for PyTorch DataLoader"""
28 |
29 | def __init__(self, file_list_with_metadata, transforms = None):
30 | self.file_list = file_list_with_metadata
31 | self.transforms = transforms
32 |
33 | def __len__(self):
34 | return len(self.file_list)
35 |
36 | def __getitem__(self, idx):
37 | imgs, disps = self.get_datum(idx)
38 | imgs, disps = self.transforms(imgs, disps)
39 | if cfg.DATASET.WITH_MASK:
40 | occs = utils.network_utils.get_occ([img.view(1, *img.shape) for img in imgs], [disp * cfg.DATA.DIV_DISP for disp in disps], cuda=False)
41 | else:
42 | _, H, W = imgs[0].shape
43 | occs = [torch.ones((1,H,W), dtype=torch.float32), torch.ones((1,H,W), dtype=torch.float32)]
44 | name = []
45 | return name, imgs, disps, occs
46 |
47 | def get_datum(self, idx):
48 | img_left_path = self.file_list[idx]['img_left']
49 | img_right_path = self.file_list[idx]['img_right']
50 | disp_left_path = self.file_list[idx]['disp_left']
51 | disp_right_path = self.file_list[idx]['disp_right']
52 |
53 | img_left = readgen(img_left_path).astype(np.float32)
54 | img_right = readgen(img_right_path).astype(np.float32)
55 | imgs = [img_left, img_right]
56 |
57 | disp_left = readgen(disp_left_path).astype(np.float32)
58 | disp_right = readgen(disp_right_path).astype(np.float32)
59 |
60 | disps = [disp_left, disp_right]
61 | return imgs, disps
62 | # //////////////////////////////// = End of FlyingThings3DDataset Class Definition = ///////////////////////////////// #
63 |
64 | class FlyingThings3DDataLoader:
65 | def __init__(self):
66 | self.img_left_path_template = cfg.DIR.IMAGE_LEFT_PATH
67 | self.img_right_path_template = cfg.DIR.IMAGE_RIGHT_PATH
68 | self.disp_left_path_template = cfg.DIR.DISPARITY_LEFT_PATH
69 | self.disp_right_path_template = cfg.DIR.DISPARITY_RIGHT_PATH
70 | # Load all files of the dataset
71 | with io.open(cfg.DIR.DATASET_JSON_FILE_PATH, encoding='utf-8') as file:
72 | self.files_list = json.loads(file.read())
73 |
74 | def get_dataset(self, dataset_type, transforms=None):
75 | files = []
76 | # Load data for each category
77 | for file in self.files_list:
78 | if dataset_type == DatasetType.TRAIN and (file['phase'] == 'TRAIN' or file['phase'] == 'TEST'):
79 | categories = file['categories']
80 | phase = file['phase']
81 | classes = file['classes']
82 | names = file['names']
83 | samples = file['sample']
84 | print('[INFO] %s Collecting files of Taxonomy [categories = %s, phase = %s, classes = %s, names = %s]' % (
85 | dt.now(), categories, phase, classes, names))
86 | files.extend(
87 | self.get_files_of_taxonomy(categories, phase, classes, names, samples))
88 | elif dataset_type == DatasetType.TEST and file['phase'] == 'TEST':
89 | categories = file['categories']
90 | phase = file['phase']
91 | classes = file['classes']
92 | names = file['names']
93 | samples = file['sample']
94 | print('[INFO] %s Collecting files of Taxonomy [categories = %s, phase = %s, classes = %s, names = %s]' % (
95 | dt.now(), categories, phase, classes, names))
96 | files.extend(
97 | self.get_files_of_taxonomy(categories, phase, classes, names, samples))
98 |
99 | print('[INFO] %s Complete collecting files of the dataset for %s. Total files: %d.' % (dt.now(), dataset_type.name, len(files)))
100 | return FlyingThings3DDataset(files, transforms)
101 |
102 | def get_files_of_taxonomy(self,categories, phase, classes, names, samples):
103 |
104 | # n_samples = len(samples)
105 | files_of_taxonomy = []
106 | for sample_idx, sample_name in enumerate(samples):
107 | # Get file path of img
108 | img_left_path = self.img_left_path_template % (categories, phase, classes, names, sample_name)
109 | img_right_path = self.img_right_path_template % (categories, phase, classes, names, sample_name)
110 | disp_left_path = self.disp_left_path_template % (phase, classes, names, sample_name)
111 | disp_right_path = self.disp_right_path_template % (phase, classes, names, sample_name)
112 |
113 | if os.path.exists(img_left_path) and os.path.exists(img_right_path) and os.path.exists(
114 | disp_left_path) and os.path.exists(disp_right_path):
115 | files_of_taxonomy.append({
116 | 'img_left': img_left_path,
117 | 'img_right': img_right_path,
118 | 'disp_left': disp_left_path,
119 | 'disp_right': disp_right_path,
120 | 'categories': categories,
121 | 'classes' : classes,
122 | 'names': names,
123 | 'sample_name': sample_name
124 | })
125 | return files_of_taxonomy
126 | # /////////////////////////////// = End of FlyingThings3DDataLoader Class Definition = /////////////////////////////// #
127 |
128 | class StereoDeblurDataset(torch.utils.data.dataset.Dataset):
129 | """StereoDeblurDataset class used for PyTorch DataLoader"""
130 |
131 | def __init__(self, file_list_with_metadata, transforms = None):
132 | self.file_list = file_list_with_metadata
133 | self.transforms = transforms
134 |
135 | def __len__(self):
136 | return len(self.file_list)
137 |
138 | def __getitem__(self, idx):
139 | name, imgs, disps = self.get_datum(idx)
140 | imgs, disps = self.transforms(imgs, disps)
141 | occs = utils.network_utils.get_occ([img.view(1, *img.shape) for img in imgs[-2:]], [disp * cfg.DATA.DIV_DISP for disp in disps], cuda=False)
142 | # remove nan and inf pixel
143 | disps[0][occs[0]==0] = 0
144 | disps[1][occs[1]==0] = 0
145 | return name, imgs, disps, occs
146 |
147 | def get_datum(self, idx):
148 |
149 | name = self.file_list[idx]['name']
150 | img_blur_left_path = self.file_list[idx]['img_blur_left']
151 | img_blur_right_path = self.file_list[idx]['img_blur_right']
152 | img_clear_left_path = self.file_list[idx]['img_clear_left']
153 | img_clear_right_path = self.file_list[idx]['img_clear_right']
154 | disp_left_path = self.file_list[idx]['disp_left']
155 | disp_right_path = self.file_list[idx]['disp_right']
156 |
157 | img_blur_left = readgen(img_blur_left_path).astype(np.float32)
158 | img_blur_right = readgen(img_blur_right_path).astype(np.float32)
159 | img_clear_left = readgen(img_clear_left_path).astype(np.float32)
160 | img_clear_right = readgen(img_clear_right_path).astype(np.float32)
161 | imgs = [img_blur_left, img_blur_right, img_clear_left, img_clear_right]
162 |
163 | disp_left = readgen(disp_left_path).astype(np.float32)
164 | disp_right = readgen(disp_right_path).astype(np.float32)
165 |
166 | disps = [disp_left, disp_right]
167 | return name, imgs, disps
168 | # //////////////////////////////// = End of StereoDeblurDataset Class Definition = ///////////////////////////////// #
169 |
170 | class StereoDeblurLoader:
171 | def __init__(self):
172 | self.img_left_blur_path_template = cfg.DIR.IMAGE_LEFT_BLUR_PATH
173 | self.img_left_clear_path_template = cfg.DIR.IMAGE_LEFT_CLEAR_PATH
174 | self.img_right_blur_path_template = cfg.DIR.IMAGE_RIGHT_BLUR_PATH
175 | self.img_right_clear_path_template = cfg.DIR.IMAGE_RIGHT_CLEAR_PATH
176 | self.disp_left_path_template = cfg.DIR.DISPARITY_LEFT_PATH
177 | self.disp_right_path_template = cfg.DIR.DISPARITY_RIGHT_PATH
178 | # Load all files of the dataset
179 | with io.open(cfg.DIR.DATASET_JSON_FILE_PATH, encoding='utf-8') as file:
180 | self.files_list = json.loads(file.read())
181 |
182 | def get_dataset(self, dataset_type, transforms=None):
183 | files = []
184 | # Load data for each sequence
185 | for file in self.files_list:
186 | if dataset_type == DatasetType.TRAIN and file['phase'] == 'Train':
187 | name = file['name']
188 | pair_num = file['pair_num']
189 | samples = file['sample']
190 | files_num_old = len(files)
191 | files.extend(self.get_files_of_taxonomy(name, samples))
192 | print('[INFO] %s Collecting files of Taxonomy [Name = %s, Pair Numbur = %s, Loaded = %r]' % (
193 | dt.now(), name, pair_num, pair_num == (len(files)-files_num_old)))
194 | elif dataset_type == DatasetType.TEST and file['phase'] == 'Test':
195 | name = file['name']
196 | pair_num = file['pair_num']
197 | samples = file['sample']
198 | files_num_old = len(files)
199 | files.extend(self.get_files_of_taxonomy(name, samples))
200 | print('[INFO] %s Collecting files of Taxonomy [Name = %s, Pair Numbur = %s, Loaded = %r]' % (
201 | dt.now(), name, pair_num, pair_num == (len(files)-files_num_old)))
202 |
203 | print('[INFO] %s Complete collecting files of the dataset for %s. Total Pair Numbur: %d.\n' % (dt.now(), dataset_type.name, len(files)))
204 | return StereoDeblurDataset(files, transforms)
205 |
206 | def get_files_of_taxonomy(self, name, samples):
207 |
208 | # n_samples = len(samples)
209 | files_of_taxonomy = []
210 | for sample_idx, sample_name in enumerate(samples):
211 | # Get file path of img
212 | img_left_clear_path = self.img_left_clear_path_template % (name, sample_name)
213 | img_right_clear_path = self.img_right_clear_path_template % (name, sample_name)
214 | img_left_blur_path = self.img_left_blur_path_template % (name, sample_name)
215 | img_right_blur_path = self.img_right_blur_path_template % (name, sample_name)
216 | disp_left_path = self.disp_left_path_template % (name, sample_name)
217 | disp_right_path = self.disp_right_path_template % (name, sample_name)
218 |
219 | if os.path.exists(img_left_blur_path) and os.path.exists(img_right_blur_path) and os.path.exists(
220 | img_left_clear_path) and os.path.exists(img_right_clear_path) and os.path.exists(
221 | disp_left_path) and os.path.exists(disp_right_path):
222 | files_of_taxonomy.append({
223 | 'name': name,
224 | 'img_blur_left': img_left_blur_path,
225 | 'img_blur_right': img_right_blur_path,
226 | 'img_clear_left': img_left_clear_path,
227 | 'img_clear_right': img_right_clear_path,
228 | 'disp_left': disp_left_path,
229 | 'disp_right': disp_right_path,
230 | })
231 | return files_of_taxonomy
232 | # /////////////////////////////// = End of StereoDeblurLoader Class Definition = /////////////////////////////// #
233 |
234 |
235 | DATASET_LOADER_MAPPING = {
236 | 'FlyingThings3D': FlyingThings3DDataLoader,
237 | 'StereoDeblur': StereoDeblurLoader,
238 | }
239 |
--------------------------------------------------------------------------------
/utils/data_transforms.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Developed by Shangchen Zhou
5 | '''ref: http://pytorch.org/docs/master/torchvision/transforms.html'''
6 |
7 |
8 | import cv2
9 | import numpy as np
10 | import torch
11 | import torchvision.transforms.functional as F
12 | from config import cfg
13 | from PIL import Image
14 | import random
15 | import numbers
16 | class Compose(object):
17 | """ Composes several co_transforms together.
18 | For example:
19 | >>> transforms.Compose([
20 | >>> transforms.CenterCrop(10),
21 | >>> transforms.ToTensor(),
22 | >>> ])
23 | """
24 |
25 | def __init__(self, transforms):
26 | self.transforms = transforms
27 |
28 | def __call__(self, inputs, disps):
29 | for t in self.transforms:
30 | inputs,disps = t(inputs, disps)
31 | return inputs, disps
32 |
33 |
34 | class ColorJitter(object):
35 | def __init__(self, color_adjust_para):
36 | """brightness [max(0, 1 - brightness), 1 + brightness] or the given [min, max]"""
37 | """contrast [max(0, 1 - contrast), 1 + contrast] or the given [min, max]"""
38 | """saturation [max(0, 1 - saturation), 1 + saturation] or the given [min, max]"""
39 | """hue [-hue, hue] 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5"""
40 | '''Ajust brightness, contrast, saturation, hue'''
41 | '''Input: PIL Image, Output: PIL Image'''
42 | self.brightness, self.contrast, self.saturation, self.hue = color_adjust_para
43 |
44 | def __call__(self, inputs, disps):
45 | inputs = [Image.fromarray(np.uint8(inp)) for inp in inputs]
46 | if self.brightness > 0:
47 | brightness_factor = np.random.uniform(max(0, 1 - self.brightness), 1 + self.brightness)
48 | inputs = [F.adjust_brightness(inp, brightness_factor) for inp in inputs]
49 |
50 | if self.contrast > 0:
51 | contrast_factor = np.random.uniform(max(0, 1 - self.contrast), 1 + self.contrast)
52 | inputs = [F.adjust_contrast(inp, contrast_factor) for inp in inputs]
53 |
54 | if self.saturation > 0:
55 | saturation_factor = np.random.uniform(max(0, 1 - self.saturation), 1 + self.saturation)
56 | inputs = [F.adjust_saturation(inp, saturation_factor) for inp in inputs]
57 |
58 | if self.hue > 0:
59 | hue_factor = np.random.uniform(-self.hue, self.hue)
60 | inputs = [F.adjust_hue(inp, hue_factor) for inp in inputs]
61 |
62 | inputs = [np.asarray(inp) for inp in inputs]
63 | inputs = [inp.clip(0,255) for inp in inputs]
64 |
65 | return inputs, disps
66 |
67 | class RandomColorChannel(object):
68 | def __call__(self, inputs, disps):
69 | random_order = np.random.permutation(3)
70 | inputs = [inp[:,:,random_order] for inp in inputs]
71 |
72 | return inputs, disps
73 |
74 | class RandomGaussianNoise(object):
75 | def __init__(self, gaussian_para):
76 | self.mu = gaussian_para[0]
77 | self.std_var = gaussian_para[1]
78 |
79 | def __call__(self, inputs, disps):
80 |
81 | shape = inputs[0].shape
82 | gaussian_noise = np.random.normal(self.mu, self.std_var, shape)
83 | # only apply to blurry images
84 | inputs[0] = inputs[0]+gaussian_noise
85 | inputs[1] = inputs[1]+gaussian_noise
86 |
87 | inputs = [inp.clip(0, 1) for inp in inputs]
88 |
89 | return inputs, disps
90 |
91 | class Normalize(object):
92 | def __init__(self, mean, std, div_disp):
93 | self.mean = mean
94 | self.std = std
95 | self.div_disp = div_disp
96 | def __call__(self, inputs, disps):
97 | assert(all([isinstance(inp, np.ndarray) for inp in inputs]))
98 | inputs = [inp/self.std -self.mean for inp in inputs]
99 | disps = [d/self.div_disp for d in disps]
100 | return inputs, disps
101 |
102 | class CenterCrop(object):
103 |
104 | def __init__(self, crop_size):
105 | """Set the height and weight before and after cropping"""
106 |
107 | self.crop_size_h = crop_size[0]
108 | self.crop_size_w = crop_size[1]
109 |
110 | def __call__(self, inputs, disps):
111 | input_size_h, input_size_w, _ = inputs[0].shape
112 | x_start = int(round((input_size_w - self.crop_size_w) / 2.))
113 | y_start = int(round((input_size_h - self.crop_size_h) / 2.))
114 |
115 | inputs = [inp[y_start: y_start + self.crop_size_h, x_start: x_start + self.crop_size_w] for inp in inputs]
116 | disps = [disp[y_start: y_start + self.crop_size_h, x_start: x_start + self.crop_size_w] for disp in disps]
117 |
118 | return inputs, disps
119 |
120 | class RandomCrop(object):
121 |
122 | def __init__(self, crop_size):
123 | """Set the height and weight before and after cropping"""
124 | self.crop_size_h = crop_size[0]
125 | self.crop_size_w = crop_size[1]
126 |
127 | def __call__(self, inputs, disps):
128 | input_size_h, input_size_w, _ = inputs[0].shape
129 | x_start = random.randint(0, input_size_w - self.crop_size_w)
130 | y_start = random.randint(0, input_size_h - self.crop_size_h)
131 | inputs = [inp[y_start: y_start + self.crop_size_h, x_start: x_start + self.crop_size_w] for inp in inputs]
132 | disps = [disp[y_start: y_start + self.crop_size_h, x_start: x_start + self.crop_size_w] for disp in disps]
133 |
134 | return inputs, disps
135 |
136 | class RandomHorizontalFlip(object):
137 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 left-right"""
138 |
139 | def __call__(self, inputs, disps):
140 | if random.random() < 0.5:
141 | '''Change the order of 0 and 1, for keeping the net search direction'''
142 | inputs[0] = np.copy(np.fliplr(inputs[1]))
143 | inputs[1] = np.copy(np.fliplr(inputs[0]))
144 | inputs[2] = np.copy(np.fliplr(inputs[3]))
145 | inputs[3] = np.copy(np.fliplr(inputs[2]))
146 |
147 | disps[0] = np.copy(np.fliplr(disps[1]))
148 | disps[1] = np.copy(np.fliplr(disps[0]))
149 |
150 | return inputs, disps
151 |
152 |
153 | class RandomVerticalFlip(object):
154 | """Randomly vertically flips the given PIL.Image with a probability of 0.5 up-down"""
155 | def __call__(self, inputs, disps):
156 | if random.random() < 0.5:
157 | inputs = [np.copy(np.flipud(inp)) for inp in inputs]
158 | disps = [np.copy(np.flipud(disp)) for disp in disps]
159 | return inputs, disps
160 |
161 |
162 | class ToTensor(object):
163 | """Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W)."""
164 |
165 | def __call__(self, inputs, disps):
166 | assert(isinstance(inputs[0], np.ndarray) and isinstance(inputs[1], np.ndarray))
167 | inputs = [np.transpose(inp, (2, 0, 1)) for inp in inputs]
168 | inputs_tensor = [torch.from_numpy(inp).float() for inp in inputs]
169 |
170 | assert(isinstance(disps[0], np.ndarray) and isinstance(disps[1], np.ndarray))
171 | disps_tensor = [torch.from_numpy(d) for d in disps]
172 | disps_tensor = [d.view(1, d.size()[0],d.size()[1]).float() for d in disps_tensor]
173 | return inputs_tensor, disps_tensor
174 |
175 |
--------------------------------------------------------------------------------
/utils/imgio_gen.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3.4
2 |
3 | import os
4 | import re
5 | import numpy as np
6 | import uuid
7 | from scipy import misc
8 | import numpy as np
9 | import pyexr
10 | import sys
11 | import cv2
12 |
13 | def readgen(file):
14 | if file.endswith('.float3'): return readFloat(file)
15 | elif file.endswith('.flo'): return readFlow(file)
16 | elif file.endswith('.ppm'): return readImage(file)
17 | elif file.endswith('.pgm'): return readImage(file)
18 | elif file.endswith('.png'): return readImage(file)
19 | elif file.endswith('.jpg'): return readImage(file)
20 | elif file.endswith('.pfm'): return readPFM(file)[0]
21 | elif file.endswith('.exr'): return pyexr.open(file).get() #https://github.com/tvogels/pyexr
22 | else: raise Exception('don\'t know how to read %s' % file)
23 |
24 | def writegen(file, data):
25 | if file.endswith('.float3'): return writeFloat(file, data)
26 | elif file.endswith('.flo'): return writeFlow(file, data)
27 | elif file.endswith('.ppm'): return writeImage(file, data)
28 | elif file.endswith('.pgm'): return writeImage(file, data)
29 | elif file.endswith('.png'): return writeImage(file, data)
30 | elif file.endswith('.jpg'): return writeImage(file, data)
31 | elif file.endswith('.pfm'): return writePFM(file, data)
32 | elif file.endswith('.exr'): return pyexr.write(file, data) #https://github.com/tvogels/pyexr
33 | else: raise Exception('don\'t know how to write %s' % file)
34 |
35 | def readPFM(file):
36 | file = open(file, 'rb')
37 |
38 | color = None
39 | width = None
40 | height = None
41 | scale = None
42 | endian = None
43 |
44 | header = file.readline().rstrip()
45 | if header.decode("ascii") == 'PF':
46 | color = True
47 | elif header.decode("ascii") == 'Pf':
48 | color = False
49 | else:
50 | raise Exception('Not a PFM file.')
51 |
52 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii"))
53 | if dim_match:
54 | width, height = list(map(int, dim_match.groups()))
55 | else:
56 | raise Exception('Malformed PFM header.')
57 |
58 | scale = float(file.readline().decode("ascii").rstrip())
59 | if scale < 0: # little-endian
60 | endian = '<'
61 | scale = -scale
62 | else:
63 | endian = '>' # big-endian
64 |
65 | data = np.fromfile(file, endian + 'f')
66 | shape = (height, width, 3) if color else (height, width)
67 |
68 | data = np.reshape(data, shape)
69 | data = np.flipud(data)
70 | return data, scale
71 |
72 | def writePFM(file, image, scale=1):
73 | file = open(file, 'wb')
74 | color = None
75 | if image.dtype.name != 'float32':
76 | raise Exception('Image dtype must be float32.')
77 |
78 | image = np.flipud(image)
79 |
80 | if len(image.shape) == 3 and image.shape[2] == 3: # color image
81 | color = True
82 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale
83 | color = False
84 | else:
85 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')
86 |
87 | file.write('PF\n' if color else 'Pf\n'.encode())
88 | file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0]))
89 |
90 | endian = image.dtype.byteorder
91 |
92 | if endian == '<' or endian == '=' and sys.byteorder == 'little':
93 | scale = -scale
94 |
95 | file.write('%f\n'.encode() % scale)
96 |
97 | image.tofile(file)
98 |
99 | def readFlow(name):
100 | if name.endswith('.pfm') or name.endswith('.PFM'):
101 | return readPFM(name)[0][:,:,0:2]
102 |
103 | f = open(name, 'rb')
104 |
105 | header = f.read(4)
106 | if header.decode("utf-8") != 'PIEH':
107 | raise Exception('Flow file header does not contain PIEH')
108 |
109 | width = np.fromfile(f, np.int32, 1).squeeze()
110 | height = np.fromfile(f, np.int32, 1).squeeze()
111 |
112 | flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2))
113 |
114 | return flow.astype(np.float32)
115 |
116 | def readImage(name):
117 | if name.endswith('.pfm') or name.endswith('.PFM'):
118 | data = readPFM(name)[0]
119 | if len(data.shape)==3:
120 | return data[:,:,0:3]
121 | else:
122 | return data
123 |
124 | return cv2.imread(name)
125 |
126 | def writeImage(name, data):
127 | if name.endswith('.pfm') or name.endswith('.PFM'):
128 | return writePFM(name, data, 1)
129 |
130 | return misc.imsave(name, data)
131 |
132 | def writeFlow(name, flow):
133 | f = open(name, 'wb')
134 | f.write('PIEH'.encode('utf-8'))
135 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
136 | flow = flow.astype(np.float32)
137 | flow.tofile(f)
138 |
139 | def readFloat(name):
140 | f = open(name, 'rb')
141 |
142 | if(f.readline().decode("utf-8")) != 'float\n':
143 | raise Exception('float file %s did not contain keyword' % name)
144 |
145 | dim = int(f.readline())
146 |
147 | dims = []
148 | count = 1
149 | for i in range(0, dim):
150 | d = int(f.readline())
151 | dims.append(d)
152 | count *= d
153 |
154 | dims = list(reversed(dims))
155 |
156 | data = np.fromfile(f, np.float32, count).reshape(dims)
157 | if dim > 2:
158 | data = np.transpose(data, (2, 1, 0))
159 | data = np.transpose(data, (1, 0, 2))
160 |
161 | return data
162 |
163 | def writeFloat(name, data):
164 | f = open(name, 'wb')
165 |
166 | dim=len(data.shape)
167 | if dim>3:
168 | raise Exception('bad float file dimension: %d' % dim)
169 |
170 | f.write(('float\n').encode('ascii'))
171 | f.write(('%d\n' % dim).encode('ascii'))
172 |
173 | if dim == 1:
174 | f.write(('%d\n' % data.shape[0]).encode('ascii'))
175 | else:
176 | f.write(('%d\n' % data.shape[1]).encode('ascii'))
177 | f.write(('%d\n' % data.shape[0]).encode('ascii'))
178 | for i in range(2, dim):
179 | f.write(('%d\n' % data.shape[i]).encode('ascii'))
180 |
181 | data = data.astype(np.float32)
182 | if dim==2:
183 | data.tofile(f)
184 |
185 | else:
186 | np.transpose(data, (2, 0, 1)).tofile(f)
187 |
--------------------------------------------------------------------------------
/utils/network_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Developed by Shangchen Zhou
5 |
6 | import os
7 | import sys
8 | import torch
9 | import numpy as np
10 | from datetime import datetime as dt
11 | from config import cfg
12 | import torch.nn.functional as F
13 |
14 | import cv2
15 |
16 | def mkdir(path):
17 | if not os.path.isdir(path):
18 | mkdir(os.path.split(path)[0])
19 | else:
20 | return
21 | os.mkdir(path)
22 |
23 | def var_or_cuda(x):
24 | if torch.cuda.is_available():
25 | x = x.cuda(non_blocking=True)
26 | return x
27 |
28 |
29 | def init_weights_xavier(m):
30 | if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d):
31 | torch.nn.init.xavier_uniform_(m.weight)
32 | if m.bias is not None:
33 | torch.nn.init.constant_(m.bias, 0)
34 | elif type(m) == torch.nn.BatchNorm2d or type(m) == torch.nn.InstanceNorm2d:
35 | if m.weight is not None:
36 | torch.nn.init.constant_(m.weight, 1)
37 | torch.nn.init.constant_(m.bias, 0)
38 | elif type(m) == torch.nn.Linear:
39 | torch.nn.init.normal_(m.weight, 0, 0.01)
40 | torch.nn.init.constant_(m.bias, 0)
41 |
42 | def init_weights_kaiming(m):
43 | if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d):
44 | torch.nn.init.kaiming_normal_(m.weight)
45 | if m.bias is not None:
46 | torch.nn.init.constant_(m.bias, 0)
47 | elif type(m) == torch.nn.BatchNorm2d or type(m) == torch.nn.InstanceNorm2d:
48 | if m.weight is not None:
49 | torch.nn.init.constant_(m.weight, 1)
50 | torch.nn.init.constant_(m.bias, 0)
51 | elif type(m) == torch.nn.Linear:
52 | torch.nn.init.normal_(m.weight, 0, 0.01)
53 | torch.nn.init.constant_(m.bias, 0)
54 |
55 | def save_disp_checkpoints(file_path, epoch_idx, dispnet, dispnet_solver, Best_Disp_EPE, Best_Epoch):
56 | print('[INFO] %s Saving checkpoint to %s ...' % (dt.now(), file_path))
57 | checkpoint = {
58 | 'epoch_idx': epoch_idx,
59 | 'Best_Disp_EPE': Best_Disp_EPE,
60 | 'Best_Epoch': Best_Epoch,
61 | 'dispnet_state_dict': dispnet.state_dict(),
62 | 'dispnet_solver_state_dict': dispnet_solver.state_dict(),
63 | }
64 | torch.save(checkpoint, file_path)
65 |
66 | def save_deblur_checkpoints(file_path, epoch_idx, deblurnet, deblurnet_solver, Best_Img_PSNR, Best_Epoch):
67 | print('[INFO] %s Saving checkpoint to %s ...\n' % (dt.now(), file_path))
68 | checkpoint = {
69 | 'epoch_idx': epoch_idx,
70 | 'Best_Img_PSNR': Best_Img_PSNR,
71 | 'Best_Epoch': Best_Epoch,
72 | 'deblurnet_state_dict': deblurnet.state_dict(),
73 | 'deblurnet_solver_state_dict': deblurnet_solver.state_dict(),
74 | }
75 | torch.save(checkpoint, file_path)
76 |
77 | def save_checkpoints(file_path, epoch_idx, dispnet, dispnet_solver, deblurnet, deblurnet_solver, Disp_EPE, Best_Img_PSNR, Best_Epoch):
78 | print('[INFO] %s Saving checkpoint to %s ...' % (dt.now(), file_path))
79 | checkpoint = {
80 | 'epoch_idx': epoch_idx,
81 | 'Disp_EPE': Disp_EPE,
82 | 'Best_Img_PSNR': Best_Img_PSNR,
83 | 'Best_Epoch': Best_Epoch,
84 | 'dispnet_state_dict': dispnet.state_dict(),
85 | 'dispnet_solver_state_dict': dispnet_solver.state_dict(),
86 | 'deblurnet_state_dict': deblurnet.state_dict(),
87 | 'deblurnet_solver_state_dict': deblurnet_solver.state_dict(),
88 | }
89 | torch.save(checkpoint, file_path)
90 |
91 | def count_parameters(model):
92 | return sum(p.numel() for p in model.parameters())
93 |
94 | def get_weight_parameters(model):
95 | return [param for name, param in model.named_parameters() if ('weight' in name)]
96 |
97 | def get_bias_parameters(model):
98 | return [param for name, param in model.named_parameters() if ('bias' in name)]
99 |
100 | class AverageMeter(object):
101 | """Computes and stores the average and current value"""
102 | def __init__(self):
103 | self.reset()
104 |
105 | def reset(self):
106 | self.val = 0
107 | self.avg = 0
108 | self.sum = 0
109 | self.count = 0
110 |
111 | def update(self, val, n=1):
112 | self.val = val
113 | self.sum += val * n
114 | self.count += n
115 | self.avg = self.sum / self.count
116 |
117 | def __repr__(self):
118 | return '{:.5f} ({:.5f})'.format(self.val, self.avg)
119 |
120 | '''input Tensor: 2 H W'''
121 | def graybi2rgb(graybi):
122 | assert(isinstance(graybi, torch.Tensor))
123 | global args
124 | _, H, W = graybi.shape
125 | rgb_1 = torch.zeros((3,H,W))
126 | rgb_2 = torch.zeros((3,H,W))
127 | normalized_gray_map = graybi / (graybi.max())
128 | rgb_1[0] = normalized_gray_map[0]
129 | rgb_1[1] = normalized_gray_map[0]
130 | rgb_1[2] = normalized_gray_map[0]
131 |
132 | rgb_2[0] = normalized_gray_map[1]
133 | rgb_2[1] = normalized_gray_map[1]
134 | rgb_2[2] = normalized_gray_map[1]
135 | return rgb_1.clamp(0,1), rgb_2.clamp(0,1)
136 |
137 |
138 | def get_occ(imgs, disps, cuda = True):
139 | '''
140 | img: b, c, h, w
141 | disp: b, h, w
142 | '''
143 | assert(isinstance(imgs[0], torch.Tensor) and isinstance(imgs[1], torch.Tensor))
144 | assert(isinstance(disps[0], torch.Tensor) and isinstance(disps[1], torch.Tensor))
145 | if cuda == True:
146 | imgs = [var_or_cuda(img) for img in imgs]
147 | disps = [var_or_cuda(disp) for disp in disps]
148 | alpha = 0.001
149 | beta = 0.005
150 | B, _, H, W = imgs[0].shape
151 | disp_left = disps[0]
152 | disp_right = disps[1]
153 | mask0_lelf = ~np.logical_or(torch.isnan(disp_left), torch.isinf(disp_left))
154 | mask0_right = ~np.logical_or(torch.isnan(disp_right), torch.isinf(disp_right))
155 | disp_left[torch.isnan(disp_left)] = 0.0
156 | disp_right[torch.isnan(disp_right)] = 0.0
157 | disp_left[torch.isinf(disp_left)] = 0.0
158 | disp_right[torch.isinf(disp_right)] = 0.0
159 |
160 | img_warp_left = disp_warp(imgs[1], -disp_left, cuda = cuda)
161 | img_warp_right = disp_warp(imgs[0], disp_right, cuda = cuda)
162 |
163 | diff_left = (imgs[0] - img_warp_left)**2 - (alpha*(imgs[0]**2 + img_warp_left**2) + beta)
164 | mask1_left = torch.sum(diff_left, 1)<=0
165 | occ_left = torch.zeros((B,H,W), dtype=torch.float32)
166 | occ_left[np.logical_and(mask0_lelf, mask1_left)] = 1
167 |
168 | diff_right = (imgs[1] - img_warp_right)**2 - (alpha*(imgs[1]**2 + img_warp_right**2) + beta)
169 | mask1_right = torch.sum(diff_right, 1)<=0
170 | occ_right = torch.zeros((B,H,W), dtype=torch.float32)
171 | occ_right[np.logical_and(mask0_right, mask1_right)] = 1
172 |
173 | return [occ_left, occ_right]
174 |
175 |
176 | def disp_warp(img, disp, cuda=True):
177 | '''
178 | img.shape = b, c, h, w
179 | disp.shape = b, h, w
180 | '''
181 | b, c, h, w = img.shape
182 | if cuda == True:
183 | right_coor_x = (torch.arange(start=0, end=w, out=torch.cuda.FloatTensor())).repeat(b, h, 1)
184 | right_coor_y = (torch.arange(start=0, end=h, out=torch.cuda.FloatTensor())).repeat(b, w, 1).transpose(1, 2)
185 | else:
186 | right_coor_x = (torch.arange(start=0, end=w, out=torch.FloatTensor())).repeat(b, h, 1)
187 | right_coor_y = (torch.arange(start=0, end=h, out=torch.FloatTensor())).repeat(b, w, 1).transpose(1, 2)
188 | left_coor_x1 = right_coor_x + disp
189 | left_coor_norm1 = torch.stack((left_coor_x1 / (w - 1) * 2 - 1, right_coor_y / (h - 1) * 2 - 1), dim=1)
190 | ## backward warp
191 | warp_img = torch.nn.functional.grid_sample(img, left_coor_norm1.permute(0, 2, 3, 1))
192 |
193 | return warp_img
194 |
--------------------------------------------------------------------------------