├── LICENSE
├── README.md
├── dataloader
├── Joint_xLabel_dataLoader.py
├── NYUv2_dataLoader.py
└── __init__.py
├── figures
├── 4343-teaser.gif
└── demo.png
├── loss.py
├── models
├── __init__.py
├── attention_networks.py
├── depth_generator_networks.py
└── discriminator_networks.py
├── train.py
├── training
├── base_model.py
├── finetune_the_whole_system_with_depth_loss.py
├── jointly_train_depth_predictor_D_and_attention_module_A.py
├── train_initial_attention_module_A.py
├── train_initial_depth_predictor_D.py
├── train_inpainting_module_I.py
└── train_style_translator_T.py
└── utils
├── __init__.py
├── image_pool.py
└── metrics.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Yunhan Zhao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ARC
2 | This repo contains the Pytorch implementation of:
3 |
4 | [Domain Decluttering: Simplifying Images to Mitigate Synthetic-Real Domain Shift and Improve Depth Estimation](http://openaccess.thecvf.com/content_CVPR_2020/html/Zhao_Domain_Decluttering_Simplifying_Images_to_Mitigate_Synthetic-Real_Domain_Shift_and_CVPR_2020_paper.html)
5 |
6 | [Yunhan Zhao](https://www.ics.uci.edu/~yunhaz5/), [Shu Kong](http://www.cs.cmu.edu/~shuk/), [Daeyun Shin](https://research.dshin.org/) and [Charless Fowlkes](https://www.ics.uci.edu/~fowlkes/)
7 |
8 | CVPR 2020
9 |
10 | For more details, please check our [project website](https://www.ics.uci.edu/~yunhaz5/cvpr2020/domain_decluttering.html)
11 |
12 |
13 |
14 |
15 |
16 | ### Abstract
17 | Leveraging synthetically rendered data offers great potential to improve monocular depth estimation and other geometric estimation tasks, but closing the synthetic-real domain gap is a non-trivial and important task. While much recent work has focused on unsupervised domain adaptation, we consider a more realistic scenario where a large amount of synthetic training data is supplemented by a small set of real images with ground-truth. In this setting, we find that existing domain translation approaches are difficult to train and offer little advantage over simple baselines that use a mix of real and synthetic data. A key failure mode is that real-world images contain novel objects and clutter not present in synthetic training. This high-level domain shift isn’t handled by existing image translation models.
18 |
19 | Based on these observations, we develop an attention module that learns to identify and remove difficult out-ofdomain regions in real images in order to improve depth prediction for a model trained primarily on synthetic data. We carry out extensive experiments to validate our attendremove-complete approach (ARC) and find that it significantly outperforms state-of-the-art domain adaptation methods for depth prediction. Visualizing the removed regions provides interpretable insights into the synthetic-real domain gap.
20 |
21 |
22 |
23 | ## Reference
24 | If you find our work useful in your research please consider citing our paper:
25 | ```
26 | @inproceedings{zhao2020domain,
27 | title={Domain Decluttering: Simplifying Images to Mitigate Synthetic-Real Domain Shift and Improve Depth Estimation},
28 | author={Zhao, Yunhan and Kong, Shu and Shin, Daeyun and Fowlkes, Charless},
29 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
30 | pages={3330--3340},
31 | year={2020}
32 | }
33 | ```
34 |
35 | ## Contents
36 |
37 | - [Requirments](#requirements)
38 | - [Training Precedures](#training-precedures)
39 | - [Evaluations](#evaluations)
40 | - [Pretrained Models](#pretrained-models)
41 |
42 |
43 | ## Requirements
44 | 1. Python 3.6 with Ubuntu 16.04
45 | 2. Pytorch 1.1.0
46 | 3. Apex 0.1 (optional)
47 |
48 | You also need other third-party libraries, such as numpy, pillow, torchvision, and tensorboardX (optional) to run the code. We use apex when training all models but it is not strictly required to run the code.
49 |
50 | ## Datasets
51 | You have to download NYUv2 and PBRS and place them in the following structure to load the data.
52 | #### Dataset Structure
53 | ```
54 | NYUv2 (real)
55 | | train
56 | | rgb
57 | | depth
58 | | test
59 | | rgb
60 | | depth
61 | PBRS (synthetic)
62 | | train
63 | | rgb
64 | | depth
65 | ```
66 | You need to download Kitti and vKitti for Kitti experiments and follow the same structure.
67 | ## Training Precedures
68 | - [1 Train Initial Depth Predictor D](#1-Train-Initial-Depth-Predictor-D)
69 | - [2 Train Style Translator T (pretrain T)](#2-Train-Style-Translator-T)
70 | - [3 Train Initial Attention Module A](#3-train-initial-attention-module)
71 | - [4 Train Inpainting Module I (pretrain I)](#4-train-inpainting-module-I)
72 | - [5 Jointly Train Depth Predictor D and Attention Module A (pretrain A, D)](#5-jointly-train-depth-predictor-D-and-attention-module-A)
73 | - [6 Finetune the Whole System with Depth Loss](#6-finetune-the-whole-system-with-depth-loss)
74 |
75 | All training steps use one common `train.py` file so please make sure to comment/uncomment the correct line for each step.
76 | ```bash
77 | CUDA_VISIBLE_DEVICES= python train.py \
78 | --path_to_NYUv2= \
79 | --path_to_PBRS= \
80 | --batch_size=4 --total_epoch_num=500 --isTrain --eval_batch_size=1
81 | ```
82 | `batch_size` and `eval_batch_size` are flexible to change given your working environment.
83 | #### 1 Train Initial Depth Predictor D
84 | Train an initial depth predictor D with real and synthetic data. The best model is picked by the one with minimum L1 loss. The checkpoints are saved in `./experiments/train_initial_depth_predictor_D/`.
85 | #### 2 Train Style Translator T (pretrain T)
86 | Train the style translator T and yield a good initialization for style translator T. The best model is picked by visual inspection & training loss curves.
87 | #### 3 Train Initial Attention Module A
88 | Train an initial attention module A from scratch with descending $\tau$ values.
89 | #### 4 Train Inpainting Module I (pretrain I)
90 | Train the inpainting module I with T (from step 2) and A (from step 3). This leads to a good initalization to I.
91 | #### 5 Jointly Train Depth Predictor D and Attention Module A (pretrain A, D)
92 | Further jointly train depth predictor D and attention module A together with D (from step 1), T (from step 2), A (from step 3) and I (from step 4). The A and D learned from this step is the good initialization before finetuning the whole system together with depth loss. In step 5 and later step 6, we train for relatively less epochs, i.e., `total_epoch_num = 150`.
93 | #### 6 Finetune the Whole System with Depth Loss (Modular Coordinate Descent)
94 | Lastly, we finetune the whole system with depth loss terms using D (from step 5), T (from step 2), A (from step 5) and I (from step 4). The experimental results on NYUv2 dataset we reported in the paper are the evaluation results from this step (one step finetuning).
95 |
96 | ## Evaluations
97 | Evaluate the final results
98 | ```bash
99 | CUDA_VISIBLE_DEVICES= python train.py \
100 | --path_to_NYUv2= \
101 | --path_to_PBRS= \
102 | --eval_batch_size=1
103 | ```
104 | Make sure uncomment step 6 in the `train.py` file. If you want to evaluate with your own data, please place your own data under `/test` with the dataset structure described above.
105 |
106 | ## Pretrained Models
107 | Pretrained models for the NYUv2 & PBRS experiment are available [here](https://drive.google.com/drive/folders/1gB4dE3qoHrNGQqqU7cea7Z3MouPIJA9m?usp=sharing).
108 |
109 | Pretrained models for the Kitti & vKitti experiment are available [here](https://drive.google.com/drive/folders/1XzCXm91-HgXm1OKx358yKFN-aSVZGqpM?usp=sharing).
110 |
111 | ## Acknowledgments
112 | This code is developed based on [T2Net](https://github.com/lyndonzheng/Synthetic2Realistic) and [Pytorch-CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).
113 |
114 | ## Questions
115 | Please feel free to email me at (yunhaz5 [at] ics [dot] uci [dot] edu) if you have any questions.
--------------------------------------------------------------------------------
/dataloader/Joint_xLabel_dataLoader.py:
--------------------------------------------------------------------------------
1 | import os, sys, random, time, copy
2 | from skimage import io, transform
3 | import numpy as np
4 | import scipy.io as sio
5 | from scipy import misc
6 | import matplotlib.pyplot as plt
7 | import PIL.Image
8 |
9 | import skimage.transform
10 | import blosc, struct
11 |
12 | import torch
13 | from torch.utils.data import Dataset, DataLoader
14 | import torch.nn as nn
15 | import torch.optim as optim
16 | from torch.optim import lr_scheduler
17 | import torch.nn.functional as F
18 | from torch.autograd import Variable
19 |
20 | import torchvision
21 | from torchvision import datasets, models, transforms
22 |
23 | IMG_EXTENSIONS = [
24 | '.jpg', '.JPG', '.jpeg', '.JPEG',
25 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.bin'
26 | ]
27 |
28 | def is_image_file(filename):
29 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
30 |
31 | class Joint_xLabel_train_dataLoader(Dataset):
32 | def __init__(self, real_root_dir, syn_root_dir, size=[240, 320], rgb=True, downsampleDepthFactor=1, paired_data=False):
33 | self.real_root_dir = real_root_dir
34 | self.syn_root_dir = syn_root_dir
35 | self.size = size
36 | self.rgb = rgb
37 | self.current_set_len = 0
38 | self.real_path2files = []
39 | self.syn_path2files = []
40 | self.downsampleDepthFactor = downsampleDepthFactor
41 | self.NYU_MIN_DEPTH_CLIP = 0.0
42 | self.NYU_MAX_DEPTH_CLIP = 10.0
43 | self.paired_data = paired_data # whether 1 to 1 matching
44 | self.augment = None # whether to augment each batch data
45 | self.x_labels = False # whether to collect extra labels in synthetic data, such as segmentation or instance boundaries
46 |
47 | self.set_name = 'train' # Joint_xLabel_train_dataLoader is only used in training phase
48 |
49 | real_curfilenamelist = os.listdir(os.path.join(self.real_root_dir, self.set_name, 'rgb'))
50 | for fname in sorted(real_curfilenamelist):
51 | if is_image_file(fname):
52 | path = os.path.join(self.real_root_dir, self.set_name, 'rgb', fname)
53 | self.real_path2files.append(path)
54 |
55 | self.real_set_len = len(self.real_path2files)
56 |
57 | syn_curfilenamelist = os.listdir(os.path.join(self.syn_root_dir, self.set_name, 'rgb'))
58 | for fname in sorted(syn_curfilenamelist):
59 | if is_image_file(fname):
60 | path = os.path.join(self.syn_root_dir, self.set_name, 'rgb', fname)
61 | self.syn_path2files.append(path)
62 |
63 | self.syn_set_len = len(self.syn_path2files)
64 |
65 | self.TF2tensor = transforms.ToTensor()
66 | self.TF2PIL = transforms.ToPILImage()
67 | self.TFNormalize = transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
68 | self.funcResizeTensor = nn.Upsample(size=self.size, mode='nearest', align_corners=None)
69 | self.funcResizeDepth = nn.Upsample(size=[int(self.size[0]*self.downsampleDepthFactor),
70 | int(self.size[1]*self.downsampleDepthFactor)],
71 | mode='nearest', align_corners=None)
72 |
73 | def __len__(self):
74 | # looping over real dataset
75 | return self.real_set_len
76 |
77 | def __getitem__(self, idx):
78 | real_filename = self.real_path2files[idx % self.real_set_len]
79 | rand_idx = random.randint(0, self.syn_set_len - 1)
80 | if self.paired_data:
81 | assert self.real_set_len == self.syn_set_len
82 | syn_filename = self.syn_path2files[idx]
83 |
84 | else:
85 | syn_filename = self.syn_path2files[rand_idx]
86 |
87 | if np.random.random(1) > 0.5:
88 | self.augment = True
89 | else:
90 | self.augment = False
91 |
92 | real_img, real_depth = self.fetch_img_depth(real_filename)
93 | syn_img, syn_depth = self.fetch_img_depth(syn_filename)
94 | return_dict = {'real': [real_img, real_depth], 'syn': [syn_img, syn_depth]}
95 |
96 | if self.x_labels:
97 | # not really used in this project
98 | extra_label_list = self.fetch_syn_extra_labels(syn_filename)
99 | return_dict = {'real': [real_img, real_depth], 'syn': [syn_img, syn_depth], 'syn_extra_labels': extra_label_list}
100 | return return_dict
101 |
102 | def fetch_img_depth(self, filename):
103 | image = PIL.Image.open(filename)
104 | image = np.array(image, dtype=np.float32) / 255.
105 |
106 | if self.set_name == 'train':
107 | depthname = filename.replace('rgb','depth_inpainted').replace('png','bin')
108 | else:
109 | # use real depth for validation and testing
110 | depthname = filename.replace('rgb','depth').replace('png','bin')
111 |
112 | depth = read_array_compressed(depthname)
113 |
114 | if self.set_name=='train' and self.augment:
115 | image = np.fliplr(image).copy()
116 | depth = np.fliplr(depth).copy()
117 |
118 | # rescale depth samples in training phase
119 | if self.set_name == 'train':
120 | depth = np.clip(depth, self.NYU_MIN_DEPTH_CLIP, self.NYU_MAX_DEPTH_CLIP) # [0, 10]
121 | depth = ((depth/self.NYU_MAX_DEPTH_CLIP) - 0.5) * 2.0 # [-1, 1]
122 |
123 | image = self.TF2tensor(image)
124 | image = self.TFNormalize(image)
125 | image = image.unsqueeze(0)
126 |
127 | depth = np.expand_dims(depth, 2)
128 | depth = self.TF2tensor(depth)
129 | depth = depth.unsqueeze(0)
130 |
131 | if "nyu" in filename:
132 | image = processNYU_tensor(image)
133 | depth = processNYU_tensor(depth)
134 |
135 | image = self.funcResizeTensor(image)
136 | depth = self.funcResizeTensor(depth)
137 |
138 | if self.downsampleDepthFactor != 1:
139 | depth = self.funcResizeDepth(depth)
140 |
141 | if self.rgb:
142 | image = image.squeeze(0)
143 | else:
144 | image = image.mean(1)
145 | image = image.squeeze(0).unsqueeze(0)
146 |
147 | depth = depth.squeeze(0)
148 | return image, depth
149 |
150 | def fetch_syn_extra_labels(self, filename):
151 | # currently only fetch segmentation labels and instance boundaries
152 | seg_name = filename.replace('rgb','semantic_seg')
153 | ib_name = filename.replace('rgb','instance_boundary')
154 |
155 | seg_np = np.array(PIL.Image.open(seg_name), dtype=np.float32)
156 | ib_np = np.array(PIL.Image.open(ib_name), dtype=np.float32)
157 |
158 | if self.set_name=='train' and self.augment:
159 | seg_np = np.fliplr(seg_np).copy()
160 | ib_np = np.fliplr(ib_np).copy()
161 |
162 | seg_np = np.expand_dims(seg_np, 2)
163 | seg_tensor = self.TF2tensor(seg_np)
164 |
165 | ib_np = np.expand_dims(ib_np, 2)
166 | ib_tensor = self.TF2tensor(ib_np) # size [1, 240, 320]
167 |
168 | return [seg_tensor, ib_tensor]
169 |
170 | def ensure_dir_exists(dirname, log_mkdir=True):
171 | """
172 | Creates a directory if it does not already exist.
173 | :param dirname: Path to a directory.
174 | :param log_mkdir: If true, a debug message is logged when creating a new directory.
175 | :return: Same as `dirname`.
176 | """
177 | dirname = path.realpath(path.expanduser(dirname))
178 | if not path.isdir(dirname):
179 | # `exist_ok` in case of race condition.
180 | os.makedirs(dirname, exist_ok=True)
181 | if log_mkdir:
182 | log.debug('mkdir -p {}'.format(dirname))
183 | return dirname
184 |
185 | def read_array(filename, dtype=np.float32):
186 | """
187 | Reads a multi-dimensional array file with the following format:
188 | [int32_t number of dimensions n]
189 | [int32_t dimension 0], [int32_t dimension 1], ..., [int32_t dimension n]
190 | [float or int data]
191 |
192 | :param filename: Path to the array file.
193 | :param dtype: This must be consistent with the saved data type.
194 | :return: A numpy array.
195 | """
196 | with open(filename, mode='rb') as f:
197 | content = f.read()
198 | return bytes_to_array(content, dtype=dtype)
199 |
200 | def read_array_compressed(filename, dtype=np.float32):
201 | """
202 | Reads a multi-dimensional array file compressed with Blosc.
203 | Otherwise the same as `read_float32_array`.
204 | """
205 | with open(filename, mode='rb') as f:
206 | compressed = f.read()
207 | decompressed = blosc.decompress(compressed)
208 | return bytes_to_array(decompressed, dtype=dtype)
209 |
210 | def save_array_compressed(filename, arr: np.ndarray):
211 | """
212 | See `read_array`.
213 | """
214 | encoded = array_to_bytes(arr)
215 | compressed = blosc.compress(encoded, arr.dtype.itemsize, clevel=7, shuffle=True, cname='lz4hc')
216 | with open(filename, mode='wb') as f:
217 | f.write(compressed)
218 | log.info('Saved {}'.format(filename))
219 |
220 | def array_to_bytes(arr: np.ndarray):
221 | """
222 | Dumps a numpy array into a raw byte string.
223 | :param arr: A numpy array.
224 | :return: A `bytes` string.
225 | """
226 | shape = arr.shape
227 | ndim = arr.ndim
228 | ret = struct.pack('i', ndim) + struct.pack('i' * ndim, *shape) + arr.tobytes(order='C')
229 | return ret
230 |
231 | def bytes_to_array(s: bytes, dtype=np.float32):
232 | """
233 | Unpacks a byte string into a numpy array.
234 | :param s: A byte string containing raw array data.
235 | :param dtype: Data type.
236 | :return: A numpy array.
237 | """
238 | dims = struct.unpack('i', s[:4])[0]
239 | assert 0 <= dims < 1000 # Sanity check.
240 | shape = struct.unpack('i' * dims, s[4:4 * dims + 4])
241 | for dim in shape:
242 | assert dim > 0
243 | ret = np.frombuffer(s[4 * dims + 4:], dtype=dtype)
244 | assert ret.size == np.prod(shape), (ret.size, shape)
245 | ret.shape = shape
246 | return ret.copy()
247 |
248 | def processNYU_tensor(X):
249 | X = X[:,:,45:471,41:601]
250 | return X
251 |
252 | def cropPBRS(X):
253 | if len(X.shape)==3: return X[45:471,41:601,:]
254 | else: return X[45:471,41:601]
255 |
--------------------------------------------------------------------------------
/dataloader/NYUv2_dataLoader.py:
--------------------------------------------------------------------------------
1 | import os, random, time, copy, sys
2 | from skimage import io, transform
3 | import numpy as np
4 | import os.path as path
5 | import scipy.io as sio
6 | from scipy import misc
7 | import matplotlib.pyplot as plt
8 | import PIL.Image
9 |
10 | import skimage.transform
11 | import blosc, struct
12 |
13 | import torch
14 | from torch.utils.data import Dataset, DataLoader
15 | import torch.nn as nn
16 | import torch.optim as optim
17 | from torch.optim import lr_scheduler
18 | import torch.nn.functional as F
19 | from torch.autograd import Variable
20 |
21 | import torchvision
22 | from torchvision import datasets, models, transforms
23 |
24 | class NYUv2_dataLoader(Dataset):
25 | def __init__(self, root_dir, set_name='train', size=[240, 320], rgb=True, downsampleDepthFactor=1, training_depth='inpaint'):
26 | # training depth option: inpaint | original
27 | self.root_dir = root_dir
28 | self.size = size
29 | self.set_name = set_name
30 | self.training_depth = training_depth
31 | self.rgb = rgb
32 | self.current_set_len = 0
33 | self.path2files = []
34 | self.downsampleDepthFactor = downsampleDepthFactor
35 | self.NYU_MIN_DEPTH_CLIP = 0.0
36 | self.NYU_MAX_DEPTH_CLIP = 10.0
37 |
38 | curfilenamelist = os.listdir(path.join(self.root_dir, self.set_name, 'rgb'))
39 | self.path2files += [path.join(self.root_dir, self.set_name, 'rgb')+'/'+ curfilename for curfilename in curfilenamelist]
40 | self.current_set_len = len(self.path2files)
41 |
42 | self.TF2tensor = transforms.ToTensor()
43 | self.TF2PIL = transforms.ToPILImage()
44 | self.TFNormalize = transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
45 | self.funcResizeTensor = nn.Upsample(size=self.size, mode='nearest', align_corners=None)
46 | self.funcResizeDepth = nn.Upsample(size=[int(self.size[0]*self.downsampleDepthFactor),
47 | int(self.size[1]*self.downsampleDepthFactor)],
48 | mode='nearest', align_corners=None)
49 |
50 | def __len__(self):
51 | return self.current_set_len
52 |
53 | def __getitem__(self, idx):
54 | filename = self.path2files[idx]
55 | image = PIL.Image.open(filename)
56 | image = np.array(image).astype(np.float32) / 255.
57 |
58 | if self.set_name == 'train':
59 | if self.training_depth == 'original':
60 | depthname = filename.replace('rgb','depth').replace('png','bin')
61 | else:
62 | depthname = filename.replace('rgb','depth_inpainted').replace('png','bin')
63 | else:
64 | # use real depth for validation and testing
65 | depthname = filename.replace('rgb','depth').replace('png','bin')
66 |
67 | depth = read_array_compressed(depthname)
68 |
69 | if self.set_name =='train' and np.random.random(1)>0.5:
70 | image = np.fliplr(image).copy()
71 | depth = np.fliplr(depth).copy()
72 |
73 | # rescale depth samples in training phase
74 | if self.set_name == 'train':
75 | depth = np.clip(depth, self.NYU_MIN_DEPTH_CLIP, self.NYU_MAX_DEPTH_CLIP) # [0, 10]
76 | depth = ((depth/self.NYU_MAX_DEPTH_CLIP) - 0.5) * 2.0 # [-1, 1]
77 |
78 | image = self.TF2tensor(image)
79 | image = self.TFNormalize(image)
80 | image = image.unsqueeze(0)
81 |
82 | depth = np.expand_dims(depth, 2)
83 | depth = self.TF2tensor(depth)
84 | depth = depth.unsqueeze(0)
85 |
86 | image = processNYU_tensor(image)
87 | depth = processNYU_tensor(depth)
88 |
89 | image = self.funcResizeTensor(image)
90 | depth = self.funcResizeTensor(depth)
91 |
92 | if self.downsampleDepthFactor != 1:
93 | depth = self.funcResizeDepth(depth)
94 |
95 | if self.rgb:
96 | image = image.squeeze(0)
97 | else:
98 | image = image.mean(1)
99 | image = image.squeeze(0).unsqueeze(0)
100 |
101 | depth = depth.squeeze(0)
102 | return image, depth
103 |
104 | def ensure_dir_exists(dirname, log_mkdir=True):
105 | """
106 | Creates a directory if it does not already exist.
107 | :param dirname: Path to a directory.
108 | :param log_mkdir: If true, a debug message is logged when creating a new directory.
109 | :return: Same as `dirname`.
110 | """
111 | dirname = path.realpath(path.expanduser(dirname))
112 | if not path.isdir(dirname):
113 | # `exist_ok` in case of race condition.
114 | os.makedirs(dirname, exist_ok=True)
115 | if log_mkdir:
116 | log.debug('mkdir -p {}'.format(dirname))
117 | return dirname
118 |
119 |
120 | def read_array(filename, dtype=np.float32):
121 | """
122 | Reads a multi-dimensional array file with the following format:
123 | [int32_t number of dimensions n]
124 | [int32_t dimension 0], [int32_t dimension 1], ..., [int32_t dimension n]
125 | [float or int data]
126 |
127 | :param filename: Path to the array file.
128 | :param dtype: This must be consistent with the saved data type.
129 | :return: A numpy array.
130 | """
131 | with open(filename, mode='rb') as f:
132 | content = f.read()
133 | return bytes_to_array(content, dtype=dtype)
134 |
135 |
136 | def read_array_compressed(filename, dtype=np.float32):
137 | """
138 | Reads a multi-dimensional array file compressed with Blosc.
139 | Otherwise the same as `read_float32_array`.
140 | """
141 | with open(filename, mode='rb') as f:
142 | compressed = f.read()
143 | decompressed = blosc.decompress(compressed)
144 | return bytes_to_array(decompressed, dtype=dtype)
145 |
146 |
147 | def save_array_compressed(filename, arr: np.ndarray):
148 | """
149 | See `read_array`.
150 | """
151 | encoded = array_to_bytes(arr)
152 | compressed = blosc.compress(encoded, arr.dtype.itemsize, clevel=7, shuffle=True, cname='lz4hc')
153 | with open(filename, mode='wb') as f:
154 | f.write(compressed)
155 | log.info('Saved {}'.format(filename))
156 |
157 |
158 | def array_to_bytes(arr: np.ndarray):
159 | """
160 | Dumps a numpy array into a raw byte string.
161 | :param arr: A numpy array.
162 | :return: A `bytes` string.
163 | """
164 | shape = arr.shape
165 | ndim = arr.ndim
166 | ret = struct.pack('i', ndim) + struct.pack('i' * ndim, *shape) + arr.tobytes(order='C')
167 | return ret
168 |
169 |
170 | def bytes_to_array(s: bytes, dtype=np.float32):
171 | """
172 | Unpacks a byte string into a numpy array.
173 | :param s: A byte string containing raw array data.
174 | :param dtype: Data type.
175 | :return: A numpy array.
176 | """
177 | dims = struct.unpack('i', s[:4])[0]
178 | assert 0 <= dims < 1000 # Sanity check.
179 | shape = struct.unpack('i' * dims, s[4:4 * dims + 4])
180 | for dim in shape:
181 | assert dim > 0
182 | ret = np.frombuffer(s[4 * dims + 4:], dtype=dtype)
183 | assert ret.size == np.prod(shape), (ret.size, shape)
184 | ret.shape = shape
185 | return ret.copy()
186 |
187 |
188 | def processNYU_tensor(X):
189 | X = X[:,:,45:471,41:601]
190 | return X
191 |
192 |
193 | def cropPBRS(X):
194 | if len(X.shape)==3: return X[45:471,41:601,:]
195 | else: return X[45:471,41:601]
196 |
--------------------------------------------------------------------------------
/dataloader/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yunhan-zhao/ARC/13add94311bfa22660e34200ec8a1dd97a66faa3/dataloader/__init__.py
--------------------------------------------------------------------------------
/figures/4343-teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yunhan-zhao/ARC/13add94311bfa22660e34200ec8a1dd97a66faa3/figures/4343-teaser.gif
--------------------------------------------------------------------------------
/figures/demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yunhan-zhao/ARC/13add94311bfa22660e34200ec8a1dd97a66faa3/figures/demo.png
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import os, random, time, copy
2 | import sys
3 | from skimage import io, transform
4 | import numpy as np
5 | import os.path as path
6 | import scipy.io as sio
7 | import matplotlib.pyplot as plt
8 |
9 | import torch
10 | from torch.utils.data import Dataset, DataLoader
11 | import torch.nn as nn
12 | import torch.optim as optim
13 | from torch.optim import lr_scheduler
14 | import torch.nn.functional as F
15 | from torch.autograd import Variable
16 |
17 | import torchvision
18 | from torchvision import datasets, models, transforms
19 | import torchvision.models as models
20 |
21 | class StyleLoss(nn.Module):
22 | r"""
23 | Perceptual loss, VGG-based
24 | https://arxiv.org/abs/1603.08155
25 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py
26 | """
27 |
28 | def __init__(self, vgg19=None):
29 | super(StyleLoss, self).__init__()
30 | self.add_module('vgg', vgg19)
31 | self.criterion = torch.nn.L1Loss()
32 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
33 | self.vgg.to(self.device)
34 |
35 | def compute_gram(self, x):
36 | b, ch, h, w = x.size()
37 | f = x.view(b, ch, w * h)
38 | f_T = f.transpose(1, 2)
39 | # print(f_T)
40 | G = f.bmm(f_T) / (h * w * ch)
41 |
42 | # test = f.bmm(f_T) / (h * w * ch)
43 | # cond = torch.isnan(test)
44 | # print(torch.sum(cond))
45 | # if torch.sum(cond) > 0:
46 | # idx = np.argwhere(np.isnan(test.to('cpu').detach().numpy()))
47 | # print(idx[0], test.shape)
48 | # print(test[idx[0]])
49 | # print(f[torch.isinf(f)], f[torch.isnan(f)])
50 | # print(f_T[torch.isinf(f_T)], f_T[torch.isnan(f_T)])
51 | # print(f.bmm(f_T)[cond])
52 | # print(torch.bmm(f, f_T)[cond])
53 | # print(h * w * ch)
54 | # sys.exit(1)
55 |
56 | # print(f.bmm(f_T))
57 | # print(h * w * ch)
58 | # print(G)
59 |
60 | return G
61 |
62 | def __call__(self, x, y):
63 | # Compute features
64 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
65 |
66 | # Compute loss
67 | style_loss = 0.0
68 | style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2']))
69 | style_loss += self.criterion(self.compute_gram(x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4']))
70 | style_loss += self.criterion(self.compute_gram(x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4']))
71 | style_loss += self.criterion(self.compute_gram(x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2']))
72 |
73 | return style_loss
74 |
75 | class PerceptualLoss(nn.Module):
76 | r"""
77 | Perceptual loss, VGG-based
78 | https://arxiv.org/abs/1603.08155
79 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py
80 | """
81 |
82 | def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0], vgg19=None):
83 | super(PerceptualLoss, self).__init__()
84 | self.add_module('vgg', vgg19)
85 | self.criterion = torch.nn.L1Loss()
86 | self.weights = weights
87 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
88 | self.vgg.to(self.device)
89 |
90 | def __call__(self, x, y):
91 | # Compute features
92 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
93 |
94 | content_loss = 0.0
95 | content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1'])
96 | content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1'])
97 | content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1'])
98 | content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1'])
99 | content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1'])
100 |
101 | return content_loss
102 |
103 | class VGG19(torch.nn.Module):
104 | def __init__(self):
105 | super(VGG19, self).__init__()
106 | features = models.vgg19(pretrained=True).features
107 | self.relu1_1 = torch.nn.Sequential()
108 | self.relu1_2 = torch.nn.Sequential()
109 |
110 | self.relu2_1 = torch.nn.Sequential()
111 | self.relu2_2 = torch.nn.Sequential()
112 |
113 | self.relu3_1 = torch.nn.Sequential()
114 | self.relu3_2 = torch.nn.Sequential()
115 | self.relu3_3 = torch.nn.Sequential()
116 | self.relu3_4 = torch.nn.Sequential()
117 |
118 | self.relu4_1 = torch.nn.Sequential()
119 | self.relu4_2 = torch.nn.Sequential()
120 | self.relu4_3 = torch.nn.Sequential()
121 | self.relu4_4 = torch.nn.Sequential()
122 |
123 | self.relu5_1 = torch.nn.Sequential()
124 | self.relu5_2 = torch.nn.Sequential()
125 | self.relu5_3 = torch.nn.Sequential()
126 | self.relu5_4 = torch.nn.Sequential()
127 |
128 | for x in range(2):
129 | self.relu1_1.add_module(str(x), features[x])
130 |
131 | for x in range(2, 4):
132 | self.relu1_2.add_module(str(x), features[x])
133 |
134 | for x in range(4, 7):
135 | self.relu2_1.add_module(str(x), features[x])
136 |
137 | for x in range(7, 9):
138 | self.relu2_2.add_module(str(x), features[x])
139 |
140 | for x in range(9, 12):
141 | self.relu3_1.add_module(str(x), features[x])
142 |
143 | for x in range(12, 14):
144 | self.relu3_2.add_module(str(x), features[x])
145 |
146 | for x in range(14, 16):
147 | self.relu3_3.add_module(str(x), features[x])
148 |
149 | for x in range(16, 18):
150 | self.relu3_4.add_module(str(x), features[x])
151 |
152 | for x in range(18, 21):
153 | self.relu4_1.add_module(str(x), features[x])
154 |
155 | for x in range(21, 23):
156 | self.relu4_2.add_module(str(x), features[x])
157 |
158 | for x in range(23, 25):
159 | self.relu4_3.add_module(str(x), features[x])
160 |
161 | for x in range(25, 27):
162 | self.relu4_4.add_module(str(x), features[x])
163 |
164 | for x in range(27, 30):
165 | self.relu5_1.add_module(str(x), features[x])
166 |
167 | for x in range(30, 32):
168 | self.relu5_2.add_module(str(x), features[x])
169 |
170 | for x in range(32, 34):
171 | self.relu5_3.add_module(str(x), features[x])
172 |
173 | for x in range(34, 36):
174 | self.relu5_4.add_module(str(x), features[x])
175 |
176 | # don't need the gradients, just want the features
177 | for param in self.parameters():
178 | param.requires_grad = False
179 |
180 | def forward(self, x):
181 | relu1_1 = self.relu1_1(x)
182 | relu1_2 = self.relu1_2(relu1_1)
183 |
184 | relu2_1 = self.relu2_1(relu1_2)
185 | relu2_2 = self.relu2_2(relu2_1)
186 |
187 | relu3_1 = self.relu3_1(relu2_2)
188 | relu3_2 = self.relu3_2(relu3_1)
189 | relu3_3 = self.relu3_3(relu3_2)
190 | relu3_4 = self.relu3_4(relu3_3)
191 |
192 | relu4_1 = self.relu4_1(relu3_4)
193 | relu4_2 = self.relu4_2(relu4_1)
194 | relu4_3 = self.relu4_3(relu4_2)
195 | relu4_4 = self.relu4_4(relu4_3)
196 |
197 | relu5_1 = self.relu5_1(relu4_4)
198 | relu5_2 = self.relu5_2(relu5_1)
199 | relu5_3 = self.relu5_3(relu5_2)
200 | relu5_4 = self.relu5_4(relu5_3)
201 |
202 | out = {
203 | 'relu1_1': relu1_1,
204 | 'relu1_2': relu1_2,
205 |
206 | 'relu2_1': relu2_1,
207 | 'relu2_2': relu2_2,
208 |
209 | 'relu3_1': relu3_1,
210 | 'relu3_2': relu3_2,
211 | 'relu3_3': relu3_3,
212 | 'relu3_4': relu3_4,
213 |
214 | 'relu4_1': relu4_1,
215 | 'relu4_2': relu4_2,
216 | 'relu4_3': relu4_3,
217 | 'relu4_4': relu4_4,
218 |
219 | 'relu5_1': relu5_1,
220 | 'relu5_2': relu5_2,
221 | 'relu5_3': relu5_3,
222 | 'relu5_4': relu5_4,
223 | }
224 | return out
225 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yunhan-zhao/ARC/13add94311bfa22660e34200ec8a1dd97a66faa3/models/__init__.py
--------------------------------------------------------------------------------
/models/depth_generator_networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import functools
5 | from torch.autograd import Variable
6 | from torchvision import models
7 | import torch.nn.functional as F
8 | from torch.optim import lr_scheduler
9 |
10 |
11 | ######################################################################################
12 | # Functions
13 | ######################################################################################
14 | def get_norm_layer(norm_type='batch'):
15 | if norm_type == 'batch':
16 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
17 | elif norm_type == 'instance':
18 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
19 | elif norm_type == 'none':
20 | norm_layer = None
21 | else:
22 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
23 | return norm_layer
24 |
25 |
26 | def get_nonlinearity_layer(activation_type='PReLU'):
27 | if activation_type == 'ReLU':
28 | nonlinearity_layer = nn.ReLU(True)
29 | elif activation_type == 'SELU':
30 | nonlinearity_layer = nn.SELU(True)
31 | elif activation_type == 'LeakyReLU':
32 | nonlinearity_layer = nn.LeakyReLU(0.1, True)
33 | elif activation_type == 'PReLU':
34 | nonlinearity_layer = nn.PReLU()
35 | else:
36 | raise NotImplementedError('activation layer [%s] is not found' % activation_type)
37 | return nonlinearity_layer
38 |
39 |
40 | def get_scheduler(optimizer, opt):
41 | if opt.lr_policy == 'lambda':
42 | def lambda_rule(epoch):
43 | lr_l = 1.0 - max(0, epoch+1+1+opt.epoch_count-opt.niter) / float(opt.niter_decay+1)
44 | return lr_l
45 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
46 | elif opt.lr_policy == 'step':
47 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
48 | elif opt.lr_policy == 'exponent':
49 | scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
50 | # scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
51 | else:
52 | raise NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
53 | return scheduler
54 |
55 |
56 | def init_weights(net, net_name=None, init_type='normal', gain=0.02):
57 | def init_func(m):
58 | classname = m.__class__.__name__
59 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
60 | if init_type == 'normal':
61 | init.normal_(m.weight.data, 0.0, gain)
62 | elif init_type == 'xavier':
63 | init.xavier_normal_(m.weight.data, gain=gain)
64 | elif init_type == 'kaiming':
65 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
66 | elif init_type == 'orthogonal':
67 | init.orthogonal_(m.weight.data, gain=gain)
68 | else:
69 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
70 | if hasattr(m, 'bias') and m.bias is not None:
71 | init.constant_(m.bias.data, 0.0)
72 | elif classname.find('BatchNorm2d') != -1:
73 | init.uniform_(m.weight.data, 1.0, gain)
74 | init.constant_(m.bias.data, 0.0)
75 |
76 | print('initialize network {} with {}'.format(net_name, init_type))
77 | net.apply(init_func)
78 |
79 |
80 | def print_network(net):
81 | num_params = 0
82 | for param in net.parameters():
83 | num_params += param.numel()
84 | print(net)
85 | print('total number of parameters: %.3f M' % (num_params / 1e6))
86 |
87 |
88 | def init_net(net, init_type='normal', gpu_ids=[]):
89 |
90 | print_network(net)
91 |
92 | if len(gpu_ids) > 0:
93 | assert(torch.cuda.is_available())
94 | net = torch.nn.DataParallel(net, gpu_ids)
95 | net.cuda()
96 | init_weights(net, init_type)
97 | return net
98 |
99 |
100 | def _freeze(*args):
101 | for module in args:
102 | if module:
103 | for p in module.parameters():
104 | p.requires_grad = False
105 |
106 |
107 | def _unfreeze(*args):
108 | for module in args:
109 | if module:
110 | for p in module.parameters():
111 | p.requires_grad = True
112 |
113 |
114 | # define the generator(transform, task) network
115 | def define_G(input_nc, output_nc, ngf=64, layers=4, norm='batch', activation='PReLU', model_type='UNet',
116 | init_type='xavier', drop_rate=0, add_noise=False, gpu_ids=[], weight=0.1):
117 |
118 | if model_type == 'ResNet':
119 | net = _ResGenerator(input_nc, output_nc, ngf, layers, norm, activation, drop_rate, add_noise, gpu_ids)
120 | elif model_type == 'UNet':
121 | net = _UNetGenerator(input_nc, output_nc, ngf, layers, norm, activation, drop_rate, add_noise, gpu_ids, weight)
122 | # net = _PreUNet16(input_nc, output_nc, ngf, layers, True, norm, activation, drop_rate, gpu_ids)
123 | else:
124 | raise NotImplementedError('model type [%s] is not implemented', model_type)
125 |
126 | return init_net(net, init_type, gpu_ids)
127 |
128 |
129 | # define the discriminator network
130 | def define_D(input_nc, ndf = 64, n_layers = 3, num_D = 1, norm = 'batch', activation = 'PReLU', init_type='xavier', gpu_ids = []):
131 |
132 | net = _MultiscaleDiscriminator(input_nc, ndf, n_layers, num_D, norm, activation, gpu_ids)
133 |
134 | return init_net(net, init_type, gpu_ids)
135 |
136 |
137 | # define the feature discriminator network
138 | def define_featureD(input_nc, n_layers=2, norm='batch', activation='PReLU', init_type='xavier', gpu_ids=[]):
139 |
140 | net = _FeatureDiscriminator(input_nc, n_layers, norm, activation, gpu_ids)
141 |
142 | return init_net(net, init_type, gpu_ids)
143 |
144 |
145 | ######################################################################################
146 | # Basic Operation
147 | ######################################################################################
148 |
149 | class GaussianNoiseLayer(nn.Module):
150 | def __init__(self):
151 | super(GaussianNoiseLayer, self).__init__()
152 |
153 | def forward(self, x):
154 | if self.training == False:
155 | return x
156 | noise = Variable((torch.randn(x.size()).cuda(x.data.get_device()) - 0.5) / 10.0)
157 | return x+noise
158 |
159 |
160 | class _InceptionBlock(nn.Module):
161 | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.PReLU(), width=1, drop_rate=0, use_bias=False):
162 | super(_InceptionBlock, self).__init__()
163 |
164 | self.width = width
165 | self.drop_rate = drop_rate
166 |
167 | for i in range(width):
168 | layer = nn.Sequential(
169 | nn.ReflectionPad2d(i*2+1),
170 | nn.Conv2d(input_nc, output_nc, kernel_size=3, padding=0, dilation=i*2+1, bias=use_bias)
171 | )
172 | setattr(self, 'layer'+str(i), layer)
173 |
174 | self.norm1 = norm_layer(output_nc * width)
175 | self.norm2 = norm_layer(output_nc)
176 | self.nonlinearity = nonlinearity
177 | self.branch1x1 = nn.Sequential(
178 | nn.ReflectionPad2d(1),
179 | nn.Conv2d(output_nc * width, output_nc, kernel_size=3, padding=0, bias=use_bias)
180 | )
181 |
182 | def forward(self, x):
183 | result = []
184 | for i in range(self.width):
185 | layer = getattr(self, 'layer'+str(i))
186 | result.append(layer(x))
187 | output = torch.cat(result, 1)
188 | output = self.nonlinearity(self.norm1(output))
189 | output = self.norm2(self.branch1x1(output))
190 | if self.drop_rate > 0:
191 | output = F.dropout(output, p=self.drop_rate, training=self.training)
192 |
193 | return self.nonlinearity(output+x)
194 |
195 |
196 | class _EncoderBlock(nn.Module):
197 | def __init__(self, input_nc, middle_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.PReLU(), use_bias=False):
198 | super(_EncoderBlock, self).__init__()
199 |
200 | model = [
201 | nn.Conv2d(input_nc, middle_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
202 | norm_layer(middle_nc),
203 | nonlinearity,
204 | nn.Conv2d(middle_nc, output_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
205 | norm_layer(output_nc),
206 | nonlinearity
207 | ]
208 |
209 | self.model = nn.Sequential(*model)
210 |
211 | def forward(self, x):
212 | return self.model(x)
213 |
214 |
215 | class _DownBlock(nn.Module):
216 | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.PReLU(), use_bias=False):
217 | super(_DownBlock, self).__init__()
218 |
219 | model = [
220 | nn.Conv2d(input_nc, output_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
221 | norm_layer(output_nc),
222 | nonlinearity,
223 | nn.MaxPool2d(kernel_size=2, stride=2),
224 | ]
225 |
226 | self.model = nn.Sequential(*model)
227 |
228 | def forward(self, x):
229 | return self.model(x)
230 |
231 |
232 | class _ShuffleUpBlock(nn.Module):
233 | def __init__(self, input_nc, up_scale, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.PReLU(), use_bias=False):
234 | super(_ShuffleUpBlock, self).__init__()
235 |
236 | model = [
237 | nn.Conv2d(input_nc, input_nc*up_scale**2, kernel_size=3, stride=1, padding=1, bias=use_bias),
238 | nn.PixelShuffle(up_scale),
239 | nonlinearity,
240 | nn.Conv2d(input_nc, output_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
241 | norm_layer(output_nc),
242 | nonlinearity
243 | ]
244 |
245 | self.model = nn.Sequential(*model)
246 |
247 | def forward(self, x):
248 | return self.model(x)
249 |
250 |
251 | class _DecoderUpBlock(nn.Module):
252 | def __init__(self, input_nc, middle_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.PReLU(), use_bias=False):
253 | super(_DecoderUpBlock, self).__init__()
254 |
255 | model = [
256 | nn.ReflectionPad2d(1),
257 | nn.Conv2d(input_nc, middle_nc, kernel_size=3, stride=1, padding=0, bias=use_bias),
258 | norm_layer(middle_nc),
259 | nonlinearity,
260 | nn.ConvTranspose2d(middle_nc, output_nc, kernel_size=3, stride=2, padding=1, output_padding=1),
261 | norm_layer(output_nc),
262 | nonlinearity
263 | ]
264 |
265 | self.model = nn.Sequential(*model)
266 |
267 | def forward(self, x):
268 | return self.model(x)
269 |
270 | class _DecoderUpBlock_Upsampling(nn.Module):
271 | def __init__(self, input_nc, middle_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.PReLU(), use_bias=False):
272 | super(_DecoderUpBlock_Upsampling, self).__init__()
273 |
274 | model = [
275 | nn.ReflectionPad2d(1),
276 | nn.Conv2d(input_nc, middle_nc, kernel_size=3, stride=1, padding=0, bias=use_bias),
277 | norm_layer(middle_nc),
278 | nonlinearity,
279 | # nn.ConvTranspose2d(middle_nc, output_nc, kernel_size=3, stride=2, padding=1, output_padding=1),
280 | nn.Upsample(scale_factor = 2, mode='bilinear'),
281 | nn.ReflectionPad2d(1),
282 | nn.Conv2d(middle_nc, output_nc, kernel_size=3, stride=1, padding=0),
283 | norm_layer(output_nc),
284 | nonlinearity
285 | ]
286 |
287 | self.model = nn.Sequential(*model)
288 |
289 | def forward(self, x):
290 | return self.model(x)
291 |
292 | class _OutputBlock(nn.Module):
293 | def __init__(self, input_nc, output_nc, kernel_size=3, use_bias=False):
294 | super(_OutputBlock, self).__init__()
295 |
296 | model = [
297 | nn.ReflectionPad2d(int(kernel_size/2)),
298 | nn.Conv2d(input_nc, output_nc, kernel_size=kernel_size, padding=0, bias=use_bias),
299 | nn.Tanh()
300 | ]
301 |
302 | self.model = nn.Sequential(*model)
303 |
304 | def forward(self, x):
305 | return self.model(x)
306 |
307 |
308 | ######################################################################################
309 | # Network structure
310 | ######################################################################################
311 |
312 | class _ResGenerator(nn.Module):
313 | def __init__(self, input_nc, output_nc, ngf=64, n_blocks=6, norm='batch', activation='PReLU', drop_rate=0, add_noise=False, gpu_ids=[]):
314 | super(_ResGenerator, self).__init__()
315 |
316 | self.gpu_ids = gpu_ids
317 |
318 | norm_layer = get_norm_layer(norm_type=norm)
319 | nonlinearity = get_nonlinearity_layer(activation_type=activation)
320 |
321 | if type(norm_layer) == functools.partial:
322 | use_bias = norm_layer.func == nn.InstanceNorm2d
323 | else:
324 | use_bias = norm_layer == nn.InstanceNorm2d
325 |
326 | encoder = [
327 | nn.ReflectionPad2d(3),
328 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
329 | norm_layer(ngf),
330 | nonlinearity
331 | ]
332 |
333 | n_downsampling = 2
334 | mult = 1
335 | for i in range(n_downsampling):
336 | mult_prev = mult
337 | mult = min(2 ** (i+1), 2)
338 | encoder += [
339 | _EncoderBlock(ngf * mult_prev, ngf*mult, ngf*mult, norm_layer, nonlinearity, use_bias),
340 | nn.AvgPool2d(kernel_size=2, stride=2)
341 | ]
342 |
343 | mult = min(2 ** n_downsampling, 2)
344 | for i in range(n_blocks-n_downsampling):
345 | encoder +=[
346 | _InceptionBlock(ngf*mult, ngf*mult, norm_layer=norm_layer, nonlinearity=nonlinearity, width=1,
347 | drop_rate=drop_rate, use_bias=use_bias)
348 | ]
349 |
350 | decoder = []
351 | if add_noise:
352 | decoder += [GaussianNoiseLayer()]
353 |
354 | for i in range(n_downsampling):
355 | mult_prev = mult
356 | mult = min(2 ** (n_downsampling - i -1), 2)
357 | decoder +=[
358 | _DecoderUpBlock(ngf*mult_prev, ngf*mult_prev, ngf*mult, norm_layer, nonlinearity, use_bias),
359 | ]
360 |
361 | decoder +=[
362 | nn.ReflectionPad2d(3),
363 | nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
364 | nn.Tanh()
365 | ]
366 |
367 | self.encoder = nn.Sequential(*encoder)
368 | self.decoder = nn.Sequential(*decoder)
369 |
370 | def forward(self, input):
371 | feature = self.encoder(input)
372 | result = [feature]
373 | output = self.decoder(feature)
374 | result.append(output)
375 | return result
376 |
377 | class _ResGenerator_Upsample_Conv2d(nn.Module):
378 | def __init__(self, input_nc, output_nc, ngf=64, n_blocks=6, norm='batch', activation='PReLU', drop_rate=0, add_noise=False, gpu_ids=[]):
379 | super(_ResGenerator_Upsample_Conv2d, self).__init__()
380 |
381 | self.gpu_ids = gpu_ids
382 |
383 | norm_layer = get_norm_layer(norm_type=norm)
384 | nonlinearity = get_nonlinearity_layer(activation_type=activation)
385 |
386 | if type(norm_layer) == functools.partial:
387 | use_bias = norm_layer.func == nn.InstanceNorm2d
388 | else:
389 | use_bias = norm_layer == nn.InstanceNorm2d
390 |
391 | encoder = [
392 | nn.ReflectionPad2d(3),
393 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
394 | norm_layer(ngf),
395 | nonlinearity
396 | ]
397 |
398 | n_downsampling = 2
399 | mult = 1
400 | for i in range(n_downsampling):
401 | mult_prev = mult
402 | mult = min(2 ** (i+1), 2)
403 | encoder += [
404 | _EncoderBlock(ngf * mult_prev, ngf*mult, ngf*mult, norm_layer, nonlinearity, use_bias),
405 | nn.AvgPool2d(kernel_size=2, stride=2)
406 | ]
407 |
408 | mult = min(2 ** n_downsampling, 2)
409 | for i in range(n_blocks-n_downsampling):
410 | encoder +=[
411 | _InceptionBlock(ngf*mult, ngf*mult, norm_layer=norm_layer, nonlinearity=nonlinearity, width=1,
412 | drop_rate=drop_rate, use_bias=use_bias)
413 | ]
414 |
415 | decoder = []
416 | if add_noise:
417 | decoder += [GaussianNoiseLayer()]
418 |
419 | for i in range(n_downsampling):
420 | mult_prev = mult
421 | mult = min(2 ** (n_downsampling - i -1), 2)
422 | decoder +=[
423 | _DecoderUpBlock_Upsampling(ngf*mult_prev, ngf*mult_prev, ngf*mult, norm_layer, nonlinearity, use_bias),
424 | ]
425 |
426 | decoder += [
427 | nn.ReflectionPad2d(3),
428 | nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)
429 | # nn.Conv2d(ngf, output_nc, kernel_size=1, padding=0)
430 | ]
431 |
432 | self.encoder = nn.Sequential(*encoder)
433 | self.decoder = nn.Sequential(*decoder)
434 |
435 | def forward(self, input):
436 | feature = self.encoder(input)
437 | result = [feature]
438 | output = self.decoder(feature)
439 | # print('before first sigmoid before final projection:', output)
440 | # output = self.final_proj(output_pool)
441 | print('before first sigmoid after final projection:', output)
442 | result.append(output)
443 | return result
444 |
445 | class _ResGenerator_Upsample_Conv2d_Pool(nn.Module):
446 | def __init__(self, input_nc, output_nc, output_size, ngf=64, n_blocks=6, norm='batch', activation='PReLU', drop_rate=0, add_noise=False, gpu_ids=[]):
447 | super(_ResGenerator_Upsample_Conv2d_Pool, self).__init__()
448 |
449 | self.gpu_ids = gpu_ids
450 | self.output_h = output_size[0]
451 | self.output_w = output_size[1]
452 |
453 | norm_layer = get_norm_layer(norm_type=norm)
454 | nonlinearity = get_nonlinearity_layer(activation_type=activation)
455 |
456 | if type(norm_layer) == functools.partial:
457 | use_bias = norm_layer.func == nn.InstanceNorm2d
458 | else:
459 | use_bias = norm_layer == nn.InstanceNorm2d
460 |
461 | encoder = [
462 | nn.ReflectionPad2d(3),
463 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
464 | norm_layer(ngf),
465 | nonlinearity
466 | ]
467 |
468 | n_downsampling = 2
469 | mult = 1
470 | for i in range(n_downsampling):
471 | mult_prev = mult
472 | mult = min(2 ** (i+1), 2)
473 | encoder += [
474 | _EncoderBlock(ngf * mult_prev, ngf*mult, ngf*mult, norm_layer, nonlinearity, use_bias),
475 | nn.AvgPool2d(kernel_size=2, stride=2)
476 | ]
477 |
478 | mult = min(2 ** n_downsampling, 2)
479 | for i in range(n_blocks-n_downsampling):
480 | encoder +=[
481 | _InceptionBlock(ngf*mult, ngf*mult, norm_layer=norm_layer, nonlinearity=nonlinearity, width=1,
482 | drop_rate=drop_rate, use_bias=use_bias)
483 | ]
484 |
485 | decoder = []
486 | if add_noise:
487 | decoder += [GaussianNoiseLayer()]
488 |
489 | for i in range(n_downsampling):
490 | mult_prev = mult
491 | mult = min(2 ** (n_downsampling - i -1), 2)
492 | decoder +=[
493 | _DecoderUpBlock_Upsampling(ngf*mult_prev, ngf*mult_prev, ngf*mult, norm_layer, nonlinearity, use_bias),
494 | ]
495 |
496 | final_proj = [
497 | # nn.ReflectionPad2d(3),
498 | # nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)
499 | nn.Conv2d(ngf, output_nc, kernel_size=1, padding=0)
500 | ]
501 |
502 | self.encoder = nn.Sequential(*encoder)
503 | self.decoder = nn.Sequential(*decoder)
504 | self.final_proj = nn.Sequential(*final_proj)
505 |
506 | def forward(self, input):
507 | feature = self.encoder(input)
508 | # print(feature)
509 | result = [feature]
510 | output = self.decoder(feature)
511 | H, W = output.size()[2], output.size()[3]
512 | output_pool = F.max_pool2d(output, kernel_size=(int(H/self.output_h), int(W/self.output_w)),
513 | stride=(int(H/self.output_h), int(W/self.output_w)))
514 | print('before first sigmoid, pooling:', output_pool)
515 | output_pool = self.final_proj(output_pool)
516 | print('before first sigmoid, after projection: ', output_pool)
517 | # print(output_pool.size())
518 | result.append(output_pool)
519 | return result
520 |
521 | class _ResGenerator_Upsample(nn.Module):
522 | def __init__(self, input_nc, output_nc, ngf=64, n_blocks=6, norm='batch', activation='PReLU', drop_rate=0, add_noise=False, gpu_ids=[]):
523 | super(_ResGenerator_Upsample, self).__init__()
524 |
525 | self.gpu_ids = gpu_ids
526 |
527 | norm_layer = get_norm_layer(norm_type=norm)
528 | nonlinearity = get_nonlinearity_layer(activation_type=activation)
529 |
530 | if type(norm_layer) == functools.partial:
531 | use_bias = norm_layer.func == nn.InstanceNorm2d
532 | else:
533 | use_bias = norm_layer == nn.InstanceNorm2d
534 |
535 | encoder = [
536 | nn.ReflectionPad2d(3),
537 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
538 | norm_layer(ngf),
539 | nonlinearity
540 | ]
541 |
542 | n_downsampling = 2
543 | mult = 1
544 | for i in range(n_downsampling):
545 | mult_prev = mult
546 | mult = min(2 ** (i+1), 2)
547 | encoder += [
548 | _EncoderBlock(ngf * mult_prev, ngf*mult, ngf*mult, norm_layer, nonlinearity, use_bias),
549 | nn.AvgPool2d(kernel_size=2, stride=2)
550 | ]
551 |
552 | mult = min(2 ** n_downsampling, 2)
553 | for i in range(n_blocks-n_downsampling):
554 | encoder +=[
555 | _InceptionBlock(ngf*mult, ngf*mult, norm_layer=norm_layer, nonlinearity=nonlinearity, width=1,
556 | drop_rate=drop_rate, use_bias=use_bias)
557 | ]
558 |
559 | decoder = []
560 | if add_noise:
561 | decoder += [GaussianNoiseLayer()]
562 |
563 | for i in range(n_downsampling):
564 | mult_prev = mult
565 | mult = min(2 ** (n_downsampling - i -1), 2)
566 | decoder +=[
567 | _DecoderUpBlock_Upsampling(ngf*mult_prev, ngf*mult_prev, ngf*mult, norm_layer, nonlinearity, use_bias),
568 | ]
569 |
570 | decoder +=[
571 | nn.ReflectionPad2d(3),
572 | nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
573 | nn.Tanh()
574 | ]
575 |
576 | self.encoder = nn.Sequential(*encoder)
577 | self.decoder = nn.Sequential(*decoder)
578 |
579 | def forward(self, input):
580 | feature = self.encoder(input)
581 | result = [feature]
582 | output = self.decoder(feature)
583 | result.append(output)
584 | return result
585 |
586 | class _PreUNet16(nn.Module):
587 | def __init__(self, input_nc, output_nc, ngf=64, layers=5, pretrained=False, norm ='batch', activation='PReLu',
588 | drop_rate=0, gpu_ids=[]):
589 | super(_PreUNet16, self).__init__()
590 |
591 | self.gpu_ids = gpu_ids
592 | self.layers = layers
593 | norm_layer = get_norm_layer(norm_type=norm)
594 | nonlinearity = get_nonlinearity_layer(activation_type=activation)
595 | if type(norm_layer) == functools.partial:
596 | use_bias = norm_layer.func == nn.InstanceNorm2d
597 | else:
598 | use_bias = norm_layer == nn.InstanceNorm2d
599 |
600 | encoder = models.vgg16(pretrained=pretrained).features
601 |
602 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
603 | self.relu = nn.ReLU(inplace=True)
604 |
605 | self.conv1 = nn.Sequential(encoder[0], self.relu, encoder[2], self.relu)
606 | self.conv2 = nn.Sequential(encoder[5], self.relu, encoder[7], self.relu)
607 | self.conv3 = nn.Sequential(encoder[10], self.relu, encoder[12], self.relu, encoder[14], self.relu)
608 | self.conv4 = nn.Sequential(encoder[17], self.relu, encoder[19], self.relu, encoder[21], self.relu)
609 |
610 | for i in range(layers - 4):
611 | conv = _EncoderBlock(ngf * 8, ngf * 8, ngf * 8, norm_layer, nonlinearity, use_bias)
612 | setattr(self, 'down' + str(i), conv.model)
613 |
614 | center = []
615 | for i in range(7 - layers):
616 | center += [
617 | _InceptionBlock(ngf * 8, ngf * 8, norm_layer, nonlinearity, 7 - layers, drop_rate, use_bias)
618 | ]
619 |
620 | center += [_DecoderUpBlock(ngf * 8, ngf * 8, ngf * 4, norm_layer, nonlinearity, use_bias)]
621 |
622 | for i in range(layers - 4):
623 | upconv = _DecoderUpBlock(ngf * (8 + 4), ngf * 8, ngf * 4, norm_layer, nonlinearity, use_bias)
624 | setattr(self, 'up' + str(i), upconv.model)
625 |
626 | self.deconv4 = _DecoderUpBlock(ngf * (4 + 4), ngf * 8, ngf * 2, norm_layer, nonlinearity, use_bias)
627 | self.deconv3 = _DecoderUpBlock(ngf * (2 + 2) + output_nc, ngf * 4, ngf, norm_layer, nonlinearity, use_bias)
628 | self.deconv2 = _DecoderUpBlock(ngf * (1 + 1) + output_nc, ngf * 2, int(ngf / 2), norm_layer, nonlinearity, use_bias)
629 |
630 | self.deconv1 = _OutputBlock(int(ngf / 2) + output_nc, output_nc, kernel_size=7, use_bias=use_bias)
631 |
632 | self.output4 = _OutputBlock(ngf * (4 + 4), output_nc, kernel_size=3, use_bias=use_bias)
633 | self.output3 = _OutputBlock(ngf * (2 + 2) + output_nc, output_nc, kernel_size=3, use_bias=use_bias)
634 | self.output2 = _OutputBlock(ngf * (1 + 1) + output_nc, output_nc, kernel_size=3, use_bias=use_bias)
635 |
636 | self.center = nn.Sequential(*center)
637 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
638 |
639 | def forward(self, input):
640 | conv1 = self.pool(self.conv1(input))
641 | conv2 = self.pool(self.conv2(conv1))
642 | conv3 = self.pool(self.conv3(conv2))
643 | center_in = self.pool(self.conv4(conv3))
644 |
645 | middle = [center_in]
646 | for i in range(self.layers - 4):
647 | model = getattr(self, 'down' + str(i))
648 | center_in = self.pool(model(center_in))
649 | middle.append(center_in)
650 |
651 | result = [center_in]
652 |
653 | center_out = self.center(center_in)
654 |
655 | for i in range(self.layers - 4):
656 | model = getattr(self, 'up' + str(i))
657 | center_out = model(torch.cat([center_out, middle[self.layers - 4 - i]], 1))
658 |
659 | deconv4 = self.deconv4.forward(torch.cat([center_out, conv3 * 0.1], 1))
660 | output4 = self.output4.forward(torch.cat([center_out, conv3 * 0.1], 1))
661 | result.append(output4)
662 | deconv3 = self.deconv3.forward(torch.cat([deconv4, conv2 * 0.05, self.upsample(output4)], 1))
663 | output3 = self.output3.forward(torch.cat([deconv4, conv2 * 0.05, self.upsample(output4)], 1))
664 | result.append(output3)
665 | deconv2 = self.deconv2.forward(torch.cat([deconv3, conv1 * 0.01, self.upsample(output3)], 1))
666 | output2 = self.output2.forward(torch.cat([deconv3, conv1 * 0.01, self.upsample(output3)], 1))
667 | result.append(output2)
668 |
669 | output1 = self.deconv1.forward(torch.cat([deconv2, self.upsample(output2)], 1))
670 | result.append(output1)
671 |
672 | return result
673 |
674 |
675 | class _UNetGenerator(nn.Module):
676 | def __init__(self, input_nc, output_nc, ngf=64, layers=4, norm='batch', activation='PReLU', drop_rate=0, add_noise=False, gpu_ids=[],
677 | weight=0.1):
678 | super(_UNetGenerator, self).__init__()
679 |
680 | self.gpu_ids = gpu_ids
681 | self.layers = layers
682 | self.weight = weight
683 | norm_layer = get_norm_layer(norm_type=norm)
684 | nonlinearity = get_nonlinearity_layer(activation_type=activation)
685 |
686 | if type(norm_layer) == functools.partial:
687 | use_bias = norm_layer.func == nn.InstanceNorm2d
688 | else:
689 | use_bias = norm_layer == nn.InstanceNorm2d
690 |
691 | # encoder part
692 | self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
693 | self.conv1 = nn.Sequential(
694 | nn.ReflectionPad2d(3),
695 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
696 | norm_layer(ngf),
697 | nonlinearity
698 | )
699 | self.conv2 = _EncoderBlock(ngf, ngf*2, ngf*2, norm_layer, nonlinearity, use_bias)
700 | self.conv3 = _EncoderBlock(ngf*2, ngf*4, ngf*4, norm_layer, nonlinearity, use_bias)
701 | self.conv4 = _EncoderBlock(ngf*4, ngf*8, ngf*8, norm_layer, nonlinearity, use_bias)
702 |
703 | for i in range(layers-4):
704 | conv = _EncoderBlock(ngf*8, ngf*8, ngf*8, norm_layer, nonlinearity, use_bias)
705 | setattr(self, 'down'+str(i), conv.model)
706 |
707 | center=[]
708 | for i in range(7-layers):
709 | center +=[
710 | _InceptionBlock(ngf*8, ngf*8, norm_layer, nonlinearity, 7-layers, drop_rate, use_bias)
711 | ]
712 |
713 | center += [
714 | _DecoderUpBlock(ngf*8, ngf*8, ngf*4, norm_layer, nonlinearity, use_bias)
715 | ]
716 | if add_noise:
717 | center += [GaussianNoiseLayer()]
718 | self.center = nn.Sequential(*center)
719 |
720 | for i in range(layers-4):
721 | upconv = _DecoderUpBlock(ngf*(8+4), ngf*8, ngf*4, norm_layer, nonlinearity, use_bias)
722 | setattr(self, 'up' + str(i), upconv.model)
723 |
724 | self.deconv4 = _DecoderUpBlock(ngf*(4+4), ngf*8, ngf*2, norm_layer, nonlinearity, use_bias)
725 | self.deconv3 = _DecoderUpBlock(ngf*(2+2)+output_nc, ngf*4, ngf, norm_layer, nonlinearity, use_bias)
726 | self.deconv2 = _DecoderUpBlock(ngf*(1+1)+output_nc, ngf*2, int(ngf/2), norm_layer, nonlinearity, use_bias)
727 |
728 | self.output4 = _OutputBlock(ngf*(4+4), output_nc, 3, use_bias)
729 | self.output3 = _OutputBlock(ngf*(2+2)+output_nc, output_nc, 3, use_bias)
730 | self.output2 = _OutputBlock(ngf*(1+1)+output_nc, output_nc, 3, use_bias)
731 | self.output1 = _OutputBlock(int(ngf/2)+output_nc, output_nc, 7, use_bias)
732 |
733 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
734 |
735 | def forward(self, input):
736 | conv1 = self.pool(self.conv1(input))
737 | conv2 = self.pool(self.conv2.forward(conv1))
738 | conv3 = self.pool(self.conv3.forward(conv2))
739 | center_in = self.pool(self.conv4.forward(conv3))
740 |
741 | middle = [center_in]
742 | for i in range(self.layers-4):
743 | model = getattr(self, 'down'+str(i))
744 | center_in = self.pool(model.forward(center_in))
745 | middle.append(center_in)
746 | center_out = self.center.forward(center_in)
747 | result = [center_in]
748 |
749 | for i in range(self.layers-4):
750 | model = getattr(self, 'up'+str(i))
751 | center_out = model.forward(torch.cat([center_out, middle[self.layers-5-i]], 1))
752 |
753 | result.append(center_out)
754 |
755 | deconv4 = self.deconv4.forward(torch.cat([center_out, conv3 * self.weight], 1))
756 | output4 = self.output4.forward(torch.cat([center_out, conv3 * self.weight], 1))
757 | result.append(output4)
758 | deconv3 = self.deconv3.forward(torch.cat([deconv4, conv2 * self.weight * 0.5, self.upsample(output4)], 1))
759 | output3 = self.output3.forward(torch.cat([deconv4, conv2 * self.weight * 0.5, self.upsample(output4)], 1))
760 | result.append(output3)
761 | deconv2 = self.deconv2.forward(torch.cat([deconv3, conv1 * self.weight * 0.1, self.upsample(output3)], 1))
762 | output2 = self.output2.forward(torch.cat([deconv3, conv1 * self.weight * 0.1, self.upsample(output3)], 1))
763 | result.append(output2)
764 | output1 = self.output1.forward(torch.cat([deconv2, self.upsample(output2)], 1))
765 | result.append(output1)
766 |
767 | return result
768 |
769 |
770 | class _SimplifiedUNetGenerator(nn.Module):
771 | def __init__(self, input_nc, output_nc, ngf=64, layers=4, norm='batch', activation='PReLU', drop_rate=0, add_noise=False, gpu_ids=[],
772 | weight=0.1):
773 | super(_SimplifiedUNetGenerator, self).__init__()
774 |
775 | self.gpu_ids = gpu_ids
776 | self.layers = layers
777 | self.weight = weight
778 | norm_layer = get_norm_layer(norm_type=norm)
779 | nonlinearity = get_nonlinearity_layer(activation_type=activation)
780 |
781 | if type(norm_layer) == functools.partial:
782 | use_bias = norm_layer.func == nn.InstanceNorm2d
783 | else:
784 | use_bias = norm_layer == nn.InstanceNorm2d
785 |
786 | # encoder part
787 | self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
788 | self.conv1 = nn.Sequential(
789 | nn.ReflectionPad2d(3),
790 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
791 | norm_layer(ngf),
792 | nonlinearity
793 | )
794 | self.conv2 = _EncoderBlock(ngf, ngf*2, ngf*2, norm_layer, nonlinearity, use_bias)
795 | self.conv3 = _EncoderBlock(ngf*2, ngf*4, ngf*4, norm_layer, nonlinearity, use_bias)
796 | self.conv4 = _EncoderBlock(ngf*4, ngf*8, ngf*8, norm_layer, nonlinearity, use_bias)
797 |
798 | for i in range(layers-4):
799 | conv = _EncoderBlock(ngf*8, ngf*8, ngf*8, norm_layer, nonlinearity, use_bias)
800 | setattr(self, 'down'+str(i), conv.model)
801 |
802 | center=[]
803 | for i in range(7-layers):
804 | center +=[
805 | _InceptionBlock(ngf*8, ngf*8, norm_layer, nonlinearity, 7-layers, drop_rate, use_bias)
806 | ]
807 |
808 | center += [
809 | _DecoderUpBlock(ngf*8, ngf*8, ngf*4, norm_layer, nonlinearity, use_bias)
810 | ]
811 | if add_noise:
812 | center += [GaussianNoiseLayer()]
813 | self.center = nn.Sequential(*center)
814 |
815 | for i in range(layers-4):
816 | upconv = _DecoderUpBlock(ngf*(8+4), ngf*8, ngf*4, norm_layer, nonlinearity, use_bias)
817 | setattr(self, 'up' + str(i), upconv.model)
818 |
819 | self.deconv4 = _DecoderUpBlock(ngf*(4+4), ngf*8, ngf*2, norm_layer, nonlinearity, use_bias)
820 | self.deconv3 = _DecoderUpBlock(ngf*(2+2)+output_nc, ngf*4, ngf, norm_layer, nonlinearity, use_bias)
821 | self.deconv2 = _DecoderUpBlock(ngf*(1+1)+output_nc, ngf*2, int(ngf/2), norm_layer, nonlinearity, use_bias)
822 |
823 | self.output4 = _OutputBlock(ngf*(4+4), output_nc, 3, use_bias)
824 | self.output3 = _OutputBlock(ngf*(2+2)+output_nc, output_nc, 3, use_bias)
825 | self.output2 = _OutputBlock(ngf*(1+1)+output_nc, output_nc, 3, use_bias)
826 | self.output1 = _OutputBlock(int(ngf/2)+output_nc, output_nc, 7, use_bias)
827 |
828 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
829 |
830 | def forward(self, input):
831 | conv1 = self.pool(self.conv1(input))
832 | conv2 = self.pool(self.conv2.forward(conv1))
833 | conv3 = self.pool(self.conv3.forward(conv2))
834 | # conv4 = self.pool(self.conv4.forward(conv3))
835 |
836 | # middle = [center_in]
837 | # for i in range(self.layers-4):
838 | # model = getattr(self, 'down'+str(i))
839 | # center_in = self.pool(model.forward(center_in))
840 | # middle.append(center_in)
841 | # center_out = self.center.forward(center_in)
842 | # result = [center_in]
843 |
844 | # for i in range(self.layers-4):
845 | # model = getattr(self, 'up'+str(i))
846 | # center_out = model.forward(torch.cat([center_out, middle[self.layers-5-i]], 1))
847 |
848 | result = []
849 |
850 | deconv4 = self.deconv4.forward(torch.cat([conv3, conv3 * self.weight], 1))
851 | output4 = self.output4.forward(torch.cat([conv3, conv3 * self.weight], 1))
852 | result.append(output4)
853 | deconv3 = self.deconv3.forward(torch.cat([deconv4, conv2 * self.weight * 0.5, self.upsample(output4)], 1))
854 | output3 = self.output3.forward(torch.cat([deconv4, conv2 * self.weight * 0.5, self.upsample(output4)], 1))
855 | result.append(output3)
856 | deconv2 = self.deconv2.forward(torch.cat([deconv3, conv1 * self.weight * 0.1, self.upsample(output3)], 1))
857 | output2 = self.output2.forward(torch.cat([deconv3, conv1 * self.weight * 0.1, self.upsample(output3)], 1))
858 | result.append(output2)
859 | output1 = self.output1.forward(torch.cat([deconv2, self.upsample(output2)], 1))
860 | result.append(output1)
861 |
862 | return result
863 |
864 |
865 | class _MultiscaleDiscriminator(nn.Module):
866 | def __init__(self, input_nc, ndf=64, n_layers=3, num_D=1, norm='batch', activation='PReLU', gpu_ids=[]):
867 | super(_MultiscaleDiscriminator, self).__init__()
868 |
869 | self.num_D = num_D
870 | self.gpu_ids = gpu_ids
871 |
872 | for i in range(num_D):
873 | netD = _Discriminator(input_nc, ndf, n_layers, norm, activation, gpu_ids)
874 | setattr(self, 'scale'+str(i), netD)
875 |
876 | self.downsample = nn.AvgPool2d(kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
877 |
878 | def forward(self, input):
879 | result = []
880 | for i in range(self.num_D):
881 | netD = getattr(self, 'scale'+str(i))
882 | output = netD.forward(input)
883 | result.append(output)
884 | if i != (self.num_D-1):
885 | input = self.downsample(input)
886 | return result
887 |
888 |
889 | class _Discriminator(nn.Module):
890 | def __init__(self, input_nc, ndf=64, n_layers=3, norm='batch', activation='PReLU', gpu_ids=[]):
891 | super(_Discriminator, self).__init__()
892 |
893 | self.gpu_ids = gpu_ids
894 |
895 | norm_layer = get_norm_layer(norm_type=norm)
896 | nonlinearity = get_nonlinearity_layer(activation_type=activation)
897 |
898 | if type(norm_layer) == functools.partial:
899 | use_bias = norm_layer.func == nn.InstanceNorm2d
900 | else:
901 | use_bias = norm_layer == nn.InstanceNorm2d
902 |
903 | model = [
904 | nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1, bias=use_bias),
905 | nonlinearity,
906 | ]
907 |
908 | nf_mult=1
909 | for i in range(1, n_layers):
910 | nf_mult_prev = nf_mult
911 | nf_mult = min(2**i, 8)
912 | model += [
913 | nn.Conv2d(ndf*nf_mult_prev, ndf*nf_mult, kernel_size=4, stride=2, padding=1, bias=use_bias),
914 | norm_layer(ndf*nf_mult),
915 | nonlinearity,
916 | ]
917 |
918 | nf_mult_prev = nf_mult
919 | nf_mult = min(2 ** n_layers, 8)
920 | model += [
921 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=1, padding=1, bias=use_bias),
922 | norm_layer(ndf * 8),
923 | nonlinearity,
924 | nn.Conv2d(ndf*nf_mult, 1, kernel_size=4, stride=1, padding=1)
925 | ]
926 |
927 | self.model = nn.Sequential(*model)
928 |
929 | def forward(self, input):
930 | return self.model(input)
931 |
932 |
933 | class _FeatureDiscriminator(nn.Module):
934 | def __init__(self, input_nc, n_layers=2, norm='batch', activation='PReLU', gpu_ids=[]):
935 | super(_FeatureDiscriminator, self).__init__()
936 |
937 | self.gpu_ids = gpu_ids
938 |
939 | norm_layer = get_norm_layer(norm_type=norm)
940 | nonlinearity = get_nonlinearity_layer(activation_type=activation)
941 |
942 | if type(norm_layer) == functools.partial:
943 | use_bias = norm_layer.func == nn.InstanceNorm2d
944 | else:
945 | use_bias = norm_layer == nn.InstanceNorm2d
946 |
947 | model = [
948 | nn.Linear(input_nc * 40 * 30, input_nc),
949 | nonlinearity,
950 | ]
951 |
952 | # for i in range(1, n_layers):
953 | # model +=[
954 | # nn.Linear(input_nc, input_nc),
955 | # nonlinearity
956 | # ]
957 |
958 | model +=[nn.Linear(input_nc, 1)]
959 |
960 | self.model = nn.Sequential(*model)
961 |
962 | def forward(self, input):
963 | result = []
964 | # print(input.size())
965 | # input = input.view(-1, 512 * 40 * 12)
966 | input = input.view(-1, 512 * 30 * 40)
967 | output = self.model(input)
968 | result.append(output)
969 | return result
--------------------------------------------------------------------------------
/models/discriminator_networks.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import Dataset, DataLoader
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | from torch.optim import lr_scheduler
7 | import torch.nn.functional as F
8 | from torch.autograd import Variable
9 | import torchvision
10 | from torchvision import datasets, models, transforms
11 |
12 | class Discriminator80x80InstNorm(nn.Module):
13 | def __init__(self, device='cpu', pretrained=False, patchSize=[64, 64], input_nc=3):
14 | super(Discriminator80x80InstNorm, self).__init__()
15 | self.device = device
16 | self.input_nc = input_nc
17 | self.patchSize = patchSize
18 | self.outputSize = [patchSize[0]/16, patchSize[1]/16]
19 |
20 | self.discriminator = nn.Sequential(
21 | # 128-->60
22 | nn.Conv2d(self.input_nc, 64, kernel_size=5, padding=0, stride=2, bias=True),
23 | nn.LeakyReLU(0.2, inplace=True),
24 |
25 | # 60-->33
26 | nn.Conv2d(64, 128, kernel_size=5, padding=0, stride=2, bias=False),
27 | nn.InstanceNorm2d(128, momentum=0.001, affine=False, track_running_stats=False),
28 | nn.LeakyReLU(0.2, inplace=True),
29 | # 33->
30 | nn.Conv2d(128, 256, kernel_size=3, padding=0, stride=2, bias=False),
31 | nn.InstanceNorm2d(256, momentum=0.001, affine=False, track_running_stats=False),
32 | nn.LeakyReLU(0.2, inplace=True),
33 | #
34 | nn.Conv2d(256, 512, kernel_size=3, padding=0, stride=2, bias=False),
35 | nn.InstanceNorm2d(512, momentum=0.001, affine=False, track_running_stats=False),
36 | nn.LeakyReLU(0.2, inplace=True),
37 | # final classification for 'real(1) vs. fake(0)'
38 | nn.Conv2d(512, 1, kernel_size=1, padding=0, stride=1, bias=True),
39 | )
40 |
41 | def forward(self, X):
42 | return self.discriminator(X)
43 |
44 | class Discriminator80x80InstNormDilation(nn.Module):
45 | # same as Discriminator80x80InstNorm except the kernel size of last layer is changed to 3x3
46 | # used to test receptive field
47 | def __init__(self, device='cpu', dialate_size=1, pretrained=False, patchSize=[64, 64], input_nc=3):
48 | super(Discriminator80x80InstNormDilation, self).__init__()
49 | self.device = device
50 | self.input_nc = input_nc
51 | self.patchSize = patchSize
52 | self.outputSize = [patchSize[0]/16, patchSize[1]/16]
53 | self.dialate_size = dialate_size
54 |
55 | self.discriminator = nn.Sequential(
56 | # 128-->60
57 | nn.Conv2d(self.input_nc, 64, kernel_size=5, padding=0, stride=2, bias=True),
58 | nn.LeakyReLU(0.2, inplace=True),
59 |
60 | # 60-->33
61 | nn.Conv2d(64, 128, kernel_size=5, padding=0, stride=2, bias=False),
62 | nn.InstanceNorm2d(128, momentum=0.001, affine=False, track_running_stats=False),
63 | nn.LeakyReLU(0.2, inplace=True),
64 | # 33->
65 | nn.Conv2d(128, 256, kernel_size=3, padding=0, stride=2, bias=False),
66 | nn.InstanceNorm2d(256, momentum=0.001, affine=False, track_running_stats=False),
67 | nn.LeakyReLU(0.2, inplace=True),
68 | #
69 | nn.Conv2d(256, 512, kernel_size=3, padding=0, stride=2, bias=False),
70 | nn.InstanceNorm2d(512, momentum=0.001, affine=False, track_running_stats=False),
71 | nn.LeakyReLU(0.2, inplace=True),
72 | # final classification for 'real(1) vs. fake(0)'
73 | nn.Conv2d(512, 1, kernel_size=3, padding=0, stride=1, bias=True, dilation=self.dialate_size),
74 | )
75 |
76 | def forward(self, X):
77 | return self.discriminator(X)
78 |
79 | class Discriminator5121520InstNorm(nn.Module):
80 | def __init__(self, device='cpu', pretrained=False, patchSize=[64, 64], input_nc=3):
81 | super(Discriminator5121520InstNorm, self).__init__()
82 | self.device = device
83 | self.input_nc = input_nc
84 | self.patchSize = patchSize
85 | self.outputSize = [patchSize[0]/16, patchSize[1]/16]
86 |
87 | self.discriminator = nn.Sequential(
88 | # 128-->60
89 | nn.Conv2d(self.input_nc, 256, kernel_size=3, padding=0, stride=1, bias=True),
90 | nn.LeakyReLU(0.2, inplace=True),
91 |
92 | # 60-->33
93 | nn.Conv2d(256, 128, kernel_size=3, padding=0, stride=1, bias=False),
94 | nn.InstanceNorm2d(128, momentum=0.001, affine=False, track_running_stats=False),
95 | nn.LeakyReLU(0.2, inplace=True),
96 | # 33->
97 | nn.Conv2d(128, 64, kernel_size=3, padding=0, stride=1, bias=False),
98 | nn.InstanceNorm2d(64, momentum=0.001, affine=False, track_running_stats=False),
99 | nn.LeakyReLU(0.2, inplace=True),
100 |
101 | # final classification for 'real(1) vs. fake(0)'
102 | nn.Conv2d(64, 1, kernel_size=1, padding=0, stride=1, bias=True),
103 | )
104 |
105 | def forward(self, X):
106 | return self.discriminator(X)
107 |
108 | class DiscriminatorGlobalLocal(nn.Module):
109 | """Discriminator. PatchGAN."""
110 | def __init__(self, image_size=128, bbox_size = 64, conv_dim=64, c_dim=5, repeat_num_global=6, repeat_num_local=5, nc=3):
111 | super(DiscriminatorGlobalLocal, self).__init__()
112 |
113 | maxFilt = 512 if image_size==128 else 128
114 | globalLayers = []
115 | globalLayers.append(nn.Conv2d(nc, conv_dim, kernel_size=4, stride=2, padding=1,bias=False))
116 | globalLayers.append(nn.LeakyReLU(0.2, inplace=True))
117 |
118 | localLayers = []
119 | localLayers.append(nn.Conv2d(nc, conv_dim, kernel_size=4, stride=2, padding=1, bias=False))
120 | localLayers.append(nn.LeakyReLU(0.2, inplace=True))
121 |
122 | curr_dim = conv_dim
123 | for i in range(1, repeat_num_global):
124 | globalLayers.append(nn.Conv2d(curr_dim, min(curr_dim*2,maxFilt), kernel_size=4, stride=2, padding=1, bias=False))
125 | globalLayers.append(nn.LeakyReLU(0.2, inplace=True))
126 | curr_dim = min(curr_dim * 2, maxFilt)
127 |
128 | curr_dim = conv_dim
129 | for i in range(1, repeat_num_local):
130 | localLayers.append(nn.Conv2d(curr_dim, min(curr_dim * 2, maxFilt), kernel_size=4, stride=2, padding=1, bias=False))
131 | localLayers.append(nn.LeakyReLU(0.2, inplace=True))
132 | curr_dim = min(curr_dim * 2, maxFilt)
133 |
134 | k_size_local = int(bbox_size/ np.power(2, repeat_num_local))
135 | k_size_global = int(image_size/ np.power(2, repeat_num_global))
136 |
137 | self.mainGlobal = nn.Sequential(*globalLayers)
138 | self.mainLocal = nn.Sequential(*localLayers)
139 |
140 | # FC 1 for doing real/fake
141 | # self.fc1 = nn.Linear(curr_dim*(k_size_local**2+k_size_global**2), 1, bias=False)
142 | self.fc1 = nn.Linear(10880, 1, bias=False)
143 |
144 | # FC 2 for doing classification only on local patch
145 | if c_dim > 0:
146 | self.fc2 = nn.Linear(curr_dim*(k_size_local**2), c_dim, bias=False)
147 | else:
148 | self.fc2 = None
149 |
150 | def forward(self, x, boxImg, classify=False):
151 | bsz = x.size(0)
152 | h_global = self.mainGlobal(x)
153 | h_local = self.mainLocal(boxImg)
154 | h_append = torch.cat([h_global.view(bsz,-1), h_local.view(bsz,-1)], dim=-1)
155 | out_rf = self.fc1(h_append)
156 | out_cls = self.fc2(h_local.view(bsz,-1)) if classify and (self.fc2 is not None) else None
157 | return out_rf.squeeze(), out_cls, h_append
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import random, time, copy
3 | import argparse
4 | import torch
5 | from torch.utils.data import Dataset, DataLoader
6 |
7 | from dataloader.NYUv2_dataLoader import NYUv2_dataLoader
8 | from dataloader.Joint_xLabel_dataLoader import Joint_xLabel_train_dataLoader
9 |
10 | # step 1: Train Initial Depth Predictor D
11 | # from training.train_initial_depth_predictor_D import train_initial_depth_predictor_D as train_model
12 |
13 | # step 2: Train Style Translator T (pre-train T)
14 | # from training.train_style_translator_T import train_style_translator_T as train_model
15 |
16 | # step 3: Train Initial Attention Module A
17 | # from training.train_initial_attention_module_A import train_initial_attention_module_A as train_model
18 |
19 | # step 4: Train Inpainting Module I (pre-train I)
20 | # from training.train_inpainting_module_I import train_inpainting_module_I as train_model
21 |
22 | # step 5: Jointly Train Depth Predictor D and Attention Module A (pre-train A, D)
23 | # from training.jointly_train_depth_predictor_D_and_attention_module_A import jointly_train_depth_predictor_D_and_attention_module_A as train_model
24 |
25 | # step 6: Finetune the Whole System with Depth Loss (Modular Coordinate Descent)
26 | from training.finetune_the_whole_system_with_depth_loss import finetune_the_whole_system_with_depth_loss as train_model
27 |
28 | import warnings # ignore warnings
29 | warnings.filterwarnings("ignore")
30 |
31 | print(sys.version)
32 | print(torch.__version__)
33 |
34 | ################## set attributes for this project/experiment ##################
35 |
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument('--exp_dir', type=str, default=os.path.join(os.getcwd(), 'experiments'),
38 | help='place to store all experiments')
39 | parser.add_argument('--project_name', type=str, help='Test Project')
40 | parser.add_argument('--path_to_NYUv2', type=str, default='your absolute path to NYUv2 data',
41 | help='absolute dir of NYUv2 dataset')
42 | parser.add_argument('--path_to_PBRS', type=str, default='your absolute path to PBRS data',
43 | help='absolute dir of PBRS dataset')
44 | parser.add_argument('--isTrain', action='store_true', help='whether this is training phase')
45 | parser.add_argument('--batch_size', type=int, default=16, help='batch size')
46 | parser.add_argument('--eval_batch_size', type=int, default=1, help='batch size')
47 | parser.add_argument('--cropSize', type=list, default=[240, 320] , help='size of samples in experiments')
48 | parser.add_argument('--total_epoch_num', type=int, default=50, help='total number of epoch')
49 | parser.add_argument('--device', type=str, default='cpu', help='whether running on gpu')
50 | parser.add_argument('--num_workers', type=int, default=4, help='number of workers in dataLoaders')
51 | args = parser.parse_args()
52 |
53 | if torch.cuda.is_available():
54 | args.device='cuda'
55 | torch.cuda.empty_cache()
56 |
57 | # here only for evaluation purpose
58 | datasets_nyuv2 = {set_name: NYUv2_dataLoader(root_dir=args.path_to_NYUv2, set_name=set_name, size=args.cropSize, rgb=True)
59 | for set_name in ['train', 'test']}
60 | dataloaders_nyuv2 = {set_name: DataLoader(datasets_nyuv2[set_name],
61 | batch_size=args.batch_size if set_name=='train' else args.eval_batch_size,
62 | shuffle=set_name=='train',
63 | drop_last=set_name=='train',
64 | num_workers=args.num_workers)
65 | for set_name in ['train', 'test']}
66 |
67 | # for training purpose
68 | datasets_xLabels_joint = Joint_xLabel_train_dataLoader(real_root_dir=args.path_to_NYUv2, syn_root_dir=args.path_to_PBRS, paired_data=False)
69 | dataloaders_xLabels_joint = DataLoader(datasets_xLabels_joint,
70 | batch_size=args.batch_size,
71 | shuffle=True,
72 | drop_last=True,
73 | num_workers=args.num_workers)
74 |
75 | model = train_model(args, dataloaders_xLabels_joint, dataloaders_nyuv2)
76 |
77 | if args.isTrain:
78 | model.train()
79 | model.evaluate(mode='best')
80 | else:
81 | model.evaluate(mode='best')
--------------------------------------------------------------------------------
/training/base_model.py:
--------------------------------------------------------------------------------
1 | import os, copy, torch
2 | from torch.utils.data import Dataset, DataLoader
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | from torch.optim import lr_scheduler
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 | from collections import OrderedDict
9 |
10 | import torchvision
11 | from torchvision import datasets, models, transforms
12 | from torchvision.utils import make_grid
13 | from tensorboardX import SummaryWriter
14 |
15 | from utils.metrics import *
16 |
17 | try:
18 | from apex import amp
19 | except ImportError:
20 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run with apex.")
21 |
22 | import torch.multiprocessing as mp
23 |
24 | def set_requires_grad(nets, requires_grad=False):
25 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
26 | Parameters:
27 | nets (network list) -- a list of networks
28 | requires_grad (bool) -- whether the networks require gradients or not
29 | """
30 | if not isinstance(nets, list):
31 | nets = [nets]
32 | for net in nets:
33 | if net is not None:
34 | for param in net.parameters():
35 | param.requires_grad = requires_grad
36 |
37 |
38 | def apply_scheduler(optimizer, lr_policy, num_epoch=None, total_num_epoch=None):
39 | if lr_policy == 'linear':
40 | # num_epoch with initial lr
41 | # rest of epoch linearly decrease to 0 (the last epoch is not 0)
42 | def lambda_rule(epoch):
43 | # lr_l = 1.0 - max(0, epoch + 1 + epoch_count - niter) / float(niter_decay + 1)
44 | lr_l = 1.0 - max(0, epoch + 1 - num_epoch) / float(total_num_epoch - num_epoch + 1)
45 | return lr_l
46 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
47 | elif lr_policy == 'step':
48 | scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
49 | elif lr_policy == 'plateau':
50 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
51 | else:
52 | return NotImplementedError('learning rate policy [%s] is not implemented', lr_policy)
53 | return scheduler
54 |
55 | class base_model(nn.Module):
56 | def __init__(self, args):
57 | super(base_model, self).__init__()
58 | self.device = args.device
59 | self.isTrain = args.isTrain
60 | self.project_name = args.project_name
61 | self.exp_dir = args.exp_dir
62 |
63 | self.use_tensorboardX = True
64 | self.use_apex = True
65 |
66 | self.cropSize = args.cropSize # patch size for training the model. Default: [240, 320]
67 | self.cropSize_h, self.cropSize_w = self.cropSize[0], self.cropSize[1]
68 | self.batch_size = args.batch_size
69 | self.total_epoch_num = args.total_epoch_num # total number of epoch in training
70 | self.save_steps = 5
71 | self.task_lr = 1e-4 # default task learning rate
72 | self.D_lr = 5e-5 # default discriminator learning rate
73 | self.G_lr = 5e-5 # default generator learning rate
74 | self.real_label = 1
75 | self.syn_label = 0
76 |
77 | def _initialize_training(self):
78 | if self.project_name is not None:
79 | self.save_dir = os.path.join(self.exp_dir, self.project_name)
80 | else:
81 | self.project_name = self._get_project_name()
82 | self.save_dir = os.path.join(self.exp_dir, self.project_name)
83 | print('project name: {}'.format(self.project_name))
84 | print('save dir: {}'.format(self.save_dir))
85 | if not os.path.exists(self.save_dir): os.makedirs(self.save_dir)
86 |
87 | self.train_log = os.path.join(self.save_dir, 'train.log')
88 | self.evaluate_log = os.path.join(self.save_dir, 'evaluate.log')
89 | self.file_to_note_bestModel = os.path.join(self.save_dir,'note_bestModel.log')
90 |
91 | if self.use_tensorboardX:
92 | self.tensorboard_train_dir = os.path.join(self.save_dir, 'tensorboardX_train_logs')
93 | self.train_SummaryWriter = SummaryWriter(self.tensorboard_train_dir)
94 |
95 | self.tensorboard_eval_dir = os.path.join(self.save_dir, 'tensorboardX_eval_logs')
96 | self.eval_SummaryWriter = SummaryWriter(self.tensorboard_eval_dir)
97 |
98 | # self.train_display_freq = 500
99 | # self.val_write_freq = 10
100 | self.tensorboard_num_display_per_epoch = 5
101 | self.val_display_freq = 10
102 |
103 | def _initialize_networks(self):
104 | for name, model in self.model_dict.items():
105 | model.train().to(self.device)
106 | init_weights(model, net_name=name, init_type='normal', gain=0.02)
107 |
108 | def _get_scheduler(self, optim_type='linear'):
109 | '''
110 | if type is None -> all optim use default scheduler
111 | if types is str -> all optim use this types of scheduler
112 | if type is list -> each optim use their own scheduler
113 | '''
114 | self.scheduler_list = []
115 | if isinstance(optim_type, str):
116 | for name in self.optim_name:
117 | self.scheduler_list.append(apply_scheduler(getattr(self, name), lr_policy=optim_type, num_epoch=0.6*self.total_epoch_num,
118 | total_num_epoch=self.total_epoch_num))
119 | elif isinstance(optim_type, list):
120 | for name, optim in zip(self.optim_name, optim_type):
121 | self.scheduler_list.append(apply_scheduler(getattr(self, name), lr_policy=optim, num_epoch=0.6*self.total_epoch_num,
122 | total_num_epoch=self.total_epoch_num))
123 | else:
124 | raise RuntimeError("optim type should be either string or list!")
125 |
126 | def _init_apex(self, Num_losses):
127 | model_list = []
128 | optim_list = []
129 | for m in self.model_name:
130 | model_list.append(getattr(self, m))
131 | for o in self.optim_name:
132 | optim_list.append(getattr(self, o))
133 | model_list, optim_list = amp.initialize(model_list, optim_list, opt_level="O1", num_losses=Num_losses)
134 |
135 | def _check_parallel(self):
136 | if torch.cuda.device_count() > 1:
137 | for name in self.model_name:
138 | setattr(self, name, nn.DataParallel(getattr(self, name)))
139 |
140 | def _check_distribute(self):
141 | # not ready to use yet
142 | if torch.cuda.device_count() > 1:
143 | # world size is number of process participat in the job
144 | # torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
145 | # mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
146 | if use_apex:
147 | setattr(self, name, apex.parallel.DistributedDataParallel(getattr(self, name)))
148 | else:
149 | for name in self.model_name:
150 | setattr(self, name, nn.DistributedDataParallel(getattr(self, name)))
151 |
152 | def _set_models_train(self, model_name):
153 | for name in model_name:
154 | getattr(self, name).train()
155 |
156 | def _set_models_eval(self, model_name):
157 | for name in model_name:
158 | getattr(self, name).eval()
159 |
160 | def _set_models_float(self, model_name):
161 | for name in model_name:
162 | for layers in getattr(self, name).modules():
163 | layers.float()
164 |
165 | def save_models(self, model_list, mode, save_list=None):
166 | '''
167 | mode include best, latest, or a number (epoch)
168 | save as non-dataparallel state_dict
169 | save_list is used when we save model as a different name for later use
170 | '''
171 | if not save_list:
172 | for model_name in model_list:
173 | if mode == 'latest':
174 | path_to_save_paramOnly = os.path.join(self.save_dir, 'latest_{}.pth'.format(model_name))
175 | elif mode == 'best':
176 | path_to_save_paramOnly = os.path.join(self.save_dir, 'best_{}.pth'.format(model_name))
177 | elif isinstance(mode, int):
178 | path_to_save_paramOnly = os.path.join(self.save_dir, 'epoch-{}_{}.pth'.format(str(mode), model_name))
179 |
180 | try:
181 | state_dict = getattr(self, model_name).module.state_dict()
182 | except AttributeError:
183 | state_dict = getattr(self, model_name).state_dict()
184 |
185 | model_weights = copy.deepcopy(state_dict)
186 | torch.save(model_weights, path_to_save_paramOnly)
187 | else:
188 | assert len(model_list) == len(save_list)
189 | for save_name, model_name in zip(save_list, model_list):
190 | if mode == 'latest':
191 | path_to_save_paramOnly = os.path.join(self.save_dir, 'latest_{}.pth'.format(save_name))
192 | elif mode == 'best':
193 | path_to_save_paramOnly = os.path.join(self.save_dir, 'best_{}.pth'.format(save_name))
194 | elif isinstance(mode, int):
195 | path_to_save_paramOnly = os.path.join(self.save_dir, 'epoch-{}_{}.pth'.format(str(mode), save_name))
196 |
197 | try:
198 | state_dict = getattr(self, model_name).module.state_dict()
199 | except AttributeError:
200 | state_dict = getattr(self, model_name).state_dict()
201 |
202 | model_weights = copy.deepcopy(state_dict)
203 | torch.save(model_weights, path_to_save_paramOnly)
204 |
205 | def _load_models(self, model_list, mode, isTrain=False, model_path=None):
206 | if model_path is None:
207 | model_path = self.save_dir
208 |
209 | for model_name in model_list:
210 | if mode == 'latest':
211 | path = os.path.join(model_path, 'latest_{}.pth'.format(model_name))
212 | elif mode == 'best':
213 | path = os.path.join(model_path, 'best_{}.pth'.format(model_name))
214 | elif isinstance(mode, int):
215 | path = os.path.join(model_path, 'epoch-{}_{}.pth'.format(str(mode), model_name))
216 | else:
217 | raise RuntimeError("Mode not implemented")
218 |
219 | state_dict = torch.load(path)
220 |
221 | try:
222 | getattr(self, model_name).load_state_dict(state_dict)
223 | except RuntimeError:
224 | # in the case of parallel model loading non-parallel state_dict || add module to all keys
225 | new_state_dict = OrderedDict()
226 | for k, v in state_dict.items():
227 | name = 'module.' + k # add `module.`
228 | new_state_dict[name] = v
229 |
230 | getattr(self, model_name).load_state_dict(new_state_dict)
231 |
232 | if isTrain:
233 | getattr(self, model_name).to(self.device).train()
234 | else:
235 | getattr(self, model_name).to(self.device).eval()
236 |
237 | def save_tensor2np(self, tensor, name, epoch, path=None):
238 | # not ready to use in this project
239 | if path == None:
240 | path = self.save_dir
241 | generated_sample = tensor.detach().cpu().numpy()
242 | generated_sample_save_path = os.path.join(path, 'tensor2np', 'Epoch-%s_%s.npy' % (epoch, name))
243 | if not os.path.exists(os.path.join(path, 'tensor2np')):
244 | os.makedirs(os.path.join(path, 'tensor2np'))
245 | np.save(generated_sample_save_path, generated_sample)
246 |
247 | def write_2_tensorboardX(self, writer, input_tensor, name, mode, count, nrow=None, normalize=True, value_range=(-1.0, 1.0)):
248 | if mode == 'image':
249 | if not nrow:
250 | raise RuntimeError('tensorboardX: must specify number of rows in image mode')
251 | grid = make_grid(input_tensor, nrow=nrow, normalize=normalize, range=value_range)
252 | writer.add_image(name, grid, count)
253 | elif mode == 'scalar':
254 | if isinstance(input_tensor, list) and isinstance(name, list):
255 | assert len(input_tensor) == len(name)
256 | for n, t in zip(name, input_tensor):
257 | writer.add_scalar(n, t, count)
258 | else:
259 | writer.add_scalar(name, input_tensor, count)
260 | else:
261 | raise RuntimeError('tensorboardX: this mode is not yet implemented')
262 |
263 |
--------------------------------------------------------------------------------
/training/finetune_the_whole_system_with_depth_loss.py:
--------------------------------------------------------------------------------
1 | import os, time, sys
2 | import random
3 | import torch
4 | from torch.utils.data import Dataset, DataLoader
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | from torch.optim import lr_scheduler
8 | import torch.nn.functional as F
9 | from torch.autograd import Variable
10 |
11 | import torchvision
12 | from torchvision import datasets, models, transforms
13 | from torchvision.utils import make_grid
14 | from tensorboardX import SummaryWriter
15 |
16 | from models.depth_generator_networks import _UNetGenerator, init_weights, _ResGenerator_Upsample
17 | from models.discriminator_networks import Discriminator80x80InstNorm
18 | from models.attention_networks import _Attention_FullRes
19 |
20 | from utils.metrics import *
21 | from utils.image_pool import ImagePool
22 |
23 | from training.base_model import set_requires_grad, base_model
24 |
25 | try:
26 | from apex import amp
27 | except ImportError:
28 | print("\nPlease consider install apex from https://www.github.com/nvidia/apex to run with apex or set use_apex = False\n")
29 |
30 | import warnings # ignore warnings
31 | warnings.filterwarnings("ignore")
32 |
33 | class finetune_the_whole_system_with_depth_loss(base_model):
34 | def __init__(self, args, dataloaders_xLabels_joint, dataloaders_single):
35 | super(finetune_the_whole_system_with_depth_loss, self).__init__(args)
36 | self._initialize_training()
37 | # self.KITTI_MAX_DEPTH_CLIP = 80.0
38 | # self.EVAL_DEPTH_MIN = 1.0
39 | # self.EVAL_DEPTH_MAX = 50.0
40 |
41 | self.NYU_MAX_DEPTH_CLIP = 10.0
42 | self.EVAL_DEPTH_MIN = 1.0
43 | self.EVAL_DEPTH_MAX = 8.0
44 |
45 | self.dataloaders_single = dataloaders_single
46 | self.dataloaders_xLabels_joint = dataloaders_xLabels_joint
47 |
48 | self.tensorboard_num_display_per_epoch = 1
49 |
50 | self.attModule = _Attention_FullRes(input_nc = 3, output_nc = 1)
51 | self.inpaintNet = _ResGenerator_Upsample(input_nc = 3, output_nc = 3)
52 | self.styleTranslator = _ResGenerator_Upsample(input_nc = 3, output_nc = 3)
53 | self.depthEstModel = _UNetGenerator(input_nc = 3, output_nc = 1)
54 |
55 | self.tau_min = 0.05
56 | self.model_name = ['attModule', 'inpaintNet', 'styleTranslator', 'depthEstModel']
57 | self.L1loss = nn.L1Loss()
58 |
59 | if self.isTrain:
60 | self.optim_depth = optim.Adam(list(self.depthEstModel.parameters()) + list(self.inpaintNet.parameters()) + list(self.styleTranslator.parameters()), lr=self.task_lr, betas=(0.5, 0.999))
61 | self.optim_name = ['optim_depth']
62 | self._get_scheduler()
63 | self.loss_BCE = nn.BCEWithLogitsLoss()
64 |
65 | # load the "best" depth predictor D (from step 5)
66 | preTrain_path = os.path.join(os.getcwd(), 'experiments', 'jointly_train_depth_predictor_D_and_attention_module_A')
67 | self._load_models(model_list=['depthEstModel'], mode='best', isTrain=True, model_path=preTrain_path)
68 | print('Successfully loaded pre-trained {} model from {}'.format('depthEstModel', preTrain_path))
69 |
70 | # load the "best" style translator T (from step 2)
71 | preTrain_path = os.path.join(os.getcwd(), 'experiments', 'train_style_translator_T')
72 | self._load_models(model_list=['styleTranslator'], mode=480, isTrain=True, model_path=preTrain_path)
73 | print('Successfully loaded pre-trained {} model from {}'.format('styleTranslator', preTrain_path))
74 |
75 | # load the "best" attention module A (from step 5)
76 | preTrain_path = os.path.join(os.getcwd(), 'experiments', 'jointly_train_depth_predictor_D_and_attention_module_A')
77 | self._load_models(model_list=['attModule'], mode='best', isTrain=True, model_path=preTrain_path)
78 | print('Successfully loaded pre-trained {} model from {}'.format('attModule', preTrain_path))
79 |
80 | # load the "best" inpainting module I (from step 4)
81 | preTrain_path = os.path.join(os.getcwd(), 'experiments', 'train_inpainting_module_I')
82 | self._load_models(model_list=['inpaintNet'], mode=450, isTrain=True, model_path=preTrain_path)
83 | print('Successfully loaded pre-trained {} model from {}'.format('inpaintNet', preTrain_path))
84 |
85 | # apex can only be applied to CUDA models
86 | if self.use_apex:
87 | self._init_apex(Num_losses=2)
88 |
89 | self.EVAL_best_loss = float('inf')
90 | self.EVAL_best_model_epoch = 0
91 | self.EVAL_all_results = {}
92 |
93 | self._check_parallel()
94 |
95 | def _get_project_name(self):
96 | return 'finetune_the_whole_system_with_depth_loss'
97 |
98 | def _initialize_networks(self, model_name):
99 | for name in model_name:
100 | getattr(self, name).train().to(self.device)
101 | init_weights(getattr(self, name), net_name=name, init_type='normal', gain=0.02)
102 |
103 | def compute_D_loss(self, real_sample, fake_sample, netD):
104 | loss = 0
105 | syn_acc = 0
106 | real_acc = 0
107 |
108 | output = netD(fake_sample)
109 | label = torch.full((output.size()), self.syn_label, device=self.device)
110 | predSyn = (output > 0.5).to(self.device, dtype=torch.float32)
111 | total_num = torch.numel(output)
112 | syn_acc += (predSyn==label).type(torch.float32).sum().item()/total_num
113 |
114 | loss += self.loss_BCE(output, label)
115 |
116 | output = netD(real_sample)
117 | label = torch.full((output.size()), self.real_label, device=self.device)
118 | predReal = (output > 0.5).to(self.device, dtype=torch.float32)
119 | real_acc += (predReal==label).type(torch.float32).sum().item()/total_num
120 |
121 | loss += self.loss_BCE(output, label)
122 |
123 | return loss, syn_acc, real_acc
124 |
125 | def compute_depth_loss(self, input_rgb, depth_label, depthEstModel, valid_mask=None):
126 |
127 | prediction = depthEstModel(input_rgb)[-1]
128 | if valid_mask is not None:
129 | loss = self.L1loss(prediction[valid_mask], depth_label[valid_mask])
130 | else:
131 | assert valid_mask == None
132 | loss = self.L1loss(prediction, depth_label)
133 |
134 | return loss
135 |
136 | def compute_spare_attention(self, confident_score, t, isTrain=True):
137 | # t is the temperature --> scalar
138 | if isTrain:
139 | noise = torch.rand(confident_score.size(), requires_grad=False).to(self.device)
140 | noise = (noise + 0.00001) / 1.001
141 | noise = - torch.log(- torch.log(noise))
142 |
143 | confident_score = (confident_score + 0.00001) / 1.001
144 | confident_score = (confident_score + noise) / t
145 | else:
146 | confident_score = confident_score / t
147 |
148 | confident_score = F.sigmoid(confident_score)
149 |
150 | return confident_score
151 |
152 | def train(self):
153 | phase = 'train'
154 | since = time.time()
155 | best_loss = float('inf')
156 | set_requires_grad(self.attModule, requires_grad=False) # freeze attention module A
157 |
158 | tensorboardX_iter_count = 0
159 | for epoch in range(self.total_epoch_num):
160 | print('\nEpoch {}/{}'.format(epoch+1, self.total_epoch_num))
161 | print('-' * 10)
162 | fn = open(self.train_log,'a')
163 | fn.write('\nEpoch {}/{}\n'.format(epoch+1, self.total_epoch_num))
164 | fn.write('--'*5+'\n')
165 | fn.close()
166 |
167 | self._set_models_train(['attModule', 'inpaintNet', 'styleTranslator', 'depthEstModel'])
168 | iterCount,sampleCount = 0, 0
169 |
170 | for sample_dict in self.dataloaders_xLabels_joint:
171 | imageListReal, depthListReal = sample_dict['real']
172 | imageListSyn, depthListSyn = sample_dict['syn']
173 |
174 | imageListSyn = imageListSyn.to(self.device)
175 | depthListSyn = depthListSyn.to(self.device)
176 | imageListReal = imageListReal.to(self.device)
177 | depthListReal = depthListReal.to(self.device)
178 | valid_mask = (depthListReal > -1.)
179 |
180 | B, C, H, W = imageListReal.size()[0], imageListReal.size()[1], imageListReal.size()[2], imageListReal.size()[3]
181 |
182 | with torch.set_grad_enabled(phase=='train'):
183 | r2s_img = self.styleTranslator(imageListReal)[-1]
184 | confident_score = self.attModule(imageListReal)[-1]
185 | # convert to sparse confident score
186 | confident_score = self.compute_spare_attention(confident_score, t=self.tau_min, isTrain=False)
187 | # hard threshold
188 | confident_score[confident_score < 0.5] = 0.
189 | confident_score[confident_score >= 0.5] = 1.
190 |
191 | mod_r2s_img = r2s_img * confident_score
192 | inpainted_r2s = self.inpaintNet(mod_r2s_img)[-1]
193 |
194 | reconst_img = inpainted_r2s * (1. - confident_score) + confident_score * r2s_img
195 |
196 | # update
197 | self.optim_depth.zero_grad()
198 | total_loss = 0.
199 | inpainted_depth_loss = self.compute_depth_loss(reconst_img, depthListReal, self.depthEstModel, valid_mask)
200 | # add translated image to finetune the whole system might gives better results (normally could also be commented out)
201 | translated_depth_loss = self.compute_depth_loss(r2s_img, depthListReal, self.depthEstModel, valid_mask)
202 | syn_depth_loss = self.compute_depth_loss(imageListSyn, depthListSyn, self.depthEstModel)
203 | total_loss += (inpainted_depth_loss + translated_depth_loss + syn_depth_loss)
204 | if self.use_apex:
205 | with amp.scale_loss(total_loss, self.optim_depth, loss_id=0) as total_loss_scaled:
206 | total_loss_scaled.backward()
207 | else:
208 | total_loss.backward()
209 |
210 | self.optim_depth.step()
211 |
212 | iterCount += 1
213 |
214 | if self.use_tensorboardX:
215 | nrow = imageListReal.size()[0]
216 | self.train_display_freq = len(self.dataloaders_xLabels_joint) # feel free to adjust the display frequency
217 | if tensorboardX_iter_count % self.train_display_freq == 0:
218 | img_concat = torch.cat((imageListReal, r2s_img, mod_r2s_img, inpainted_r2s, reconst_img), dim=0)
219 | self.write_2_tensorboardX(self.train_SummaryWriter, img_concat, name='real, r2s, r2sMasked, inpaintedR2s, reconst', mode='image',
220 | count=tensorboardX_iter_count, nrow=nrow)
221 |
222 | self.write_2_tensorboardX(self.train_SummaryWriter, confident_score, name='Attention', mode='image',
223 | count=tensorboardX_iter_count, nrow=nrow, value_range=(0., 1.0))
224 |
225 | # add loss values
226 | loss_val_list = [total_loss, inpainted_depth_loss, translated_depth_loss, syn_depth_loss]
227 | loss_name_list = ['total_loss', 'inpainted_depth_loss', 'translated_depth_loss', 'syn_depth_loss']
228 | self.write_2_tensorboardX(self.train_SummaryWriter, loss_val_list, name=loss_name_list, mode='scalar', count=tensorboardX_iter_count)
229 |
230 | tensorboardX_iter_count += 1
231 |
232 | if iterCount % 20 == 0:
233 | loss_summary = '\t{}/{}, total_loss: {:.7f}, inpainted_depth_loss: {:.7f}, translated_depth_loss: {:.7f}, syn_depth_loss: {:.7f}'.format(
234 | iterCount, len(self.dataloaders_xLabels_joint), total_loss, inpainted_depth_loss, translated_depth_loss, syn_depth_loss)
235 |
236 | print(loss_summary)
237 |
238 | fn = open(self.train_log,'a')
239 | fn.write(loss_summary + '\n')
240 | fn.close()
241 |
242 | # take step in optimizer
243 | for scheduler in self.scheduler_list:
244 | scheduler.step()
245 | for optim in self.optim_name:
246 | lr = getattr(self, optim).param_groups[0]['lr']
247 | lr_update = 'Epoch {}/{} finished: {} learning rate = {:.7f}'.format(epoch+1, self.total_epoch_num, optim, lr)
248 | print(lr_update)
249 | fn = open(self.train_log,'a')
250 | fn.write(lr_update)
251 | fn.close()
252 |
253 | if (epoch+1) % self.save_steps == 0:
254 | self.save_models(self.model_name, mode=epoch+1)
255 | self.evaluate(epoch+1)
256 |
257 | time_elapsed = time.time() - since
258 | print('\nTraining complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
259 |
260 | fn = open(self.train_log,'a')
261 | fn.write('\nTraining complete in {:.0f}m {:.0f}s\n'.format(time_elapsed // 60, time_elapsed % 60))
262 | fn.close()
263 |
264 | best_model_summary = '\nOverall best model is epoch {}'.format(self.EVAL_best_model_epoch)
265 | print(best_model_summary)
266 | print(self.EVAL_all_results[str(self.EVAL_best_model_epoch)])
267 | fn = open(self.evaluate_log, 'a')
268 | fn.write(best_model_summary + '\n')
269 | fn.write(self.EVAL_all_results[str(self.EVAL_best_model_epoch)])
270 | fn.close()
271 |
272 | def evaluate(self, mode):
273 | '''
274 | mode choose from or best
275 | is the number of epoch, represents the number of epoch, used for in training evaluation
276 | 'best' is used for after training mode
277 | '''
278 | set_name = 'test'
279 | eval_model_list = ['attModule', 'inpaintNet', 'styleTranslator', 'depthEstModel']
280 |
281 | if isinstance(mode, int) and self.isTrain:
282 | self._set_models_eval(eval_model_list)
283 | if self.EVAL_best_loss == float('inf'):
284 | fn = open(self.evaluate_log, 'w')
285 | else:
286 | fn = open(self.evaluate_log, 'a')
287 |
288 | fn.write('Evaluating with mode: {}\n'.format(mode))
289 | fn.write('\tEvaluation range min: {} | max: {} \n'.format(self.EVAL_DEPTH_MIN, self.EVAL_DEPTH_MAX))
290 | fn.close()
291 |
292 | else:
293 | self._load_models(eval_model_list, mode)
294 |
295 | print('Evaluating with mode: {}'.format(mode))
296 | print('\tEvaluation range min: {} | max: {}'.format(self.EVAL_DEPTH_MIN, self.EVAL_DEPTH_MAX))
297 |
298 | total_loss, count = 0., 0
299 | predTensor = torch.zeros((1, 1, self.cropSize_h, self.cropSize_w)).to('cpu')
300 | grndTensor = torch.zeros((1, 1, self.cropSize_h, self.cropSize_w)).to('cpu')
301 | idx = 0
302 |
303 | tensorboardX_iter_count = 0
304 | for sample in self.dataloaders_single[set_name]:
305 | imageList, depthList = sample
306 | valid_mask = np.logical_and(depthList > self.EVAL_DEPTH_MIN, depthList < self.EVAL_DEPTH_MAX)
307 |
308 | idx += imageList.shape[0]
309 | print('epoch {}: have processed {} number samples in {} set'.format(mode, str(idx), set_name))
310 | imageList = imageList.to(self.device)
311 | depthList = depthList.to(self.device)
312 |
313 | if self.isTrain and self.use_apex:
314 | with amp.disable_casts():
315 | r2s_img = self.styleTranslator(imageList)[-1]
316 | confident_score = self.attModule(imageList)[-1]
317 | # convert to sparse confident score
318 | confident_score = self.compute_spare_attention(confident_score, t=self.tau_min, isTrain=False)
319 | # hard threshold
320 | confident_score[confident_score < 0.5] = 0.
321 | confident_score[confident_score >= 0.5] = 1.
322 | mod_r2s_img = r2s_img * confident_score
323 | inpainted_r2s = self.inpaintNet(mod_r2s_img)[-1]
324 | reconst_img = inpainted_r2s * (1. - confident_score) + confident_score * r2s_img
325 | predList = self.depthEstModel(reconst_img)[-1].detach().to('cpu')
326 |
327 | else:
328 | r2s_img = self.styleTranslator(imageList)[-1]
329 | confident_score = self.attModule(imageList)[-1]
330 | # convert to sparse confident score
331 | confident_score = self.compute_spare_attention(confident_score, t=self.tau_min, isTrain=False)
332 | # hard threshold
333 | confident_score[confident_score < 0.5] = 0.
334 | confident_score[confident_score >= 0.5] = 1.
335 | mod_r2s_img = r2s_img * confident_score
336 | inpainted_r2s = self.inpaintNet(mod_r2s_img)[-1]
337 | reconst_img = inpainted_r2s * (1. - confident_score) + confident_score * r2s_img
338 | predList = self.depthEstModel(reconst_img)[-1].detach().to('cpu')
339 |
340 | # recover real depth
341 | predList = (predList + 1.0) * 0.5 * self.NYU_MAX_DEPTH_CLIP
342 | depthList = depthList.detach().to('cpu')
343 | predTensor = torch.cat((predTensor, predList), dim=0)
344 | grndTensor = torch.cat((grndTensor, depthList), dim=0)
345 |
346 | if self.use_tensorboardX:
347 | nrow = imageList.size()[0]
348 | if tensorboardX_iter_count % self.val_display_freq == 0:
349 | depth_concat = torch.cat((depthList, predList), dim=0)
350 | self.write_2_tensorboardX(self.eval_SummaryWriter, depth_concat, name='{}: ground truth and depth prediction'.format(set_name),
351 | mode='image', count=tensorboardX_iter_count, nrow=nrow, value_range=(0.0, self.NYU_MAX_DEPTH_CLIP))
352 |
353 | tensorboardX_iter_count += 1
354 |
355 | if isinstance(mode, int) and self.isTrain:
356 | eval_depth_loss = self.L1loss(predList[valid_mask], depthList[valid_mask])
357 | total_loss += eval_depth_loss.detach().cpu()
358 |
359 | count += 1
360 |
361 | if isinstance(mode, int) and self.isTrain:
362 | validation_loss = (total_loss / count)
363 | print('validation loss is {:.7f}'.format(validation_loss))
364 | if self.use_tensorboardX:
365 | self.write_2_tensorboardX(self.eval_SummaryWriter, validation_loss, name='validation loss', mode='scalar', count=mode)
366 |
367 | results = Result(mask_min=self.EVAL_DEPTH_MIN, mask_max=self.EVAL_DEPTH_MAX)
368 | results.evaluate(predTensor[1:], grndTensor[1:])
369 |
370 | result1 = '\tabs_rel:{:.3f}, sq_rel:{:.3f}, rmse:{:.3f}, rmse_log:{:.3f}, mae:{:.3f} '.format(
371 | results.absrel,results.sqrel,results.rmse,results.rmselog,results.mae)
372 | result2 = '\t[<1.25]:{:.3f}, [<1.25^2]:{:.3f}, [<1.25^3]::{:.3f}'.format(results.delta1,results.delta2,results.delta3)
373 |
374 | print(result1)
375 | print(result2)
376 |
377 | if isinstance(mode, int) and self.isTrain:
378 | self.EVAL_all_results[str(mode)] = result1 + '\t' + result2
379 |
380 | if validation_loss.item() < self.EVAL_best_loss:
381 | self.EVAL_best_loss = validation_loss.item()
382 | self.EVAL_best_model_epoch = mode
383 | self.save_models(self.model_name, mode='best')
384 |
385 | best_model_summary = '\tCurrent best loss {:.7f}, current best model {}\n'.format(self.EVAL_best_loss, self.EVAL_best_model_epoch)
386 | print(best_model_summary)
387 |
388 | fn = open(self.evaluate_log, 'a')
389 | fn.write(result1 + '\n')
390 | fn.write(result2 + '\n')
391 | fn.write(best_model_summary + '\n')
392 | fn.close()
--------------------------------------------------------------------------------
/training/jointly_train_depth_predictor_D_and_attention_module_A.py:
--------------------------------------------------------------------------------
1 | import os, time, sys
2 | import random
3 | import torch
4 | from torch.utils.data import Dataset, DataLoader
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | from torch.optim import lr_scheduler
8 | import torch.nn.functional as F
9 | from torch.autograd import Variable
10 |
11 | import torchvision
12 | from torchvision import datasets, models, transforms
13 | from torchvision.utils import make_grid
14 | from tensorboardX import SummaryWriter
15 |
16 | from models.depth_generator_networks import _UNetGenerator, init_weights, _ResGenerator_Upsample
17 | from models.discriminator_networks import Discriminator80x80InstNorm
18 | from models.attention_networks import _Attention_FullRes
19 |
20 | from utils.metrics import *
21 | from utils.image_pool import ImagePool
22 |
23 | from training.base_model import set_requires_grad, base_model
24 |
25 | try:
26 | from apex import amp
27 | except ImportError:
28 | print("\nPlease consider install apex from https://www.github.com/nvidia/apex to run with apex or set use_apex = False\n")
29 |
30 | import warnings # ignore warnings
31 | warnings.filterwarnings("ignore")
32 |
33 | class jointly_train_depth_predictor_D_and_attention_module_A(base_model):
34 | def __init__(self, args, dataloaders_xLabels_joint, dataloaders_single):
35 | super(jointly_train_depth_predictor_D_and_attention_module_A, self).__init__(args)
36 | self._initialize_training()
37 | # self.KITTI_MAX_DEPTH_CLIP = 80.0
38 | # self.EVAL_DEPTH_MIN = 1.0
39 | # self.EVAL_DEPTH_MAX = 50.0
40 |
41 | self.NYU_MAX_DEPTH_CLIP = 10.0
42 | self.EVAL_DEPTH_MIN = 1.0
43 | self.EVAL_DEPTH_MAX = 8.0
44 |
45 | self.dataloaders_single = dataloaders_single
46 | self.dataloaders_xLabels_joint = dataloaders_xLabels_joint
47 |
48 | self.attModule = _Attention_FullRes(input_nc = 3, output_nc = 1)
49 | self.inpaintNet = _ResGenerator_Upsample(input_nc = 3, output_nc = 3)
50 | self.styleTranslator = _ResGenerator_Upsample(input_nc = 3, output_nc = 3)
51 | self.netD = Discriminator80x80InstNorm(input_nc = 3)
52 | self.depthEstModel = _UNetGenerator(input_nc = 3, output_nc = 1)
53 |
54 | self.tau_min = 0.05
55 | self.rho = 0.85
56 | self.KL_loss_weight = 1.0
57 | self.dis_weight = 1.0
58 | self.fake_loss_weight = 1e-3
59 |
60 | self.tensorboard_num_display_per_epoch = 1
61 | self.model_name = ['attModule', 'inpaintNet', 'styleTranslator', 'netD', 'depthEstModel']
62 | self.L1loss = nn.L1Loss()
63 |
64 | if self.isTrain:
65 | self.optim_netD = optim.Adam(self.netD.parameters(), lr=self.task_lr, betas=(0.5, 0.999))
66 | self.optim_depth = optim.Adam(list(self.depthEstModel.parameters()) + list(self.attModule.parameters()), lr=self.task_lr, betas=(0.5, 0.999))
67 | self.optim_name = ['optim_depth', 'optim_netD']
68 | self._get_scheduler()
69 | self.loss_BCE = nn.BCEWithLogitsLoss()
70 |
71 | self._initialize_networks(['netD'])
72 |
73 | # load the "best" depth predictor D (from step 1)
74 | preTrain_path = os.path.join(os.getcwd(), 'experiments', 'train_initial_depth_predictor_D')
75 | self._load_models(model_list=['depthEstModel'], mode='best', isTrain=True, model_path=preTrain_path)
76 | print('Successfully loaded pre-trained {} model from {}'.format('depthEstModel', preTrain_path))
77 |
78 | # load the "best" style translator T (from step 2)
79 | preTrain_path = os.path.join(os.getcwd(), 'experiments', 'train_style_translator_T')
80 | self._load_models(model_list=['styleTranslator'], mode=480, isTrain=True, model_path=preTrain_path)
81 | print('Successfully loaded pre-trained {} model from {}'.format('styleTranslator', preTrain_path))
82 |
83 | # load the "best" attention module A (from step 3)
84 | preTrain_path = os.path.join(os.getcwd(), 'experiments', 'train_initial_attention_module_A')
85 | self._load_models(model_list=['attModule'], mode=450, isTrain=True, model_path=preTrain_path)
86 | print('Successfully loaded pre-trained {} model from {}'.format('attModule', preTrain_path))
87 |
88 | # load the "best" inpainting module I (from step 4)
89 | preTrain_path = os.path.join(os.getcwd(), 'experiments', 'train_inpainting_module_I')
90 | self._load_models(model_list=['inpaintNet'], mode=450, isTrain=True, model_path=preTrain_path)
91 | print('Successfully loaded pre-trained {} model from {}'.format('inpaintNet', preTrain_path))
92 |
93 | # apex can only be applied to CUDA models
94 | if self.use_apex:
95 | self._init_apex(Num_losses=2)
96 |
97 | self.EVAL_best_loss = float('inf')
98 | self.EVAL_best_model_epoch = 0
99 | self.EVAL_all_results = {}
100 |
101 | self._check_parallel()
102 |
103 | def _get_project_name(self):
104 | return 'jointly_train_depth_predictor_D_and_attention_module_A'
105 |
106 | def _initialize_networks(self, model_name):
107 | for name in model_name:
108 | getattr(self, name).train().to(self.device)
109 | init_weights(getattr(self, name), net_name=name, init_type='normal', gain=0.02)
110 |
111 | def compute_D_loss(self, real_sample, fake_sample, netD):
112 | loss = 0
113 | syn_acc = 0
114 | real_acc = 0
115 |
116 | output = netD(fake_sample)
117 | label = torch.full((output.size()), self.syn_label, device=self.device)
118 | predSyn = (output > 0.5).to(self.device, dtype=torch.float32)
119 | total_num = torch.numel(output)
120 | syn_acc += (predSyn==label).type(torch.float32).sum().item()/total_num
121 |
122 | loss += self.loss_BCE(output, label)
123 |
124 | output = netD(real_sample)
125 | label = torch.full((output.size()), self.real_label, device=self.device)
126 | predReal = (output > 0.5).to(self.device, dtype=torch.float32)
127 | real_acc += (predReal==label).type(torch.float32).sum().item()/total_num
128 |
129 | loss += self.loss_BCE(output, label)
130 |
131 | return loss, syn_acc, real_acc
132 |
133 | def compute_depth_loss(self, input_rgb, depth_label, depthEstModel, valid_mask=None):
134 |
135 | prediction = depthEstModel(input_rgb)[-1]
136 | if valid_mask is not None:
137 | loss = self.L1loss(prediction[valid_mask], depth_label[valid_mask])
138 | else:
139 | assert valid_mask == None
140 | loss = self.L1loss(prediction, depth_label)
141 |
142 | return loss
143 |
144 | def compute_spare_attention(self, confident_score, t, isTrain=True):
145 | # t is the temperature --> scalar
146 | if isTrain:
147 | noise = torch.rand(confident_score.size(), requires_grad=False).to(self.device)
148 | noise = (noise + 0.00001) / 1.001
149 | noise = - torch.log(- torch.log(noise))
150 |
151 | confident_score = (confident_score + 0.00001) / 1.001
152 | confident_score = (confident_score + noise) / t
153 | else:
154 | confident_score = confident_score / t
155 |
156 | confident_score = F.sigmoid(confident_score)
157 |
158 | return confident_score
159 |
160 | def compute_KL_div(self, cf, target=0.5):
161 | g = cf.mean()
162 | g = (g + 0.00001) / 1.001 # prevent g = 0. or 1.
163 | y = target*torch.log(target/g) + (1-target)*torch.log((1-target)/(1-g))
164 | return y
165 |
166 | def compute_real_fake_loss(self, scores, loss_type, datasrc = 'real', loss_for='discr'):
167 | if loss_for == 'discr':
168 | if datasrc == 'real':
169 | if loss_type == 'lsgan':
170 | # The Loss for least-square gan
171 | d_loss = torch.pow(scores - 1., 2).mean()
172 | elif loss_type == 'hinge':
173 | # Hinge loss used in the spectral GAN paper
174 | d_loss = - torch.mean(torch.clamp(scores-1.,max=0.))
175 | elif loss_type == 'wgan':
176 | # The Loss for Wgan
177 | d_loss = - torch.mean(scores)
178 | else:
179 | scores = scores.view(scores.size(0),-1).mean(dim=1)
180 | d_loss = F.binary_cross_entropy_with_logits(scores, torch.ones_like(scores).detach())
181 | else:
182 | if loss_type == 'lsgan':
183 | # The Loss for least-square gan
184 | d_loss = torch.pow((scores),2).mean()
185 | elif loss_type == 'hinge':
186 | # Hinge loss used in the spectral GAN paper
187 | d_loss = -torch.mean(torch.clamp(-scores-1.,max=0.))
188 | elif loss_type == 'wgan':
189 | # The Loss for Wgan
190 | d_loss = torch.mean(scores)
191 | else:
192 | scores = scores.view(scores.size(0),-1).mean(dim=1)
193 | d_loss = F.binary_cross_entropy_with_logits(scores, torch.zeros_like(scores).detach())
194 |
195 | return d_loss
196 | else:
197 | if loss_type == 'lsgan':
198 | # The Loss for least-square gan
199 | g_loss = torch.pow(scores - 1., 2).mean()
200 | elif (loss_type == 'wgan') or (loss_type == 'hinge') :
201 | g_loss = - torch.mean(scores)
202 | else:
203 | scores = scores.view(scores.size(0),-1).mean(dim=1)
204 | g_loss = F.binary_cross_entropy_with_logits(scores, torch.ones_like(scores).detach())
205 | return g_loss
206 |
207 | def train(self):
208 | phase = 'train'
209 | since = time.time()
210 | best_loss = float('inf')
211 |
212 | set_requires_grad(self.styleTranslator, requires_grad=False) # freeze style translator T
213 | set_requires_grad(self.inpaintNet, requires_grad=False) # freeze inpainting module I
214 |
215 | tensorboardX_iter_count = 0
216 | for epoch in range(self.total_epoch_num):
217 | print('\nEpoch {}/{}'.format(epoch+1, self.total_epoch_num))
218 | print('-' * 10)
219 | fn = open(self.train_log,'a')
220 | fn.write('\nEpoch {}/{}\n'.format(epoch+1, self.total_epoch_num))
221 | fn.write('--'*5+'\n')
222 | fn.close()
223 |
224 | self._set_models_train(['attModule', 'inpaintNet', 'styleTranslator', 'depthEstModel'])
225 | iterCount = 0
226 |
227 | for sample_dict in self.dataloaders_xLabels_joint:
228 | imageListReal, depthListReal = sample_dict['real']
229 | imageListSyn, depthListSyn = sample_dict['syn']
230 |
231 | imageListSyn = imageListSyn.to(self.device)
232 | depthListSyn = depthListSyn.to(self.device)
233 | imageListReal = imageListReal.to(self.device)
234 | depthListReal = depthListReal.to(self.device)
235 | valid_mask = (depthListReal > -1.)
236 |
237 | B, C, H, W = imageListReal.size()[0], imageListReal.size()[1], imageListReal.size()[2], imageListReal.size()[3]
238 |
239 | with torch.set_grad_enabled(phase=='train'):
240 | r2s_img = self.styleTranslator(imageListReal)[-1]
241 | confident_score = self.attModule(imageListReal)[-1]
242 | # convert to sparse confident score
243 | confident_score = self.compute_spare_attention(confident_score, t=self.tau_min, isTrain=True)
244 |
245 | mod_r2s_img = r2s_img * confident_score
246 | inpainted_r2s = self.inpaintNet(mod_r2s_img)[-1]
247 |
248 | reconst_img = inpainted_r2s * (1. - confident_score) + confident_score * r2s_img
249 |
250 | # update depth predictor and attention module
251 | self.optim_depth.zero_grad()
252 | total_loss = 0.
253 | real_depth_loss = self.compute_depth_loss(reconst_img, depthListReal, self.depthEstModel, valid_mask)
254 | syn_depth_loss = self.compute_depth_loss(imageListSyn, depthListSyn, self.depthEstModel)
255 | KL_loss = self.compute_KL_div(confident_score, target=self.rho) * self.KL_loss_weight
256 |
257 | fake_pred = self.netD(inpainted_r2s)
258 | fake_label = torch.full(fake_pred.size(), self.real_label, device=self.device)
259 | fake_loss = self.loss_BCE(fake_pred, fake_label) * self.fake_loss_weight
260 |
261 | total_loss += (real_depth_loss + syn_depth_loss + KL_loss + fake_loss)
262 | if self.use_apex:
263 | with amp.scale_loss(total_loss, self.optim_depth, loss_id=0) as total_loss_scaled:
264 | total_loss_scaled.backward()
265 | else:
266 | total_loss.backward()
267 |
268 | self.optim_depth.step()
269 |
270 | # stop adding adversaial loss after stable
271 | if epoch <= 100:
272 | self.optim_netD.zero_grad()
273 | netD_loss = 0.
274 | netD_loss, _, _ = self.compute_D_loss(imageListSyn, inpainted_r2s.detach(), self.netD)
275 |
276 | if self.use_apex:
277 | with amp.scale_loss(netD_loss, self.optim_netD, loss_id=0) as netD_loss_scaled:
278 | netD_loss_scaled.backward()
279 | else:
280 | netD_loss.backward()
281 |
282 | self.optim_netD.step()
283 | else:
284 | netD_loss = 0.
285 | set_requires_grad(self.netD, requires_grad=False)
286 |
287 | iterCount += 1
288 |
289 | if self.use_tensorboardX:
290 | nrow = imageListReal.size()[0]
291 | self.train_display_freq = len(self.dataloaders_xLabels_joint) # feel free to adjust the display frequency
292 | if tensorboardX_iter_count % self.train_display_freq == 0:
293 | img_concat = torch.cat((imageListReal, r2s_img, mod_r2s_img, inpainted_r2s, reconst_img), dim=0)
294 | self.write_2_tensorboardX(self.train_SummaryWriter, img_concat, name='real, r2s, r2sMasked, inpaintedR2s, reconst', mode='image',
295 | count=tensorboardX_iter_count, nrow=nrow)
296 |
297 | self.write_2_tensorboardX(self.train_SummaryWriter, confident_score, name='Attention', mode='image',
298 | count=tensorboardX_iter_count, nrow=nrow, value_range=(0., 1.0))
299 |
300 | # add loss values
301 | loss_val_list = [total_loss, real_depth_loss, syn_depth_loss, KL_loss, fake_loss, netD_loss]
302 | loss_name_list = ['total_loss', 'real_depth_loss', 'syn_depth_loss', 'KL_loss', 'fake_loss', 'netD_loss']
303 | self.write_2_tensorboardX(self.train_SummaryWriter, loss_val_list, name=loss_name_list, mode='scalar', count=tensorboardX_iter_count)
304 |
305 | tensorboardX_iter_count += 1
306 |
307 | if iterCount % 20 == 0:
308 | loss_summary = '\t{}/{}, total_loss: {:.7f}, netD_loss: {:.7f}'.format(iterCount, len(self.dataloaders_xLabels_joint), total_loss, netD_loss)
309 | G_loss_summary = '\t\t G loss summary: real_depth_loss: {:.7f}, syn_depth_loss: {:.7f}, KL_loss: {:.7f} fake_loss: {:.7f}'.format(real_depth_loss, syn_depth_loss, KL_loss, fake_loss)
310 |
311 | print(loss_summary)
312 | print(G_loss_summary)
313 |
314 | fn = open(self.train_log,'a')
315 | fn.write(loss_summary + '\n')
316 | fn.write(G_loss_summary + '\n')
317 | fn.close()
318 |
319 | # take step in optimizer
320 | for scheduler in self.scheduler_list:
321 | scheduler.step()
322 | for optim in self.optim_name:
323 | lr = getattr(self, optim).param_groups[0]['lr']
324 | lr_update = 'Epoch {}/{} finished: {} learning rate = {:.7f}'.format(epoch+1, self.total_epoch_num, optim, lr)
325 | print(lr_update)
326 | fn = open(self.train_log,'a')
327 | fn.write(lr_update)
328 | fn.close()
329 |
330 | if (epoch+1) % self.save_steps == 0:
331 | self.save_models(['depthEstModel', 'attModule'], mode=epoch+1)
332 | self.evaluate(epoch+1)
333 |
334 | time_elapsed = time.time() - since
335 | print('\nTraining complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
336 |
337 | fn = open(self.train_log,'a')
338 | fn.write('\nTraining complete in {:.0f}m {:.0f}s\n'.format(time_elapsed // 60, time_elapsed % 60))
339 | fn.close()
340 |
341 | best_model_summary = '\nOverall best model is epoch {}'.format(self.EVAL_best_model_epoch)
342 | print(best_model_summary)
343 | print(self.EVAL_all_results[str(self.EVAL_best_model_epoch)])
344 | fn = open(self.evaluate_log, 'a')
345 | fn.write(best_model_summary + '\n')
346 | fn.write(self.EVAL_all_results[str(self.EVAL_best_model_epoch)])
347 | fn.close()
348 |
349 | def evaluate(self, mode):
350 | '''
351 | mode choose from or best
352 | is the number of epoch, represents the number of epoch, used for in training evaluation
353 | 'best' is used for after training mode
354 | '''
355 | set_name = 'test'
356 | eval_model_list = ['attModule', 'inpaintNet', 'styleTranslator', 'depthEstModel']
357 |
358 | if isinstance(mode, int) and self.isTrain:
359 | self._set_models_eval(eval_model_list)
360 | if self.EVAL_best_loss == float('inf'):
361 | fn = open(self.evaluate_log, 'w')
362 | else:
363 | fn = open(self.evaluate_log, 'a')
364 |
365 | fn.write('Evaluating with mode: {}\n'.format(mode))
366 | fn.write('\tEvaluation range min: {} | max: {} \n'.format(self.EVAL_DEPTH_MIN, self.EVAL_DEPTH_MAX))
367 | fn.close()
368 |
369 | else:
370 | self._load_models(eval_model_list, mode)
371 |
372 | print('Evaluating with mode: {}'.format(mode))
373 | print('\tEvaluation range min: {} | max: {}'.format(self.EVAL_DEPTH_MIN, self.EVAL_DEPTH_MAX))
374 |
375 | total_loss, count = 0., 0
376 | predTensor = torch.zeros((1, 1, self.cropSize_h, self.cropSize_w)).to('cpu')
377 | grndTensor = torch.zeros((1, 1, self.cropSize_h, self.cropSize_w)).to('cpu')
378 | idx = 0
379 |
380 | tensorboardX_iter_count = 0
381 | for sample in self.dataloaders_single[set_name]:
382 | imageList, depthList = sample
383 | valid_mask = np.logical_and(depthList > self.EVAL_DEPTH_MIN, depthList < self.EVAL_DEPTH_MAX)
384 |
385 | idx += imageList.shape[0]
386 | print('epoch {}: have processed {} number samples in {} set'.format(mode, str(idx), set_name))
387 | imageList = imageList.to(self.device)
388 | depthList = depthList.to(self.device)
389 |
390 | if self.isTrain and self.use_apex:
391 | with amp.disable_casts():
392 | r2s_img = self.styleTranslator(imageList)[-1]
393 | confident_score = self.attModule(imageList)[-1]
394 | # convert to sparse confident score
395 | confident_score = self.compute_spare_attention(confident_score, t=self.tau_min, isTrain=False)
396 | # hard threshold
397 | confident_score[confident_score < 0.5] = 0.
398 | confident_score[confident_score >= 0.5] = 1.
399 | mod_r2s_img = r2s_img * confident_score
400 | inpainted_r2s = self.inpaintNet(mod_r2s_img)[-1]
401 | reconst_img = inpainted_r2s * (1. - confident_score) + confident_score * r2s_img
402 | predList = self.depthEstModel(reconst_img)[-1].detach().to('cpu') # [-1, 1]
403 |
404 | else:
405 | r2s_img = self.styleTranslator(imageList)[-1]
406 | confident_score = self.attModule(imageList)[-1]
407 | # convert to sparse confident score
408 | confident_score = self.compute_spare_attention(confident_score, t=self.tau_min, isTrain=False)
409 | # hard threshold
410 | confident_score[confident_score < 0.5] = 0.
411 | confident_score[confident_score >= 0.5] = 1.
412 | mod_r2s_img = r2s_img * confident_score
413 | inpainted_r2s = self.inpaintNet(mod_r2s_img)[-1]
414 | reconst_img = inpainted_r2s * (1. - confident_score) + confident_score * r2s_img
415 | predList = self.depthEstModel(reconst_img)[-1].detach().to('cpu') # [-1, 1]
416 |
417 | # recover real depth
418 | predList = (predList + 1.0) * 0.5 * self.NYU_MAX_DEPTH_CLIP
419 | depthList = depthList.detach().to('cpu')
420 | predTensor = torch.cat((predTensor, predList), dim=0)
421 | grndTensor = torch.cat((grndTensor, depthList), dim=0)
422 |
423 | if self.use_tensorboardX:
424 | nrow = imageList.size()[0]
425 | if tensorboardX_iter_count % self.val_display_freq == 0:
426 | depth_concat = torch.cat((depthList, predList), dim=0)
427 | self.write_2_tensorboardX(self.eval_SummaryWriter, depth_concat, name='{}: ground truth and depth prediction'.format(set_name),
428 | mode='image', count=tensorboardX_iter_count, nrow=nrow, value_range=(0.0, self.NYU_MAX_DEPTH_CLIP))
429 |
430 | tensorboardX_iter_count += 1
431 |
432 | if isinstance(mode, int) and self.isTrain:
433 | eval_depth_loss = self.L1loss(predList[valid_mask], depthList[valid_mask])
434 | total_loss += eval_depth_loss.detach().cpu()
435 |
436 | count += 1
437 |
438 | if isinstance(mode, int) and self.isTrain:
439 | validation_loss = (total_loss / count)
440 | print('validation loss is {:.7f}'.format(validation_loss))
441 | if self.use_tensorboardX:
442 | self.write_2_tensorboardX(self.eval_SummaryWriter, validation_loss, name='validation loss', mode='scalar', count=mode)
443 |
444 | results = Result(mask_min=self.EVAL_DEPTH_MIN, mask_max=self.EVAL_DEPTH_MAX)
445 | results.evaluate(predTensor[1:], grndTensor[1:])
446 |
447 | result1 = '\tabs_rel:{:.3f}, sq_rel:{:.3f}, rmse:{:.3f}, rmse_log:{:.3f}, mae:{:.3f} '.format(
448 | results.absrel,results.sqrel,results.rmse,results.rmselog,results.mae)
449 | result2 = '\t[<1.25]:{:.3f}, [<1.25^2]:{:.3f}, [<1.25^3]::{:.3f}'.format(results.delta1,results.delta2,results.delta3)
450 |
451 | print(result1)
452 | print(result2)
453 |
454 | if isinstance(mode, int) and self.isTrain:
455 | self.EVAL_all_results[str(mode)] = result1 + '\t' + result2
456 |
457 | if validation_loss.item() < self.EVAL_best_loss:
458 | self.EVAL_best_loss = validation_loss.item()
459 | self.EVAL_best_model_epoch = mode
460 | self.save_models(['depthEstModel', 'attModule'], mode='best')
461 |
462 | best_model_summary = '\tCurrent best loss {:.7f}, current best model {}\n'.format(self.EVAL_best_loss, self.EVAL_best_model_epoch)
463 | print(best_model_summary)
464 |
465 | fn = open(self.evaluate_log, 'a')
466 | fn.write(result1 + '\n')
467 | fn.write(result2 + '\n')
468 | fn.write(best_model_summary + '\n')
469 | fn.close()
--------------------------------------------------------------------------------
/training/train_initial_attention_module_A.py:
--------------------------------------------------------------------------------
1 | import os, time, sys
2 | import torch
3 | from torch.utils.data import Dataset, DataLoader
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | from torch.optim import lr_scheduler
7 | import torch.nn.functional as F
8 | from torch.autograd import Variable
9 |
10 | import torchvision
11 | from torchvision import datasets, models, transforms
12 | from torchvision.utils import make_grid
13 | from tensorboardX import SummaryWriter
14 |
15 | from models.depth_generator_networks import _UNetGenerator, init_weights, _ResGenerator_Upsample
16 | from models.discriminator_networks import Discriminator80x80InstNorm
17 | from models.attention_networks import _Attention_FullRes
18 |
19 | from utils.metrics import *
20 | from utils.image_pool import ImagePool
21 |
22 | from training.base_model import set_requires_grad, base_model
23 |
24 | try:
25 | from apex import amp
26 | except ImportError:
27 | print("\nPlease consider install apex from https://www.github.com/nvidia/apex to run with apex or set use_apex = False\n")
28 |
29 | import warnings # ignore warnings
30 | warnings.filterwarnings("ignore")
31 |
32 | def value_scheduler(start, total_num_epoch, end=None, ratio=None, step_size=None, multiple=None, mode='linear'):
33 | if mode == 'linear':
34 | return np.linspace(start, end, total_num_epoch)
35 | elif mode == 'linear_ratio':
36 | assert ratio is not None
37 | linear = np.linspace(start, end, total_num_epoch * ratio)
38 | stable = np.repeat(end, total_num_epoch * (1 - ratio))
39 | return np.concatenate((linear, stable))
40 |
41 | elif mode == 'step_wise':
42 | assert step_size is not None
43 | times, res = divmod(total_num_epoch, step_size)
44 | for i in range(0, times):
45 | value = np.repeat(start * (multiple**i), step_size)
46 | if i == 0:
47 | final = value
48 | else:
49 | final = np.concatenate((final, value))
50 |
51 | if res != 0:
52 | final = np.concatenate((final, np.repeat(start * (multiple**(times)), res)))
53 | return final
54 |
55 | class train_initial_attention_module_A(base_model):
56 | def __init__(self, args, dataloaders_xLabels_joint, dataloaders_single):
57 | super(train_initial_attention_module_A, self).__init__(args)
58 | self._initialize_training()
59 |
60 | self.dataloaders_single = dataloaders_single
61 | self.dataloaders_xLabels_joint = dataloaders_xLabels_joint
62 |
63 | # define loss weights
64 | self.lambda_identity = 0.5 # coefficient of identity mapping score
65 | self.lambda_real = 10.0
66 | self.lambda_synthetic = 10.0
67 | self.lambda_GAN = 1.0
68 |
69 | self.KL_loss_weight_max = 1.
70 | self.rho = 0.99
71 | self.tau_min = 0.05
72 | self.tau_max = 0.9
73 |
74 | self.pool_size = 50
75 | self.generated_syn_pool = ImagePool(self.pool_size)
76 | self.generated_real_pool = ImagePool(self.pool_size)
77 |
78 | self.attModule = _Attention_FullRes(input_nc = 3, output_nc = 1)
79 | self.netD_s = Discriminator80x80InstNorm(input_nc = 3)
80 | self.netD_r = Discriminator80x80InstNorm(input_nc = 3)
81 | self.netG_s2r = _ResGenerator_Upsample(input_nc = 3, output_nc = 3)
82 | self.netG_r2s = _ResGenerator_Upsample(input_nc = 3, output_nc = 3)
83 |
84 | self.model_name = ['netD_s', 'netD_r', 'netG_s2r', 'netG_r2s', 'attModule']
85 | self.L1loss = nn.L1Loss()
86 |
87 | if self.isTrain:
88 | self.netD_optimizer = optim.Adam(list(self.netD_s.parameters()) + list(self.netD_r.parameters()), lr=self.D_lr, betas=(0.5, 0.999))
89 | self.netG_optimizer = optim.Adam(list(self.netG_r2s.parameters()) + list(self.netG_s2r.parameters()) + list(self.attModule.parameters()), lr=self.G_lr, betas=(0.5, 0.999))
90 | self.optim_name = ['netD_optimizer', 'netG_optimizer']
91 | self._get_scheduler()
92 | self.loss_BCE = nn.BCEWithLogitsLoss()
93 | self._initialize_networks()
94 |
95 | # apex can only be applied to CUDA models
96 | if self.use_apex:
97 | self._init_apex(Num_losses=3)
98 |
99 | self._check_parallel()
100 |
101 | def _get_project_name(self):
102 | return 'train_initial_attention_module_A'
103 |
104 | def _initialize_networks(self):
105 | for name in self.model_name:
106 | getattr(self, name).train().to(self.device)
107 | init_weights(getattr(self, name), net_name=name, init_type='normal', gain=0.02)
108 |
109 | def compute_D_loss(self, real_sample, fake_sample, netD):
110 | loss = 0
111 | syn_acc = 0
112 | real_acc = 0
113 |
114 | output = netD(fake_sample)
115 | label = torch.full((output.size()), self.syn_label, device=self.device)
116 |
117 | predSyn = (output > 0.5).to(self.device, dtype=torch.float32)
118 | total_num = torch.numel(output)
119 | syn_acc += (predSyn==label).type(torch.float32).sum().item()/total_num
120 | loss += self.loss_BCE(output, label)
121 |
122 | output = netD(real_sample)
123 | label = torch.full((output.size()), self.real_label, device=self.device)
124 |
125 | predReal = (output > 0.5).to(self.device, dtype=torch.float32)
126 | real_acc += (predReal==label).type(torch.float32).sum().item()/total_num
127 | loss += self.loss_BCE(output, label)
128 |
129 | return loss, syn_acc, real_acc
130 |
131 | def compute_G_loss(self, real_sample, synthetic_sample, r2s_rgb, s2r_rgb, rct_real, rct_syn, cs_imageListReal):
132 | '''
133 | real_sample: [batch_size, 4, 240, 320] real rgb
134 | synthetic_sample: [batch_size, 4, 240, 320] synthetic rgb
135 | r2s_rgb: netG_r2s(real)
136 | s2r_rgb: netG_s2r(synthetic)
137 | '''
138 | non_reduction_L1loss = nn.L1Loss(reduction='none')
139 | loss = 0
140 |
141 | # identity loss if applicable
142 | if self.lambda_identity > 0:
143 | idt_real = self.netG_s2r(real_sample)[-1]
144 | idt_synthetic = self.netG_r2s(synthetic_sample)[-1]
145 | idt_loss = (self.L1loss(idt_real, real_sample) * self.lambda_real +
146 | self.L1loss(idt_synthetic, synthetic_sample) * self.lambda_synthetic) * self.lambda_identity
147 | else:
148 | idt_loss = 0
149 |
150 | # GAN loss
151 | real_pred = self.netD_r(s2r_rgb)
152 | real_label = torch.full(real_pred.size(), self.real_label, device=self.device)
153 | GAN_loss_real = self.loss_BCE(real_pred, real_label)
154 |
155 | syn_pred = self.netD_s(r2s_rgb)
156 | syn_label = torch.full(syn_pred.size(), self.real_label, device=self.device)
157 | GAN_loss_syn = self.loss_BCE(syn_pred, syn_label)
158 |
159 | GAN_loss = (GAN_loss_real + GAN_loss_syn) * self.lambda_GAN
160 |
161 | # cycle consist loss
162 | rec_real_loss = cs_imageListReal * non_reduction_L1loss(rct_real, real_sample)
163 | rec_real_loss = rec_real_loss.mean() * self.lambda_real
164 |
165 | rec_syn_loss = self.L1loss(rct_syn, synthetic_sample) * self.lambda_synthetic
166 | rec_loss = rec_real_loss + rec_syn_loss
167 |
168 | loss += (idt_loss + GAN_loss + rec_loss)
169 |
170 | return loss, idt_loss, GAN_loss, rec_loss
171 |
172 | def compute_spare_attention(self, confident_score, t, isTrain=True):
173 | # t is the temperature --> scalar
174 | if isTrain:
175 | noise = torch.rand(confident_score.size(), requires_grad=False).to(self.device)
176 | noise = (noise + 0.00001) / 1.001
177 | noise = - torch.log(- torch.log(noise))
178 |
179 | confident_score = (confident_score + 0.00001) / 1.001
180 | confident_score = (confident_score + noise) / t
181 | else:
182 | confident_score = confident_score / t
183 |
184 | confident_score = F.sigmoid(confident_score)
185 |
186 | return confident_score
187 |
188 | def compute_KL_div(self, cf, target=0.5):
189 | g = cf.mean()
190 | g = (g + 0.00001) / 1.001 # prevent g = 0. or 1.
191 | y = target * torch.log(target/g) + (1-target) * torch.log((1-target)/(1-g))
192 | return y
193 |
194 | def train(self):
195 | phase = 'train'
196 | since = time.time()
197 | best_loss = float('inf')
198 |
199 | self.train_display_freq = len(self.dataloaders_xLabels_joint) // self.tensorboard_num_display_per_epoch
200 | tau_value_scheduler = value_scheduler(self.tau_max, self.total_epoch_num, end=self.tau_min, mode='linear')
201 |
202 | tensorboardX_iter_count = 0
203 | for epoch in range(self.total_epoch_num):
204 | print('\nEpoch {}/{}'.format(epoch+1, self.total_epoch_num))
205 | print('-' * 10)
206 | fn = open(self.train_log,'a')
207 | fn.write('\nEpoch {}/{}\n'.format(epoch+1, self.total_epoch_num))
208 | fn.write('--'*5+'\n')
209 | fn.close()
210 |
211 | iterCount = 0
212 |
213 | for sample_dict in self.dataloaders_xLabels_joint:
214 | imageListReal, depthListReal = sample_dict['real']
215 | imageListSyn, depthListSyn = sample_dict['syn']
216 |
217 | imageListSyn = imageListSyn.to(self.device)
218 | depthListSyn = depthListSyn.to(self.device)
219 | imageListReal = imageListReal.to(self.device)
220 | depthListReal = depthListReal.to(self.device)
221 |
222 | with torch.set_grad_enabled(phase=='train'):
223 | s2r_rgb = self.netG_s2r(imageListSyn)[-1]
224 | rct_syn = self.netG_r2s(s2r_rgb)[-1]
225 |
226 | cs_imageListReal = self.attModule(imageListReal)[-1]
227 | cs_imageListReal = self.compute_spare_attention(cs_imageListReal, t=tau_value_scheduler[epoch], isTrain=True)
228 | mod_imageListReal = imageListReal * cs_imageListReal
229 | r2s_rgb = self.netG_r2s(mod_imageListReal)[-1]
230 |
231 | rct_real = self.netG_s2r(r2s_rgb)[-1]
232 |
233 | ############# update generator
234 | set_requires_grad([self.netD_r, self.netD_s], False)
235 | netG_loss = 0.
236 | self.netG_optimizer.zero_grad()
237 | netG_loss, G_idt_loss, G_GAN_loss, G_rec_loss = self.compute_G_loss(imageListReal, imageListSyn,
238 | r2s_rgb, s2r_rgb, rct_real, rct_syn, cs_imageListReal)
239 |
240 | KL_loss = 0.
241 | KL_loss += self.compute_KL_div(cs_imageListReal, target=self.rho) * self.KL_loss_weight_max
242 | netG_loss += KL_loss
243 |
244 | if self.use_apex:
245 | with amp.scale_loss(netG_loss, self.netG_optimizer, loss_id=0) as netG_loss_scaled:
246 | netG_loss_scaled.backward()
247 | else:
248 | netG_loss.backward()
249 |
250 | self.netG_optimizer.step()
251 |
252 | ############# update discriminator
253 | set_requires_grad([self.netD_r, self.netD_s], True)
254 |
255 | self.netD_optimizer.zero_grad()
256 |
257 | r2s_rgb_pool = self.generated_syn_pool.query(r2s_rgb)
258 | netD_s_loss, netD_s_syn_acc, netD_s_real_acc = self.compute_D_loss(imageListSyn, r2s_rgb.detach(), self.netD_s)
259 | s2r_rgb_pool = self.generated_real_pool.query(s2r_rgb)
260 | netD_r_loss, netD_r_syn_acc, netD_r_real_acc = self.compute_D_loss(imageListReal, s2r_rgb.detach(), self.netD_r)
261 |
262 | netD_loss = netD_s_loss + netD_r_loss
263 |
264 | if self.use_apex:
265 | with amp.scale_loss(netD_loss, self.netD_optimizer, loss_id=1) as netD_loss_scaled:
266 | netD_loss_scaled.backward()
267 | else:
268 | netD_loss.backward()
269 | self.netD_optimizer.step()
270 |
271 | iterCount += 1
272 |
273 | if self.use_tensorboardX:
274 | self.train_display_freq = len(self.dataloaders_xLabels_joint) # feel free to adjust the display frequency
275 | nrow = imageListReal.size()[0]
276 | if tensorboardX_iter_count % self.train_display_freq == 0:
277 | s2r_rgb_concat = torch.cat((imageListSyn, s2r_rgb, imageListReal, rct_syn), dim=0)
278 | self.write_2_tensorboardX(self.train_SummaryWriter, s2r_rgb_concat, name='RGB: syn, s2r, real, reconstruct syn', mode='image',
279 | count=tensorboardX_iter_count, nrow=nrow)
280 |
281 | r2s_rgb_concat = torch.cat((imageListReal, r2s_rgb, imageListSyn, rct_real), dim=0)
282 | self.write_2_tensorboardX(self.train_SummaryWriter, r2s_rgb_concat, name='RGB: real, r2s, synthetic, reconstruct real', mode='image',
283 | count=tensorboardX_iter_count, nrow=nrow)
284 |
285 | self.write_2_tensorboardX(self.train_SummaryWriter, cs_imageListReal, name='Atten: real', mode='image',
286 | count=tensorboardX_iter_count, nrow=nrow, value_range=(0.0, 1.0))
287 |
288 | loss_val_list = [netD_loss, netG_loss, KL_loss]
289 | loss_name_list = ['netD_loss', 'netG_loss', 'KL_loss']
290 | self.write_2_tensorboardX(self.train_SummaryWriter, loss_val_list, name=loss_name_list, mode='scalar', count=tensorboardX_iter_count)
291 |
292 | tensorboardX_iter_count += 1
293 |
294 | if iterCount % 20 == 0:
295 | loss_summary = '\t{}/{} netD: {:.7f}, netG: {:.7f}'.format(iterCount, len(self.dataloaders_xLabels_joint), netD_loss, netG_loss)
296 | G_loss_summary = '\t\tG loss summary: netG: {:.7f}, idt_loss: {:.7f}, GAN_loss: {:.7f}, rec_loss: {:.7f}, KL_loss: {:.7f}'.format(
297 | netG_loss, G_idt_loss, G_GAN_loss, G_rec_loss, KL_loss)
298 |
299 | print(loss_summary)
300 | print(G_loss_summary)
301 |
302 | fn = open(self.train_log,'a')
303 | fn.write(loss_summary + '\n')
304 | fn.write(G_loss_summary + '\n')
305 | fn.close()
306 |
307 | if (epoch+1) % self.save_steps == 0:
308 | self.save_models(['attModule'], mode=epoch+1)
309 |
310 | # take step in optimizer
311 | for scheduler in self.scheduler_list:
312 | scheduler.step()
313 | for optim in self.optim_name:
314 | lr = getattr(self, optim).param_groups[0]['lr']
315 | lr_update = 'Epoch {}/{} finished: {} learning rate = {:.7f}'.format(epoch+1, self.total_epoch_num, optim, lr)
316 | print(lr_update)
317 |
318 | fn = open(self.train_log,'a')
319 | fn.write(lr_update + '\n')
320 | fn.close()
321 |
322 | time_elapsed = time.time() - since
323 | print('\nTraining complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
324 |
325 | fn = open(self.train_log,'a')
326 | fn.write('\nTraining complete in {:.0f}m {:.0f}s\n'.format(time_elapsed // 60, time_elapsed % 60))
327 | fn.close()
328 |
329 | def evaluate(self, mode):
330 | pass
--------------------------------------------------------------------------------
/training/train_initial_depth_predictor_D.py:
--------------------------------------------------------------------------------
1 | import os, time, sys
2 | import torch
3 | from torch.utils.data import Dataset, DataLoader
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | from torch.optim import lr_scheduler
7 | import torch.nn.functional as F
8 | from torch.autograd import Variable
9 |
10 | import torchvision
11 | from torchvision import datasets, models, transforms
12 | from torchvision.utils import make_grid
13 | from tensorboardX import SummaryWriter
14 |
15 | from models.depth_generator_networks import _UNetGenerator, init_weights
16 |
17 | from utils.metrics import *
18 | from utils.image_pool import ImagePool
19 |
20 | from training.base_model import set_requires_grad, base_model
21 |
22 | try:
23 | from apex import amp
24 | except ImportError:
25 | print("\nPlease consider install apex from https://www.github.com/nvidia/apex to run with apex or set use_apex = False\n")
26 |
27 | import warnings # ignore warnings
28 | warnings.filterwarnings("ignore")
29 |
30 | class train_initial_depth_predictor_D(base_model):
31 | def __init__(self, args, dataloaders_xLabels_joint, dataloaders_single):
32 | super(train_initial_depth_predictor_D, self).__init__(args)
33 | self._initialize_training()
34 | # self.KITTI_MAX_DEPTH_CLIP = 80.0
35 | # self.EVAL_DEPTH_MIN = 1.0
36 | # self.EVAL_DEPTH_MAX = 50.0
37 |
38 | self.NYU_MAX_DEPTH_CLIP = 10.0
39 | self.EVAL_DEPTH_MIN = 1.0
40 | self.EVAL_DEPTH_MAX = 8.0
41 |
42 | self.dataloaders_single = dataloaders_single
43 | self.dataloaders_xLabels_joint = dataloaders_xLabels_joint
44 |
45 | self.depthEstModel = _UNetGenerator(input_nc = 3, output_nc = 1)
46 | self.model_name = ['depthEstModel']
47 | self.L1loss = nn.L1Loss()
48 |
49 | if self.isTrain:
50 | self.depth_optimizer = optim.Adam(self.depthEstModel.parameters(), lr=self.task_lr, betas=(0.5, 0.999))
51 | self.optim_name = ['depth_optimizer']
52 | self._get_scheduler()
53 | self.loss_BCE = nn.BCEWithLogitsLoss()
54 | self._initialize_networks()
55 |
56 | # apex can only be applied to CUDA models
57 | if self.use_apex:
58 | self._init_apex(Num_losses=2)
59 |
60 | self.EVAL_best_loss = float('inf')
61 | self.EVAL_best_model_epoch = 0
62 | self.EVAL_all_results = {}
63 |
64 | self._check_parallel()
65 |
66 | def _get_project_name(self):
67 | return 'train_initial_depth_predictor_D'
68 |
69 | def _initialize_networks(self):
70 | for name in self.model_name:
71 | getattr(self, name).train().to(self.device)
72 | init_weights(getattr(self, name), net_name=name, init_type='normal', gain=0.02)
73 |
74 | def compute_depth_loss(self, input_rgb, depth_label, depthEstModel, valid_mask=None):
75 |
76 | prediction = depthEstModel(input_rgb)[-1]
77 | if valid_mask is not None:
78 | loss = self.L1loss(prediction[valid_mask], depth_label[valid_mask])
79 | else:
80 | assert valid_mask == None
81 | loss = self.L1loss(prediction, depth_label)
82 |
83 | return loss
84 |
85 | def train(self):
86 | phase = 'train'
87 | since = time.time()
88 | best_loss = float('inf')
89 |
90 | tensorboardX_iter_count = 0
91 | for epoch in range(self.total_epoch_num):
92 | print('\nEpoch {}/{}'.format(epoch+1, self.total_epoch_num))
93 | print('-' * 10)
94 | fn = open(self.train_log,'a')
95 | fn.write('\nEpoch {}/{}\n'.format(epoch+1, self.total_epoch_num))
96 | fn.write('--'*5+'\n')
97 | fn.close()
98 |
99 | self._set_models_train(['depthEstModel'])
100 | iterCount = 0
101 |
102 | for sample_dict in self.dataloaders_xLabels_joint:
103 | imageListReal, depthListReal = sample_dict['real']
104 | imageListSyn, depthListSyn = sample_dict['syn']
105 |
106 | imageListSyn = imageListSyn.to(self.device)
107 | depthListSyn = depthListSyn.to(self.device)
108 | imageListReal = imageListReal.to(self.device)
109 | depthListReal = depthListReal.to(self.device)
110 | valid_mask = (depthListReal > -1.) # remove undefined regions
111 |
112 | with torch.set_grad_enabled(phase=='train'):
113 | total_loss = 0.
114 | self.depth_optimizer.zero_grad()
115 | real_depth_loss = self.compute_depth_loss(imageListReal, depthListReal, self.depthEstModel, valid_mask)
116 | syn_depth_loss = self.compute_depth_loss(imageListSyn, depthListSyn, self.depthEstModel)
117 | total_loss += (real_depth_loss + syn_depth_loss)
118 |
119 | if self.use_apex:
120 | with amp.scale_loss(total_loss, self.depth_optimizer) as total_loss_scaled:
121 | total_loss_scaled.backward()
122 | else:
123 | total_loss.backward()
124 |
125 | self.depth_optimizer.step()
126 |
127 | iterCount += 1
128 |
129 | if self.use_tensorboardX:
130 | self.train_display_freq = len(self.dataloaders_xLabels_joint)
131 | nrow = imageListReal.size()[0]
132 | if tensorboardX_iter_count % self.train_display_freq == 0:
133 | pred_depth_real = self.depthEstModel(imageListReal)[-1]
134 |
135 | tensorboardX_grid_real_rgb = make_grid(imageListReal, nrow=nrow, normalize=True, range=(-1.0, 1.0))
136 | self.train_SummaryWriter.add_image('real rgb images', tensorboardX_grid_real_rgb, tensorboardX_iter_count)
137 |
138 | tensorboardX_depth_concat = torch.cat((depthListReal, pred_depth_real), dim=0)
139 | tensorboardX_grid_real_depth = make_grid(tensorboardX_depth_concat, nrow=nrow, normalize=True, range=(-1.0, 1.0))
140 | self.train_SummaryWriter.add_image('real depth and depth prediction', tensorboardX_grid_real_depth, tensorboardX_iter_count)
141 |
142 | # add loss values
143 | loss_val_list = [total_loss, real_depth_loss, syn_depth_loss]
144 | loss_name_list = ['total_loss', 'real_depth_loss', 'syn_depth_loss']
145 | self.write_2_tensorboardX(self.train_SummaryWriter, loss_val_list, name=loss_name_list, mode='scalar', count=tensorboardX_iter_count)
146 |
147 | tensorboardX_iter_count += 1
148 |
149 | if iterCount % 20 == 0:
150 | loss_summary = '\t{}/{} total_loss: {:.7f}, real_depth_loss: {:.7f}, syn_depth_loss: {:.7f}'.format(
151 | iterCount, len(self.dataloaders_xLabels_joint), total_loss, real_depth_loss, syn_depth_loss)
152 |
153 | print(loss_summary)
154 | fn = open(self.train_log,'a')
155 | fn.write(loss_summary)
156 | fn.close()
157 |
158 | # take step in optimizer
159 | for scheduler in self.scheduler_list:
160 | scheduler.step()
161 | for optim in self.optim_name:
162 | lr = getattr(self, optim).param_groups[0]['lr']
163 | lr_update = 'Epoch {}/{} finished: {} learning rate = {:.7f}'.format(epoch+1, self.total_epoch_num, optim, lr)
164 | print(lr_update)
165 |
166 | fn = open(self.train_log,'a')
167 | fn.write(lr_update)
168 | fn.close()
169 |
170 | if (epoch+1) % self.save_steps == 0:
171 | self.save_models(self.model_name, mode=epoch+1)
172 | self.evaluate(epoch+1)
173 |
174 | time_elapsed = time.time() - since
175 | print('\nTraining complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
176 |
177 | fn = open(self.train_log,'a')
178 | fn.write('\nTraining complete in {:.0f}m {:.0f}s\n'.format(time_elapsed // 60, time_elapsed % 60))
179 | fn.close()
180 |
181 | best_model_summary = '\nOverall best model is epoch {}'.format(self.EVAL_best_model_epoch)
182 | print(best_model_summary)
183 | print(self.EVAL_all_results[str(self.EVAL_best_model_epoch)])
184 | fn = open(self.evaluate_log, 'a')
185 | fn.write(best_model_summary + '\n')
186 | fn.write(self.EVAL_all_results[str(self.EVAL_best_model_epoch)])
187 | fn.close()
188 |
189 | def evaluate(self, mode):
190 | '''
191 | mode choose from or best
192 | is the number of epoch, represents the number of epoch, used for in training evaluation
193 | 'best' is used for after training mode
194 | '''
195 |
196 | set_name = 'test'
197 | eval_model_list = ['depthEstModel']
198 |
199 | if isinstance(mode, int) and self.isTrain:
200 | self._set_models_eval(eval_model_list)
201 | if self.EVAL_best_loss == float('inf'):
202 | fn = open(self.evaluate_log, 'w')
203 | else:
204 | fn = open(self.evaluate_log, 'a')
205 |
206 | fn.write('Evaluating with mode: {}\n'.format(mode))
207 | fn.write('\tEvaluation range min: {} | max: {} \n'.format(self.EVAL_DEPTH_MIN, self.EVAL_DEPTH_MAX))
208 | fn.close()
209 |
210 | else:
211 | self._load_models(eval_model_list, mode)
212 |
213 | print('Evaluating with mode: {}'.format(mode))
214 | print('\tEvaluation range min: {} | max: {}'.format(self.EVAL_DEPTH_MIN, self.EVAL_DEPTH_MAX))
215 |
216 | total_loss, count = 0., 0
217 | predTensor = torch.zeros((1, 1, self.cropSize_h, self.cropSize_w)).to('cpu')
218 | grndTensor = torch.zeros((1, 1, self.cropSize_h, self.cropSize_w)).to('cpu')
219 | idx = 0
220 |
221 | tensorboardX_iter_count = 0
222 | for sample in self.dataloaders_single[set_name]:
223 | imageList, depthList = sample
224 | valid_mask = np.logical_and(depthList > self.EVAL_DEPTH_MIN, depthList < self.EVAL_DEPTH_MAX)
225 |
226 | idx += imageList.shape[0]
227 | print('epoch {}: have processed {} number samples in {} set'.format(mode, str(idx), set_name))
228 | imageList = imageList.to(self.device)
229 | depthList = depthList.to(self.device) # real depth
230 |
231 | if self.isTrain and self.use_apex:
232 | with amp.disable_casts():
233 | predList = self.depthEstModel(imageList)[-1].detach().to('cpu')
234 | else:
235 | predList = self.depthEstModel(imageList)[-1].detach().to('cpu')
236 |
237 | # recover real depth
238 | predList = (predList + 1.0) * 0.5 * self.NYU_MAX_DEPTH_CLIP
239 | depthList = depthList.detach().to('cpu')
240 | predTensor = torch.cat((predTensor, predList), dim=0)
241 | grndTensor = torch.cat((grndTensor, depthList), dim=0)
242 |
243 | if self.use_tensorboardX:
244 | nrow = imageList.size()[0]
245 | if tensorboardX_iter_count % self.val_display_freq == 0:
246 | depth_concat = torch.cat((depthList, predList), dim=0)
247 | self.write_2_tensorboardX(self.eval_SummaryWriter, depth_concat, name='{}: ground truth and depth prediction'.format(set_name),
248 | mode='image', count=tensorboardX_iter_count, nrow=nrow, value_range=(0.0, self.NYU_MAX_DEPTH_CLIP))
249 |
250 | tensorboardX_iter_count += 1
251 |
252 | if isinstance(mode, int) and self.isTrain:
253 | eval_depth_loss = self.L1loss(predList[valid_mask], depthList[valid_mask])
254 | total_loss += eval_depth_loss.detach().cpu()
255 |
256 | count += 1
257 |
258 | if isinstance(mode, int) and self.isTrain:
259 | validation_loss = (total_loss / count)
260 | print('validation loss is {:.7f}'.format(validation_loss))
261 | if self.use_tensorboardX:
262 | self.write_2_tensorboardX(self.eval_SummaryWriter, validation_loss, name='validation loss', mode='scalar', count=mode)
263 |
264 | results = Result(mask_min=self.EVAL_DEPTH_MIN, mask_max=self.EVAL_DEPTH_MAX)
265 | results.evaluate(predTensor[1:], grndTensor[1:])
266 |
267 | result1 = '\tabs_rel:{:.3f}, sq_rel:{:.3f}, rmse:{:.3f}, rmse_log:{:.3f}, mae:{:.3f} '.format(
268 | results.absrel,results.sqrel,results.rmse,results.rmselog,results.mae)
269 | result2 = '\t[<1.25]:{:.3f}, [<1.25^2]:{:.3f}, [<1.25^3]::{:.3f}'.format(results.delta1,results.delta2,results.delta3)
270 |
271 | print(result1)
272 | print(result2)
273 |
274 | if isinstance(mode, int) and self.isTrain:
275 | self.EVAL_all_results[str(mode)] = result1 + '\t' + result2
276 |
277 | if validation_loss.item() < self.EVAL_best_loss:
278 | self.EVAL_best_loss = validation_loss.item()
279 | self.EVAL_best_model_epoch = mode
280 | self.save_models(self.model_name, mode='best')
281 |
282 | best_model_summary = '\tCurrent best loss {:.7f}, current best model {}\n'.format(self.EVAL_best_loss, self.EVAL_best_model_epoch)
283 | print(best_model_summary)
284 |
285 | fn = open(self.evaluate_log, 'a')
286 | fn.write(result1 + '\n')
287 | fn.write(result2 + '\n')
288 | fn.write(best_model_summary + '\n')
289 | fn.close()
--------------------------------------------------------------------------------
/training/train_inpainting_module_I.py:
--------------------------------------------------------------------------------
1 | import os, time, sys
2 | import random
3 | import torch
4 | from torch.utils.data import Dataset, DataLoader
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | from torch.optim import lr_scheduler
8 | import torch.nn.functional as F
9 | from torch.autograd import Variable
10 |
11 | import torchvision
12 | from torchvision import datasets, models, transforms
13 | from torchvision.utils import make_grid
14 | from tensorboardX import SummaryWriter
15 |
16 | from loss import PerceptualLoss, StyleLoss, VGG19
17 | from models.depth_generator_networks import _UNetGenerator, init_weights, _ResGenerator_Upsample
18 | from models.attention_networks import _Attention_FullRes
19 | from models.discriminator_networks import Discriminator80x80InstNorm, DiscriminatorGlobalLocal
20 |
21 | from utils.metrics import *
22 | from utils.image_pool import ImagePool
23 |
24 | from training.base_model import set_requires_grad, base_model
25 |
26 | try:
27 | from apex import amp
28 | except ImportError:
29 | print("\nPlease consider install apex from https://www.github.com/nvidia/apex to run with apex or set use_apex = False\n")
30 |
31 | import warnings # ignore warnings
32 | warnings.filterwarnings("ignore")
33 |
34 | class Mask_Buffer():
35 | """This class implements an image buffer that stores previously generated images.
36 |
37 | This buffer enables us to update discriminators using a history of generated images
38 | rather than the ones produced by the latest generators.
39 | """
40 |
41 | def __init__(self, pool_size):
42 | """Initialize the ImagePool class
43 |
44 | Parameters:
45 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
46 | """
47 | self.pool_size = pool_size
48 | if self.pool_size > 0: # create an empty pool
49 | self.num_imgs = 0
50 | self.images = []
51 |
52 | def query(self, images):
53 | """Return an image from the pool.
54 |
55 | Parameters:
56 | images: the latest generated images from the generator
57 |
58 | Returns images from the buffer.
59 |
60 | By 50/100, the buffer will return input images.
61 | By 50/100, the buffer will return images previously stored in the buffer,
62 | and insert the current images to the buffer.
63 | """
64 | if self.pool_size == 0: # if the buffer size is 0, do nothing
65 | return images
66 | return_images = []
67 | for image in images:
68 | image = torch.unsqueeze(image.data, 0)
69 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
70 | self.num_imgs = self.num_imgs + 1
71 | self.images.append(image)
72 | return_images.append(image)
73 | else:
74 | # p = random.uniform(0, 1)
75 | # if p > 0.5: # the buffer will always return a previously stored image, and insert the current image into the buffer
76 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
77 | tmp = self.images[random_id].clone()
78 | self.images[random_id] = image
79 | return_images.append(tmp)
80 | # else: # by another 50% chance, the buffer will return the current image
81 | # return_images.append(image)
82 | return_images = torch.cat(return_images, 0) # collect all the images and return
83 | return return_images
84 |
85 | class train_inpainting_module_I(base_model):
86 | def __init__(self, args, dataloaders_xLabels_joint, dataloaders_single):
87 | super(train_inpainting_module_I, self).__init__(args)
88 | self._initialize_training()
89 |
90 | self.dataloaders_single = dataloaders_single
91 | self.dataloaders_xLabels_joint = dataloaders_xLabels_joint
92 |
93 | self.use_apex = False # use apex might cause style loss to be 0
94 |
95 | self.mask_buffer = Mask_Buffer(500)
96 |
97 | self.attModule = _Attention_FullRes(input_nc = 3, output_nc = 1) # logits, no tanh()
98 | self.inpaintNet = _ResGenerator_Upsample(input_nc = 3, output_nc = 3)
99 | self.styleTranslator = _ResGenerator_Upsample(input_nc = 3, output_nc = 3)
100 | self.netD = DiscriminatorGlobalLocal(image_size=240)
101 |
102 | self.tau_min = 0.05
103 | self.use_perceptual_loss = True
104 |
105 | self.p_vgg = VGG19()
106 | self.s_vgg = VGG19()
107 |
108 | self.perceptual_loss = PerceptualLoss(vgg19=self.p_vgg)
109 | self.style_loss = StyleLoss(vgg19=self.s_vgg)
110 |
111 | self.reconst_loss_weight = 1.0
112 | self.perceptual_loss_weight = 1.0
113 | self.style_loss_weight = 1.0
114 | self.fake_loss_weight = 0.01
115 |
116 | self.model_name = ['attModule', 'inpaintNet', 'styleTranslator', 'netD', 'p_vgg', 's_vgg']
117 | self.L1loss = nn.L1Loss()
118 |
119 | if self.isTrain:
120 | self.optim_inpaintNet = optim.Adam(self.inpaintNet.parameters(), lr=self.task_lr, betas=(0.5, 0.999))
121 | self.optim_netD = optim.Adam(self.netD.parameters(), lr=self.task_lr, betas=(0.5, 0.999))
122 | self.optim_name = ['optim_inpaintNet', 'optim_netD']
123 | self._get_scheduler()
124 | self.loss_BCE = nn.BCEWithLogitsLoss()
125 | self._initialize_networks(['inpaintNet', 'netD'])
126 |
127 | # load the "best" style translator T (from step 2)
128 | preTrain_path = os.path.join(os.getcwd(), 'experiments', 'train_style_translator_T')
129 | self._load_models(model_list=['styleTranslator'], mode=480, isTrain=True, model_path=preTrain_path)
130 | print('Successfully loaded pre-trained {} model from {}'.format('styleTranslator', preTrain_path))
131 |
132 | # load the "best" attention module A (from step 3)
133 | preTrain_path = os.path.join(os.getcwd(), 'experiments', 'train_initial_attention_module_A')
134 | self._load_models(model_list=['attModule'], mode=450, isTrain=True, model_path=preTrain_path)
135 | print('Successfully loaded pre-trained {} model from {}'.format('attModule', preTrain_path))
136 |
137 | # apex can only be applied to CUDA models
138 | if self.use_apex:
139 | self._init_apex(Num_losses=2)
140 |
141 | self._check_parallel()
142 |
143 | def _get_project_name(self):
144 | return 'train_inpainting_module_I'
145 |
146 | def _initialize_networks(self, model_name):
147 | for name in model_name:
148 | getattr(self, name).train().to(self.device)
149 | init_weights(getattr(self, name), net_name=name, init_type='normal', gain=0.02)
150 |
151 | def compute_spare_attention(self, confident_score, t, isTrain=True):
152 | # t is the temperature --> scalar
153 | if isTrain:
154 | noise = torch.rand(confident_score.size(), requires_grad=False).to(self.device)
155 | noise = (noise + 0.00001) / 1.001
156 | noise = - torch.log(- torch.log(noise))
157 |
158 | confident_score = (confident_score + 0.00001) / 1.001
159 | confident_score = (confident_score + noise) / t
160 | else:
161 | confident_score = confident_score / t
162 |
163 | confident_score = F.sigmoid(confident_score)
164 |
165 | return confident_score
166 |
167 | def compute_KL_div(self, cf, target=0.5):
168 | g = cf.mean()
169 | g = (g + 0.00001) / 1.001 # prevent g = 0. or 1.
170 | y = target*torch.log(target/g) + (1-target)*torch.log((1-target)/(1-g))
171 | return y
172 |
173 | def compute_real_fake_loss(self, scores, loss_type, datasrc = 'real', loss_for='discr'):
174 | if loss_for == 'discr':
175 | if datasrc == 'real':
176 | if loss_type == 'lsgan':
177 | # The Loss for least-square gan
178 | d_loss = torch.pow(scores - 1., 2).mean()
179 | elif loss_type == 'hinge':
180 | # Hinge loss used in the spectral GAN paper
181 | d_loss = - torch.mean(torch.clamp(scores-1.,max=0.))
182 | elif loss_type == 'wgan':
183 | # The Loss for Wgan
184 | d_loss = - torch.mean(scores)
185 | else:
186 | scores = scores.view(scores.size(0),-1).mean(dim=1)
187 | d_loss = F.binary_cross_entropy_with_logits(scores, torch.ones_like(scores).detach())
188 | else:
189 | if loss_type == 'lsgan':
190 | # The Loss for least-square gan
191 | d_loss = torch.pow((scores),2).mean()
192 | elif loss_type == 'hinge':
193 | # Hinge loss used in the spectral GAN paper
194 | d_loss = -torch.mean(torch.clamp(-scores-1.,max=0.))
195 | elif loss_type == 'wgan':
196 | # The Loss for Wgan
197 | d_loss = torch.mean(scores)
198 | else:
199 | scores = scores.view(scores.size(0),-1).mean(dim=1)
200 | d_loss = F.binary_cross_entropy_with_logits(scores, torch.zeros_like(scores).detach())
201 |
202 | return d_loss
203 | else:
204 | if loss_type == 'lsgan':
205 | # The Loss for least-square gan
206 | g_loss = torch.pow(scores - 1., 2).mean()
207 | elif (loss_type == 'wgan') or (loss_type == 'hinge') :
208 | g_loss = - torch.mean(scores)
209 | else:
210 | scores = scores.view(scores.size(0),-1).mean(dim=1)
211 | g_loss = F.binary_cross_entropy_with_logits(scores, torch.ones_like(scores).detach())
212 | return g_loss
213 |
214 | def train(self):
215 | phase = 'train'
216 | since = time.time()
217 | best_loss = float('inf')
218 |
219 | set_requires_grad(self.attModule, requires_grad=False) # freeze attention module
220 | set_requires_grad(self.styleTranslator, requires_grad=False) # freeze sytle translator
221 |
222 | tensorboardX_iter_count = 0
223 | for epoch in range(self.total_epoch_num):
224 | print('\nEpoch {}/{}'.format(epoch+1, self.total_epoch_num))
225 | print('-' * 10)
226 | fn = open(self.train_log,'a')
227 | fn.write('\nEpoch {}/{}\n'.format(epoch+1, self.total_epoch_num))
228 | fn.write('--'*5+'\n')
229 | fn.close()
230 |
231 | iterCount = 0
232 |
233 | for sample_dict in self.dataloaders_xLabels_joint:
234 | imageListReal, depthListReal = sample_dict['real']
235 | imageListSyn, depthListSyn = sample_dict['syn']
236 |
237 | imageListSyn = imageListSyn.to(self.device)
238 | depthListSyn = depthListSyn.to(self.device)
239 | imageListReal = imageListReal.to(self.device)
240 | depthListReal = depthListReal.to(self.device)
241 |
242 | B, C, H, W = imageListReal.size()[0], imageListReal.size()[1], imageListReal.size()[2], imageListReal.size()[3]
243 |
244 | with torch.set_grad_enabled(phase=='train'):
245 | r2s_img = self.styleTranslator(imageListReal)[-1]
246 | confident_score = self.attModule(imageListReal)[-1]
247 | # convert to sparse confident score
248 | confident_score = self.compute_spare_attention(confident_score, t=self.tau_min, isTrain=False)
249 | # hard threshold
250 | confident_score[confident_score < 0.5] = 0.
251 | confident_score[confident_score >= 0.5] = 1.
252 |
253 | confident_score = self.mask_buffer.query(confident_score)
254 |
255 | mod_r2s_img = r2s_img * confident_score
256 | inpainted_r2s = self.inpaintNet(mod_r2s_img)[-1]
257 |
258 | reconst_img = inpainted_r2s * (1. - confident_score) + confident_score * r2s_img
259 |
260 | # update generators
261 | self.optim_inpaintNet.zero_grad()
262 | total_loss = 0.
263 | reconst_loss = self.L1loss(inpainted_r2s, r2s_img) * self.reconst_loss_weight
264 | if self.use_perceptual_loss:
265 | perceptual_loss = self.perceptual_loss(inpainted_r2s, r2s_img) * self.perceptual_loss_weight
266 | style_loss = self.style_loss(inpainted_r2s * (1.-confident_score), r2s_img * (1.-confident_score)) * self.style_loss_weight
267 | total_loss += (perceptual_loss + style_loss)
268 |
269 | d_score, _, _ = self.netD(inpainted_r2s, boxImg=confident_score.expand(B, 3, H, W))
270 | fake_loss = self.compute_real_fake_loss(d_score, loss_type='lsgan', loss_for='generator') * self.fake_loss_weight
271 |
272 | total_loss += (reconst_loss + fake_loss)
273 | if self.use_apex:
274 | with amp.scale_loss(total_loss, self.optim_inpaintNet, loss_id=0) as total_loss_scaled:
275 | total_loss_scaled.backward()
276 | else:
277 | total_loss.backward()
278 |
279 | self.optim_inpaintNet.step()
280 |
281 | # update discriminator
282 | self.optim_netD.zero_grad()
283 |
284 | real_d_score, _, _ = self.netD(r2s_img, boxImg=confident_score.expand(B, 3, H, W))
285 | real_d_loss = self.compute_real_fake_loss(real_d_score, loss_type='lsgan', datasrc='real')
286 |
287 | fake_d_score, _, _ = self.netD(inpainted_r2s.detach(), boxImg=confident_score.expand(B, 3, H, W))
288 | fake_d_loss = self.compute_real_fake_loss(fake_d_score, loss_type='lsgan', datasrc='fake')
289 |
290 | total_d_loss = (real_d_loss + fake_d_loss)
291 |
292 | if self.use_apex:
293 | with amp.scale_loss(total_d_loss, self.optim_netD, loss_id=1) as total_d_loss_scaled:
294 | total_d_loss_scaled.backward()
295 | else:
296 | total_d_loss.backward()
297 |
298 | self.optim_netD.step()
299 |
300 | iterCount += 1
301 |
302 | if self.use_tensorboardX:
303 | nrow = imageListReal.size()[0]
304 | self.train_display_freq = len(self.dataloaders_xLabels_joint) # feel free to adjust the display frequency
305 | if tensorboardX_iter_count % self.train_display_freq == 0:
306 | img_concat = torch.cat((imageListReal, r2s_img, mod_r2s_img, inpainted_r2s, reconst_img), dim=0)
307 | self.write_2_tensorboardX(self.train_SummaryWriter, img_concat, name='real, r2s, r2sMasked, inpaintedR2s, reconst', mode='image',
308 | count=tensorboardX_iter_count, nrow=nrow)
309 |
310 | self.write_2_tensorboardX(self.train_SummaryWriter, confident_score, name='Attention', mode='image',
311 | count=tensorboardX_iter_count, nrow=nrow, value_range=(0., 1.0))
312 |
313 | # add loss values
314 | loss_val_list = [total_loss, total_d_loss]
315 | loss_name_list = ['total_loss', 'total_d_loss']
316 | self.write_2_tensorboardX(self.train_SummaryWriter, loss_val_list, name=loss_name_list, mode='scalar', count=tensorboardX_iter_count)
317 |
318 | tensorboardX_iter_count += 1
319 |
320 | if iterCount % 20 == 0:
321 | loss_summary = '\t{}/{}, total_loss: {:.7f}, total_d_loss: {:.7f}'.format(
322 | iterCount, len(self.dataloaders_xLabels_joint), total_loss, total_d_loss)
323 | G_loss_summary = '\t\t G loss summary: reconst_loss: {:.7f}, fake_loss: {:.7f}, perceptual_loss: {:.7f} style_loss: {:.7f}'.format(
324 | reconst_loss, fake_loss, perceptual_loss, style_loss)
325 | D_loss_summary = '\t\t D loss summary: real_d_loss: {:.7f}, fake_d_loss: {:.7f}'.format(real_d_loss, fake_d_loss)
326 |
327 | print(loss_summary)
328 | print(G_loss_summary)
329 | print(D_loss_summary)
330 |
331 | fn = open(self.train_log,'a')
332 | fn.write(loss_summary + '\n')
333 | fn.write(G_loss_summary + '\n')
334 | fn.write(D_loss_summary + '\n')
335 | fn.close()
336 |
337 | if (epoch+1) % self.save_steps == 0:
338 | self.save_models(['inpaintNet'], mode=epoch+1)
339 |
340 | # take step in optimizer
341 | for scheduler in self.scheduler_list:
342 | scheduler.step()
343 | # print learning rate
344 | for optim in self.optim_name:
345 | lr = getattr(self, optim).param_groups[0]['lr']
346 | lr_update = 'Epoch {}/{} finished: {} learning rate = {:.7f}'.format(epoch+1, self.total_epoch_num, optim, lr)
347 | print(lr_update)
348 |
349 | fn = open(self.train_log,'a')
350 | fn.write(lr_update + '\n')
351 | fn.close()
352 |
353 | time_elapsed = time.time() - since
354 | print('\nTraining complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
355 |
356 | fn = open(self.train_log,'a')
357 | fn.write('\nTraining complete in {:.0f}m {:.0f}s\n'.format(time_elapsed // 60, time_elapsed % 60))
358 | fn.close()
359 |
360 | def evaluate(self, mode):
361 | pass
--------------------------------------------------------------------------------
/training/train_style_translator_T.py:
--------------------------------------------------------------------------------
1 | import os, time, sys
2 | import torch
3 | from torch.utils.data import Dataset, DataLoader
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | from torch.optim import lr_scheduler
7 | import torch.nn.functional as F
8 | from torch.autograd import Variable
9 |
10 | import torchvision
11 | from torchvision import datasets, models, transforms
12 | from torchvision.utils import make_grid
13 | from tensorboardX import SummaryWriter
14 |
15 | from models.depth_generator_networks import _UNetGenerator, init_weights, _ResGenerator_Upsample
16 | from models.discriminator_networks import Discriminator80x80InstNorm
17 |
18 | from utils.metrics import *
19 | from utils.image_pool import ImagePool
20 |
21 | from training.base_model import set_requires_grad, base_model
22 |
23 | try:
24 | from apex import amp
25 | except ImportError:
26 | print("\nPlease consider install apex from https://www.github.com/nvidia/apex to run with apex or set use_apex = False\n")
27 |
28 | import warnings # ignore warnings
29 | warnings.filterwarnings("ignore")
30 |
31 | class train_style_translator_T(base_model):
32 | def __init__(self, args, dataloaders_xLabels_joint, dataloaders_single):
33 | super(train_style_translator_T, self).__init__(args)
34 | self._initialize_training()
35 |
36 | self.dataloaders_single = dataloaders_single
37 | self.dataloaders_xLabels_joint = dataloaders_xLabels_joint
38 |
39 | # define loss weights
40 | self.lambda_identity = 0.5 # coefficient of identity mapping score
41 | self.lambda_real = 10.0
42 | self.lambda_synthetic = 10.0
43 | self.lambda_GAN = 1.0
44 |
45 | # define pool size in adversarial loss
46 | self.pool_size = 50
47 | self.generated_syn_pool = ImagePool(self.pool_size)
48 | self.generated_real_pool = ImagePool(self.pool_size)
49 |
50 | self.netD_s = Discriminator80x80InstNorm(input_nc = 3)
51 | self.netD_r = Discriminator80x80InstNorm(input_nc = 3)
52 | self.netG_s2r = _ResGenerator_Upsample(input_nc = 3, output_nc = 3)
53 | self.netG_r2s = _ResGenerator_Upsample(input_nc = 3, output_nc = 3)
54 | self.model_name = ['netD_s', 'netD_r', 'netG_s2r', 'netG_r2s']
55 | self.L1loss = nn.L1Loss()
56 |
57 | if self.isTrain:
58 | self.netD_optimizer = optim.Adam(list(self.netD_s.parameters()) + list(self.netD_r.parameters()), lr=self.D_lr, betas=(0.5, 0.999))
59 | self.netG_optimizer = optim.Adam(list(self.netG_r2s.parameters()) + list(self.netG_s2r.parameters()), lr=self.G_lr, betas=(0.5, 0.999))
60 | self.optim_name = ['netD_optimizer', 'netG_optimizer']
61 | self._get_scheduler()
62 | self.loss_BCE = nn.BCEWithLogitsLoss()
63 | self._initialize_networks()
64 |
65 | # apex can only be applied to CUDA models
66 | if self.use_apex:
67 | self._init_apex(Num_losses=3)
68 |
69 | self._check_parallel()
70 |
71 | def _get_project_name(self):
72 | return 'train_style_translator_T'
73 |
74 | def _initialize_networks(self):
75 | for name in self.model_name:
76 | getattr(self, name).train().to(self.device)
77 | init_weights(getattr(self, name), net_name=name, init_type='normal', gain=0.02)
78 |
79 | def compute_D_loss(self, real_sample, fake_sample, netD):
80 | loss = 0
81 | syn_acc = 0
82 | real_acc = 0
83 |
84 | output = netD(fake_sample)
85 | label = torch.full((output.size()), self.syn_label, device=self.device)
86 |
87 | predSyn = (output > 0.5).to(self.device, dtype=torch.float32)
88 | total_num = torch.numel(output)
89 | syn_acc += (predSyn==label).type(torch.float32).sum().item()/total_num
90 | loss += self.loss_BCE(output, label)
91 |
92 | output = netD(real_sample)
93 | label = torch.full((output.size()), self.real_label, device=self.device)
94 |
95 | predReal = (output > 0.5).to(self.device, dtype=torch.float32)
96 | real_acc += (predReal==label).type(torch.float32).sum().item()/total_num
97 | loss += self.loss_BCE(output, label)
98 |
99 | return loss, syn_acc, real_acc
100 |
101 | def compute_G_loss(self, real_sample, synthetic_sample, r2s_rgb, s2r_rgb, reconstruct_real, reconstruct_syn):
102 | '''
103 | real_sample: [batch_size, 4, 240, 320] real rgb
104 | synthetic_sample: [batch_size, 4, 240, 320] synthetic rgb
105 | r2s_rgb: netG_r2s(real)
106 | s2r_rgb: netG_s2r(synthetic)
107 | '''
108 | loss = 0
109 |
110 | # identity loss if applicable
111 | if self.lambda_identity > 0:
112 | idt_real = self.netG_s2r(real_sample)[-1]
113 | idt_synthetic = self.netG_r2s(synthetic_sample)[-1]
114 | idt_loss = (self.L1loss(idt_real, real_sample) * self.lambda_real +
115 | self.L1loss(idt_synthetic, synthetic_sample) * self.lambda_synthetic) * self.lambda_identity
116 | else:
117 | idt_loss = 0
118 |
119 | # GAN loss
120 | real_pred = self.netD_r(s2r_rgb)
121 | real_label = torch.full(real_pred.size(), self.real_label, device=self.device)
122 | GAN_loss_real = self.loss_BCE(real_pred, real_label)
123 |
124 | syn_pred = self.netD_s(r2s_rgb)
125 | syn_label = torch.full(syn_pred.size(), self.real_label, device=self.device)
126 | GAN_loss_syn = self.loss_BCE(syn_pred, syn_label)
127 |
128 | GAN_loss = (GAN_loss_real + GAN_loss_syn) * self.lambda_GAN
129 |
130 | # cycle consistency loss
131 | rec_real_loss = self.L1loss(reconstruct_real, real_sample) * self.lambda_real
132 | rec_syn_loss = self.L1loss(reconstruct_syn, synthetic_sample) * self.lambda_synthetic
133 | rec_loss = rec_real_loss + rec_syn_loss
134 |
135 | loss += (idt_loss + GAN_loss + rec_loss)
136 |
137 | return loss, idt_loss, GAN_loss, rec_loss
138 |
139 | def train(self):
140 | phase = 'train'
141 | since = time.time()
142 | best_loss = float('inf')
143 |
144 | tensorboardX_iter_count = 0
145 | for epoch in range(self.total_epoch_num):
146 | print('\nEpoch {}/{}'.format(epoch+1, self.total_epoch_num))
147 | print('-' * 10)
148 | fn = open(self.train_log,'a')
149 | fn.write('\nEpoch {}/{}\n'.format(epoch+1, self.total_epoch_num))
150 | fn.write('--'*5+'\n')
151 | fn.close()
152 |
153 | iterCount = 0
154 |
155 | for sample_dict in self.dataloaders_xLabels_joint:
156 | imageListReal, depthListReal = sample_dict['real']
157 | imageListSyn, depthListSyn = sample_dict['syn']
158 |
159 | imageListSyn = imageListSyn.to(self.device)
160 | depthListSyn = depthListSyn.to(self.device)
161 | imageListReal = imageListReal.to(self.device)
162 | depthListReal = depthListReal.to(self.device)
163 |
164 | with torch.set_grad_enabled(phase=='train'):
165 | s2r_rgb = self.netG_s2r(imageListSyn)[-1]
166 | reconstruct_syn = self.netG_r2s(s2r_rgb)[-1]
167 |
168 | r2s_rgb = self.netG_r2s(imageListReal)[-1]
169 | reconstruct_real = self.netG_s2r(r2s_rgb)[-1]
170 |
171 | ############# update generator
172 | set_requires_grad([self.netD_r, self.netD_s], False)
173 |
174 | netG_loss = 0.
175 | self.netG_optimizer.zero_grad()
176 | netG_loss, G_idt_loss, G_GAN_loss, G_rec_loss = self.compute_G_loss(imageListReal, imageListSyn,
177 | r2s_rgb, s2r_rgb, reconstruct_real, reconstruct_syn)
178 |
179 | if self.use_apex:
180 | with amp.scale_loss(netG_loss, self.netG_optimizer, loss_id=0) as netG_loss_scaled:
181 | netG_loss_scaled.backward()
182 | else:
183 | netG_loss.backward()
184 |
185 | self.netG_optimizer.step()
186 |
187 | ############# update discriminator
188 | set_requires_grad([self.netD_r, self.netD_s], True)
189 |
190 | self.netD_optimizer.zero_grad()
191 | r2s_rgb_pool = self.generated_syn_pool.query(r2s_rgb)
192 | netD_s_loss, netD_s_syn_acc, netD_s_real_acc = self.compute_D_loss(imageListSyn, r2s_rgb.detach(), self.netD_s)
193 | s2r_rgb_pool = self.generated_real_pool.query(s2r_rgb)
194 | netD_r_loss, netD_r_syn_acc, netD_r_real_acc = self.compute_D_loss(imageListReal, s2r_rgb.detach(), self.netD_r)
195 |
196 | netD_loss = netD_s_loss + netD_r_loss
197 |
198 | if self.use_apex:
199 | with amp.scale_loss(netD_loss, self.netD_optimizer, loss_id=1) as netD_loss_scaled:
200 | netD_loss_scaled.backward()
201 | else:
202 | netD_loss.backward()
203 | self.netD_optimizer.step()
204 |
205 | iterCount += 1
206 |
207 | if self.use_tensorboardX:
208 | self.train_display_freq = len(self.dataloaders_xLabels_joint) # feel free to adjust the display frequency
209 | nrow = imageListReal.size()[0]
210 | if tensorboardX_iter_count % self.train_display_freq == 0:
211 | s2r_rgb_concat = torch.cat((imageListSyn, s2r_rgb, imageListReal, reconstruct_syn), dim=0)
212 | self.write_2_tensorboardX(self.train_SummaryWriter, s2r_rgb_concat, name='RGB: syn, s2r, real, reconstruct syn', mode='image',
213 | count=tensorboardX_iter_count, nrow=nrow)
214 |
215 | r2s_rgb_concat = torch.cat((imageListReal, r2s_rgb, imageListSyn, reconstruct_real), dim=0)
216 | self.write_2_tensorboardX(self.train_SummaryWriter, r2s_rgb_concat, name='RGB: real, r2s, synthetic, reconstruct real', mode='image',
217 | count=tensorboardX_iter_count, nrow=nrow)
218 |
219 | loss_val_list = [netD_loss, netG_loss]
220 | loss_name_list = ['netD_loss', 'netG_loss']
221 | self.write_2_tensorboardX(self.train_SummaryWriter, loss_val_list, name=loss_name_list, mode='scalar', count=tensorboardX_iter_count)
222 |
223 | tensorboardX_iter_count += 1
224 |
225 | if iterCount % 20 == 0:
226 | loss_summary = '\t{}/{} netD: {:.7f}, netG: {:.7f}'.format(iterCount, len(self.dataloaders_xLabels_joint), netD_loss, netG_loss)
227 | G_loss_summary = '\t\tG loss summary: netG: {:.7f}, idt_loss: {:.7f}, GAN_loss: {:.7f}, rec_loss: {:.7f}'.format(
228 | netG_loss, G_idt_loss, G_GAN_loss, G_rec_loss)
229 |
230 | print(loss_summary)
231 | print(G_loss_summary)
232 |
233 | fn = open(self.train_log,'a')
234 | fn.write(loss_summary + '\n')
235 | fn.write(G_loss_summary + '\n')
236 | fn.close()
237 |
238 | if (epoch+1) % self.save_steps == 0:
239 | self.save_models(['netG_r2s'], mode=epoch+1, save_list=['styleTranslator'])
240 |
241 | # take step in optimizer
242 | for scheduler in self.scheduler_list:
243 | scheduler.step()
244 | for optim in self.optim_name:
245 | lr = getattr(self, optim).param_groups[0]['lr']
246 | lr_update = 'Epoch {}/{} finished: {} learning rate = {:.7f}'.format(epoch+1, self.total_epoch_num, optim, lr)
247 | print(lr_update)
248 |
249 | fn = open(self.train_log,'a')
250 | fn.write(lr_update + '\n')
251 | fn.close()
252 |
253 | time_elapsed = time.time() - since
254 | print('\nTraining complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
255 |
256 | fn = open(self.train_log,'a')
257 | fn.write('\nTraining complete in {:.0f}m {:.0f}s\n'.format(time_elapsed // 60, time_elapsed % 60))
258 | fn.close()
259 |
260 | def evaluate(self, mode):
261 | pass
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .metrics import *
2 |
--------------------------------------------------------------------------------
/utils/image_pool.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 |
4 |
5 | class ImagePool():
6 | """This class implements an image buffer that stores previously generated images.
7 |
8 | This buffer enables us to update discriminators using a history of generated images
9 | rather than the ones produced by the latest generators.
10 | """
11 |
12 | def __init__(self, pool_size):
13 | """Initialize the ImagePool class
14 |
15 | Parameters:
16 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
17 | """
18 | self.pool_size = pool_size
19 | if self.pool_size > 0: # create an empty pool
20 | self.num_imgs = 0
21 | self.images = []
22 |
23 | def query(self, images):
24 | """Return an image from the pool.
25 |
26 | Parameters:
27 | images: the latest generated images from the generator
28 |
29 | Returns images from the buffer.
30 |
31 | By 50/100, the buffer will return input images.
32 | By 50/100, the buffer will return images previously stored in the buffer,
33 | and insert the current images to the buffer.
34 | """
35 | if self.pool_size == 0: # if the buffer size is 0, do nothing
36 | return images
37 | return_images = []
38 | for image in images:
39 | image = torch.unsqueeze(image.data, 0)
40 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
41 | self.num_imgs = self.num_imgs + 1
42 | self.images.append(image)
43 | return_images.append(image)
44 | else:
45 | p = random.uniform(0, 1)
46 | if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
47 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
48 | tmp = self.images[random_id].clone()
49 | self.images[random_id] = image
50 | return_images.append(tmp)
51 | else: # by another 50% chance, the buffer will return the current image
52 | return_images.append(image)
53 | return_images = torch.cat(return_images, 0) # collect all the images and return
54 | return return_images
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import math
4 |
5 |
6 |
7 | def log10(x):
8 | """Convert a new tensor with the base-10 logarithm of the elements of x. """
9 | return torch.log(x) / math.log(10)
10 |
11 | class Result(object):
12 | def __init__(self, mask_min, mask_max):
13 | self.irmse, self.imae = 0, 0
14 | self.mse, self.rmse, self.mae = 0, 0, 0
15 | self.absrel, self.lg10 = 0, 0
16 | self.delta1, self.delta2, self.delta3 = 0, 0, 0
17 | self.data_time, self.gpu_time = 0, 0
18 | self.mask_min = mask_min
19 | self.mask_max = mask_max
20 |
21 | def set_to_worst(self):
22 | self.irmse, self.imae = np.inf, np.inf
23 | self.mse, self.rmse, self.mae = np.inf, np.inf, np.inf
24 | self.absrel, self.lg10 = np.inf, np.inf
25 | self.delta1, self.delta2, self.delta3 = 0, 0, 0
26 | self.data_time, self.gpu_time = 0, 0
27 |
28 | def update(self, irmse, imae, mse, rmse, mae, absrel, lg10, delta1, delta2, delta3, gpu_time, data_time):
29 | self.irmse, self.imae = irmse, imae
30 | self.mse, self.rmse, self.mae = mse, rmse, mae
31 | self.absrel, self.lg10 = absrel, lg10
32 | self.delta1, self.delta2, self.delta3 = delta1, delta2, delta3
33 | self.data_time, self.gpu_time = data_time, gpu_time
34 |
35 | def evaluate(self, output, target):
36 |
37 | # not quite sure whether this is useful
38 | # target[target < self.mask_min] = self.mask_min
39 | # target[target > self.mask_max] = self.mask_max
40 |
41 | valid_mask = np.logical_and(target > self.mask_min, target < self.mask_max)
42 | output = output[valid_mask]
43 | target = target[valid_mask]
44 |
45 | abs_diff = (output - target).abs()
46 | diff = (output - target)
47 |
48 | self.mse = float((torch.pow(abs_diff, 2)).mean())
49 | self.rmse = math.sqrt(self.mse)
50 | self.rmselog = math.sqrt(float(((torch.log(target) - torch.log(output)) ** 2).mean()))
51 |
52 | self.mae = float(abs_diff.mean())
53 | self.lg10 = float((log10(output) - log10(target)).abs().mean())
54 | self.absrel = float((abs_diff / target).mean())
55 | self.sqrel = float(((diff ** 2) / target).mean())
56 |
57 | maxRatio = torch.max(output / target, target / output)
58 | self.delta1 = float((maxRatio < 1.25).float().mean())
59 | self.delta2 = float((maxRatio < 1.25 ** 2).float().mean())
60 | self.delta3 = float((maxRatio < 1.25 ** 3).float().mean())
61 | self.data_time = 0
62 | self.gpu_time = 0
63 |
64 | inv_output = 1 / output
65 | inv_target = 1 / target
66 | abs_inv_diff = (inv_output - inv_target).abs()
67 | self.irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean())
68 | self.imae = float(abs_inv_diff.mean())
69 |
70 |
71 | class Result_withIdx(object):
72 | def __init__(self, mask_min, mask_max):
73 | self.irmse, self.imae = 0, 0
74 | self.mse, self.rmse, self.mae = 0, 0, 0
75 | self.absrel, self.lg10 = 0, 0
76 | self.delta1, self.delta2, self.delta3 = 0, 0, 0
77 | self.data_time, self.gpu_time = 0, 0
78 | self.mask_min = mask_min
79 | self.mask_max = mask_max
80 |
81 | def set_to_worst(self):
82 | self.irmse, self.imae = np.inf, np.inf
83 | self.mse, self.rmse, self.mae = np.inf, np.inf, np.inf
84 | self.absrel, self.lg10 = np.inf, np.inf
85 | self.delta1, self.delta2, self.delta3 = 0, 0, 0
86 | self.data_time, self.gpu_time = 0, 0
87 |
88 | def update(self, irmse, imae, mse, rmse, mae, absrel, lg10, delta1, delta2, delta3, gpu_time, data_time):
89 | self.irmse, self.imae = irmse, imae
90 | self.mse, self.rmse, self.mae = mse, rmse, mae
91 | self.absrel, self.lg10 = absrel, lg10
92 | self.delta1, self.delta2, self.delta3 = delta1, delta2, delta3
93 | self.data_time, self.gpu_time = data_time, gpu_time
94 |
95 | def evaluate(self, output, target, idx_tensor):
96 | # idx_tensor should have the same size as output and target
97 |
98 | valid_mask = np.logical_and(target > self.mask_min, target < self.mask_max)
99 | # print(valid_mask.shape, type(valid_mask))
100 | # print(valid_mask)
101 | # print(valid_mask.shape, idx_tensor.shape)
102 | final_mask = valid_mask & idx_tensor
103 | # print(final_mask.shape)
104 | output = output[final_mask]
105 | target = target[final_mask]
106 |
107 | abs_diff = (output - target).abs()
108 | diff = (output - target)
109 |
110 | self.mse = float((torch.pow(abs_diff, 2)).mean())
111 | self.rmse = math.sqrt(self.mse)
112 | self.rmselog = math.sqrt(float(((torch.log(target) - torch.log(output)) ** 2).mean()))
113 |
114 | self.mae = float(abs_diff.mean())
115 | self.lg10 = float((log10(output) - log10(target)).abs().mean())
116 | self.absrel = float((abs_diff / target).mean())
117 | self.sqrel = float(((diff ** 2) / target).mean())
118 |
119 | maxRatio = torch.max(output / target, target / output)
120 | self.delta1 = float((maxRatio < 1.25).float().mean())
121 | self.delta2 = float((maxRatio < 1.25 ** 2).float().mean())
122 | self.delta3 = float((maxRatio < 1.25 ** 3).float().mean())
123 | self.data_time = 0
124 | self.gpu_time = 0
125 |
126 | inv_output = 1 / output
127 | inv_target = 1 / target
128 | abs_inv_diff = (inv_output - inv_target).abs()
129 | self.irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean())
130 | self.imae = float(abs_inv_diff.mean())
131 |
132 |
133 | def miou(pred, target, n_classes=12):
134 | ious = []
135 | pred = pred.view(-1)
136 | target = target.view(-1)
137 |
138 | # Ignore IoU for background class ("0")
139 | for cls in range(0, n_classes): # This goes from 1:n_classes-1 -> class "0" is ignored
140 | pred_inds = pred == cls
141 | target_inds = target == cls
142 | intersection = (pred_inds[target_inds]).long().sum().data.cpu()[0] # Cast to long to prevent overflows
143 | union = pred_inds.long().sum().data.cpu()[0] + target_inds.long().sum().data.cpu()[0] - intersection
144 | if union == 0: ious.append(float('nan')) # If there is no ground truth, do not include in evaluation
145 | else:ious.append(float(intersection) / float(max(union, 1)))
146 | return np.array(ious)
147 |
148 |
149 | def im2col_sliding_broadcasting(A, BSZ, stepsize=1):
150 | # Parameters
151 | M,N = A.shape[0],A.shape[1]
152 | col_extent = N - BSZ[1] + 1
153 | row_extent = M - BSZ[0] + 1
154 |
155 | # Get Starting block indices
156 | start_idx = np.arange(BSZ[0])[:,None]*N + np.arange(BSZ[1])
157 |
158 | # Get offsetted indices across the height and width of input array
159 | offset_idx = np.arange(row_extent)[:,None]*N + np.arange(col_extent)
160 |
161 | # Get all actual indices & index into input array for final output
162 | return np.take (A,start_idx.ravel()[:,None] + offset_idx.ravel()[::stepsize])
163 |
164 |
165 | def rgb2ycbcr(im):
166 | cbcr = np.empty_like(im)
167 | r = im[:,:,0]
168 | g = im[:,:,1]
169 | b = im[:,:,2]
170 | # Y
171 | cbcr[:,:,0] = .299 * r + .587 * g + .114 * b
172 | # Cb
173 | cbcr[:,:,1] = 128 - .169 * r - .331 * g + .5 * b
174 | # Cr
175 | cbcr[:,:,2] = 128 + .5 * r - .419 * g - .081 * b
176 | return cbcr # np.uint8(cbcr)
177 |
178 | def ycbcr2rgb(im):
179 | rgb = np.empty_like(im)
180 | y = im[:,:,0]
181 | cb = im[:,:,1] - 128
182 | cr = im[:,:,2] - 128
183 | # R
184 | rgb[:,:,0] = y + 1.402 * cr
185 | # G
186 | rgb[:,:,1] = y - .34414 * cb - .71414 * cr
187 | # B
188 | rgb[:,:,2] = y + 1.772 * cb
189 | return rgb # np.uint8(rgb)
190 |
191 |
192 | def img_greyscale(img):
193 | return 0.299 * img[:,:,0] + 0.587 * img[:,:,1] + 0.114 * img[:,:,2]
194 |
--------------------------------------------------------------------------------