├── LICENSE ├── README.md ├── dataloader ├── Joint_xLabel_dataLoader.py ├── NYUv2_dataLoader.py └── __init__.py ├── figures ├── 4343-teaser.gif └── demo.png ├── loss.py ├── models ├── __init__.py ├── attention_networks.py ├── depth_generator_networks.py └── discriminator_networks.py ├── train.py ├── training ├── base_model.py ├── finetune_the_whole_system_with_depth_loss.py ├── jointly_train_depth_predictor_D_and_attention_module_A.py ├── train_initial_attention_module_A.py ├── train_initial_depth_predictor_D.py ├── train_inpainting_module_I.py └── train_style_translator_T.py └── utils ├── __init__.py ├── image_pool.py └── metrics.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Yunhan Zhao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ARC 2 | This repo contains the Pytorch implementation of: 3 | 4 | [Domain Decluttering: Simplifying Images to Mitigate Synthetic-Real Domain Shift and Improve Depth Estimation](http://openaccess.thecvf.com/content_CVPR_2020/html/Zhao_Domain_Decluttering_Simplifying_Images_to_Mitigate_Synthetic-Real_Domain_Shift_and_CVPR_2020_paper.html) 5 | 6 | [Yunhan Zhao](https://www.ics.uci.edu/~yunhaz5/), [Shu Kong](http://www.cs.cmu.edu/~shuk/), [Daeyun Shin](https://research.dshin.org/) and [Charless Fowlkes](https://www.ics.uci.edu/~fowlkes/) 7 | 8 | CVPR 2020 9 | 10 | For more details, please check our [project website](https://www.ics.uci.edu/~yunhaz5/cvpr2020/domain_decluttering.html) 11 | 12 |

13 | 14 |

15 | 16 | ### Abstract 17 | Leveraging synthetically rendered data offers great potential to improve monocular depth estimation and other geometric estimation tasks, but closing the synthetic-real domain gap is a non-trivial and important task. While much recent work has focused on unsupervised domain adaptation, we consider a more realistic scenario where a large amount of synthetic training data is supplemented by a small set of real images with ground-truth. In this setting, we find that existing domain translation approaches are difficult to train and offer little advantage over simple baselines that use a mix of real and synthetic data. A key failure mode is that real-world images contain novel objects and clutter not present in synthetic training. This high-level domain shift isn’t handled by existing image translation models. 18 | 19 | Based on these observations, we develop an attention module that learns to identify and remove difficult out-ofdomain regions in real images in order to improve depth prediction for a model trained primarily on synthetic data. We carry out extensive experiments to validate our attendremove-complete approach (ARC) and find that it significantly outperforms state-of-the-art domain adaptation methods for depth prediction. Visualizing the removed regions provides interpretable insights into the synthetic-real domain gap. 20 | 21 | 22 | 23 | ## Reference 24 | If you find our work useful in your research please consider citing our paper: 25 | ``` 26 | @inproceedings{zhao2020domain, 27 | title={Domain Decluttering: Simplifying Images to Mitigate Synthetic-Real Domain Shift and Improve Depth Estimation}, 28 | author={Zhao, Yunhan and Kong, Shu and Shin, Daeyun and Fowlkes, Charless}, 29 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 30 | pages={3330--3340}, 31 | year={2020} 32 | } 33 | ``` 34 | 35 | ## Contents 36 | 37 | - [Requirments](#requirements) 38 | - [Training Precedures](#training-precedures) 39 | - [Evaluations](#evaluations) 40 | - [Pretrained Models](#pretrained-models) 41 | 42 | 43 | ## Requirements 44 | 1. Python 3.6 with Ubuntu 16.04 45 | 2. Pytorch 1.1.0 46 | 3. Apex 0.1 (optional) 47 | 48 | You also need other third-party libraries, such as numpy, pillow, torchvision, and tensorboardX (optional) to run the code. We use apex when training all models but it is not strictly required to run the code. 49 | 50 | ## Datasets 51 | You have to download NYUv2 and PBRS and place them in the following structure to load the data. 52 | #### Dataset Structure 53 | ``` 54 | NYUv2 (real) 55 | | train 56 | | rgb 57 | | depth 58 | | test 59 | | rgb 60 | | depth 61 | PBRS (synthetic) 62 | | train 63 | | rgb 64 | | depth 65 | ``` 66 | You need to download Kitti and vKitti for Kitti experiments and follow the same structure. 67 | ## Training Precedures 68 | - [1 Train Initial Depth Predictor D](#1-Train-Initial-Depth-Predictor-D) 69 | - [2 Train Style Translator T (pretrain T)](#2-Train-Style-Translator-T) 70 | - [3 Train Initial Attention Module A](#3-train-initial-attention-module) 71 | - [4 Train Inpainting Module I (pretrain I)](#4-train-inpainting-module-I) 72 | - [5 Jointly Train Depth Predictor D and Attention Module A (pretrain A, D)](#5-jointly-train-depth-predictor-D-and-attention-module-A) 73 | - [6 Finetune the Whole System with Depth Loss](#6-finetune-the-whole-system-with-depth-loss) 74 | 75 | All training steps use one common `train.py` file so please make sure to comment/uncomment the correct line for each step. 76 | ```bash 77 | CUDA_VISIBLE_DEVICES= python train.py \ 78 | --path_to_NYUv2= \ 79 | --path_to_PBRS= \ 80 | --batch_size=4 --total_epoch_num=500 --isTrain --eval_batch_size=1 81 | ``` 82 | `batch_size` and `eval_batch_size` are flexible to change given your working environment. 83 | #### 1 Train Initial Depth Predictor D 84 | Train an initial depth predictor D with real and synthetic data. The best model is picked by the one with minimum L1 loss. The checkpoints are saved in `./experiments/train_initial_depth_predictor_D/`. 85 | #### 2 Train Style Translator T (pretrain T) 86 | Train the style translator T and yield a good initialization for style translator T. The best model is picked by visual inspection & training loss curves. 87 | #### 3 Train Initial Attention Module A 88 | Train an initial attention module A from scratch with descending $\tau$ values. 89 | #### 4 Train Inpainting Module I (pretrain I) 90 | Train the inpainting module I with T (from step 2) and A (from step 3). This leads to a good initalization to I. 91 | #### 5 Jointly Train Depth Predictor D and Attention Module A (pretrain A, D) 92 | Further jointly train depth predictor D and attention module A together with D (from step 1), T (from step 2), A (from step 3) and I (from step 4). The A and D learned from this step is the good initialization before finetuning the whole system together with depth loss. In step 5 and later step 6, we train for relatively less epochs, i.e., `total_epoch_num = 150`. 93 | #### 6 Finetune the Whole System with Depth Loss (Modular Coordinate Descent) 94 | Lastly, we finetune the whole system with depth loss terms using D (from step 5), T (from step 2), A (from step 5) and I (from step 4). The experimental results on NYUv2 dataset we reported in the paper are the evaluation results from this step (one step finetuning). 95 | 96 | ## Evaluations 97 | Evaluate the final results 98 | ```bash 99 | CUDA_VISIBLE_DEVICES= python train.py \ 100 | --path_to_NYUv2= \ 101 | --path_to_PBRS= \ 102 | --eval_batch_size=1 103 | ``` 104 | Make sure uncomment step 6 in the `train.py` file. If you want to evaluate with your own data, please place your own data under `/test` with the dataset structure described above. 105 | 106 | ## Pretrained Models 107 | Pretrained models for the NYUv2 & PBRS experiment are available [here](https://drive.google.com/drive/folders/1gB4dE3qoHrNGQqqU7cea7Z3MouPIJA9m?usp=sharing). 108 | 109 | Pretrained models for the Kitti & vKitti experiment are available [here](https://drive.google.com/drive/folders/1XzCXm91-HgXm1OKx358yKFN-aSVZGqpM?usp=sharing). 110 | 111 | ## Acknowledgments 112 | This code is developed based on [T2Net](https://github.com/lyndonzheng/Synthetic2Realistic) and [Pytorch-CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). 113 | 114 | ## Questions 115 | Please feel free to email me at (yunhaz5 [at] ics [dot] uci [dot] edu) if you have any questions. -------------------------------------------------------------------------------- /dataloader/Joint_xLabel_dataLoader.py: -------------------------------------------------------------------------------- 1 | import os, sys, random, time, copy 2 | from skimage import io, transform 3 | import numpy as np 4 | import scipy.io as sio 5 | from scipy import misc 6 | import matplotlib.pyplot as plt 7 | import PIL.Image 8 | 9 | import skimage.transform 10 | import blosc, struct 11 | 12 | import torch 13 | from torch.utils.data import Dataset, DataLoader 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | from torch.optim import lr_scheduler 17 | import torch.nn.functional as F 18 | from torch.autograd import Variable 19 | 20 | import torchvision 21 | from torchvision import datasets, models, transforms 22 | 23 | IMG_EXTENSIONS = [ 24 | '.jpg', '.JPG', '.jpeg', '.JPEG', 25 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.bin' 26 | ] 27 | 28 | def is_image_file(filename): 29 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 30 | 31 | class Joint_xLabel_train_dataLoader(Dataset): 32 | def __init__(self, real_root_dir, syn_root_dir, size=[240, 320], rgb=True, downsampleDepthFactor=1, paired_data=False): 33 | self.real_root_dir = real_root_dir 34 | self.syn_root_dir = syn_root_dir 35 | self.size = size 36 | self.rgb = rgb 37 | self.current_set_len = 0 38 | self.real_path2files = [] 39 | self.syn_path2files = [] 40 | self.downsampleDepthFactor = downsampleDepthFactor 41 | self.NYU_MIN_DEPTH_CLIP = 0.0 42 | self.NYU_MAX_DEPTH_CLIP = 10.0 43 | self.paired_data = paired_data # whether 1 to 1 matching 44 | self.augment = None # whether to augment each batch data 45 | self.x_labels = False # whether to collect extra labels in synthetic data, such as segmentation or instance boundaries 46 | 47 | self.set_name = 'train' # Joint_xLabel_train_dataLoader is only used in training phase 48 | 49 | real_curfilenamelist = os.listdir(os.path.join(self.real_root_dir, self.set_name, 'rgb')) 50 | for fname in sorted(real_curfilenamelist): 51 | if is_image_file(fname): 52 | path = os.path.join(self.real_root_dir, self.set_name, 'rgb', fname) 53 | self.real_path2files.append(path) 54 | 55 | self.real_set_len = len(self.real_path2files) 56 | 57 | syn_curfilenamelist = os.listdir(os.path.join(self.syn_root_dir, self.set_name, 'rgb')) 58 | for fname in sorted(syn_curfilenamelist): 59 | if is_image_file(fname): 60 | path = os.path.join(self.syn_root_dir, self.set_name, 'rgb', fname) 61 | self.syn_path2files.append(path) 62 | 63 | self.syn_set_len = len(self.syn_path2files) 64 | 65 | self.TF2tensor = transforms.ToTensor() 66 | self.TF2PIL = transforms.ToPILImage() 67 | self.TFNormalize = transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) 68 | self.funcResizeTensor = nn.Upsample(size=self.size, mode='nearest', align_corners=None) 69 | self.funcResizeDepth = nn.Upsample(size=[int(self.size[0]*self.downsampleDepthFactor), 70 | int(self.size[1]*self.downsampleDepthFactor)], 71 | mode='nearest', align_corners=None) 72 | 73 | def __len__(self): 74 | # looping over real dataset 75 | return self.real_set_len 76 | 77 | def __getitem__(self, idx): 78 | real_filename = self.real_path2files[idx % self.real_set_len] 79 | rand_idx = random.randint(0, self.syn_set_len - 1) 80 | if self.paired_data: 81 | assert self.real_set_len == self.syn_set_len 82 | syn_filename = self.syn_path2files[idx] 83 | 84 | else: 85 | syn_filename = self.syn_path2files[rand_idx] 86 | 87 | if np.random.random(1) > 0.5: 88 | self.augment = True 89 | else: 90 | self.augment = False 91 | 92 | real_img, real_depth = self.fetch_img_depth(real_filename) 93 | syn_img, syn_depth = self.fetch_img_depth(syn_filename) 94 | return_dict = {'real': [real_img, real_depth], 'syn': [syn_img, syn_depth]} 95 | 96 | if self.x_labels: 97 | # not really used in this project 98 | extra_label_list = self.fetch_syn_extra_labels(syn_filename) 99 | return_dict = {'real': [real_img, real_depth], 'syn': [syn_img, syn_depth], 'syn_extra_labels': extra_label_list} 100 | return return_dict 101 | 102 | def fetch_img_depth(self, filename): 103 | image = PIL.Image.open(filename) 104 | image = np.array(image, dtype=np.float32) / 255. 105 | 106 | if self.set_name == 'train': 107 | depthname = filename.replace('rgb','depth_inpainted').replace('png','bin') 108 | else: 109 | # use real depth for validation and testing 110 | depthname = filename.replace('rgb','depth').replace('png','bin') 111 | 112 | depth = read_array_compressed(depthname) 113 | 114 | if self.set_name=='train' and self.augment: 115 | image = np.fliplr(image).copy() 116 | depth = np.fliplr(depth).copy() 117 | 118 | # rescale depth samples in training phase 119 | if self.set_name == 'train': 120 | depth = np.clip(depth, self.NYU_MIN_DEPTH_CLIP, self.NYU_MAX_DEPTH_CLIP) # [0, 10] 121 | depth = ((depth/self.NYU_MAX_DEPTH_CLIP) - 0.5) * 2.0 # [-1, 1] 122 | 123 | image = self.TF2tensor(image) 124 | image = self.TFNormalize(image) 125 | image = image.unsqueeze(0) 126 | 127 | depth = np.expand_dims(depth, 2) 128 | depth = self.TF2tensor(depth) 129 | depth = depth.unsqueeze(0) 130 | 131 | if "nyu" in filename: 132 | image = processNYU_tensor(image) 133 | depth = processNYU_tensor(depth) 134 | 135 | image = self.funcResizeTensor(image) 136 | depth = self.funcResizeTensor(depth) 137 | 138 | if self.downsampleDepthFactor != 1: 139 | depth = self.funcResizeDepth(depth) 140 | 141 | if self.rgb: 142 | image = image.squeeze(0) 143 | else: 144 | image = image.mean(1) 145 | image = image.squeeze(0).unsqueeze(0) 146 | 147 | depth = depth.squeeze(0) 148 | return image, depth 149 | 150 | def fetch_syn_extra_labels(self, filename): 151 | # currently only fetch segmentation labels and instance boundaries 152 | seg_name = filename.replace('rgb','semantic_seg') 153 | ib_name = filename.replace('rgb','instance_boundary') 154 | 155 | seg_np = np.array(PIL.Image.open(seg_name), dtype=np.float32) 156 | ib_np = np.array(PIL.Image.open(ib_name), dtype=np.float32) 157 | 158 | if self.set_name=='train' and self.augment: 159 | seg_np = np.fliplr(seg_np).copy() 160 | ib_np = np.fliplr(ib_np).copy() 161 | 162 | seg_np = np.expand_dims(seg_np, 2) 163 | seg_tensor = self.TF2tensor(seg_np) 164 | 165 | ib_np = np.expand_dims(ib_np, 2) 166 | ib_tensor = self.TF2tensor(ib_np) # size [1, 240, 320] 167 | 168 | return [seg_tensor, ib_tensor] 169 | 170 | def ensure_dir_exists(dirname, log_mkdir=True): 171 | """ 172 | Creates a directory if it does not already exist. 173 | :param dirname: Path to a directory. 174 | :param log_mkdir: If true, a debug message is logged when creating a new directory. 175 | :return: Same as `dirname`. 176 | """ 177 | dirname = path.realpath(path.expanduser(dirname)) 178 | if not path.isdir(dirname): 179 | # `exist_ok` in case of race condition. 180 | os.makedirs(dirname, exist_ok=True) 181 | if log_mkdir: 182 | log.debug('mkdir -p {}'.format(dirname)) 183 | return dirname 184 | 185 | def read_array(filename, dtype=np.float32): 186 | """ 187 | Reads a multi-dimensional array file with the following format: 188 | [int32_t number of dimensions n] 189 | [int32_t dimension 0], [int32_t dimension 1], ..., [int32_t dimension n] 190 | [float or int data] 191 | 192 | :param filename: Path to the array file. 193 | :param dtype: This must be consistent with the saved data type. 194 | :return: A numpy array. 195 | """ 196 | with open(filename, mode='rb') as f: 197 | content = f.read() 198 | return bytes_to_array(content, dtype=dtype) 199 | 200 | def read_array_compressed(filename, dtype=np.float32): 201 | """ 202 | Reads a multi-dimensional array file compressed with Blosc. 203 | Otherwise the same as `read_float32_array`. 204 | """ 205 | with open(filename, mode='rb') as f: 206 | compressed = f.read() 207 | decompressed = blosc.decompress(compressed) 208 | return bytes_to_array(decompressed, dtype=dtype) 209 | 210 | def save_array_compressed(filename, arr: np.ndarray): 211 | """ 212 | See `read_array`. 213 | """ 214 | encoded = array_to_bytes(arr) 215 | compressed = blosc.compress(encoded, arr.dtype.itemsize, clevel=7, shuffle=True, cname='lz4hc') 216 | with open(filename, mode='wb') as f: 217 | f.write(compressed) 218 | log.info('Saved {}'.format(filename)) 219 | 220 | def array_to_bytes(arr: np.ndarray): 221 | """ 222 | Dumps a numpy array into a raw byte string. 223 | :param arr: A numpy array. 224 | :return: A `bytes` string. 225 | """ 226 | shape = arr.shape 227 | ndim = arr.ndim 228 | ret = struct.pack('i', ndim) + struct.pack('i' * ndim, *shape) + arr.tobytes(order='C') 229 | return ret 230 | 231 | def bytes_to_array(s: bytes, dtype=np.float32): 232 | """ 233 | Unpacks a byte string into a numpy array. 234 | :param s: A byte string containing raw array data. 235 | :param dtype: Data type. 236 | :return: A numpy array. 237 | """ 238 | dims = struct.unpack('i', s[:4])[0] 239 | assert 0 <= dims < 1000 # Sanity check. 240 | shape = struct.unpack('i' * dims, s[4:4 * dims + 4]) 241 | for dim in shape: 242 | assert dim > 0 243 | ret = np.frombuffer(s[4 * dims + 4:], dtype=dtype) 244 | assert ret.size == np.prod(shape), (ret.size, shape) 245 | ret.shape = shape 246 | return ret.copy() 247 | 248 | def processNYU_tensor(X): 249 | X = X[:,:,45:471,41:601] 250 | return X 251 | 252 | def cropPBRS(X): 253 | if len(X.shape)==3: return X[45:471,41:601,:] 254 | else: return X[45:471,41:601] 255 | -------------------------------------------------------------------------------- /dataloader/NYUv2_dataLoader.py: -------------------------------------------------------------------------------- 1 | import os, random, time, copy, sys 2 | from skimage import io, transform 3 | import numpy as np 4 | import os.path as path 5 | import scipy.io as sio 6 | from scipy import misc 7 | import matplotlib.pyplot as plt 8 | import PIL.Image 9 | 10 | import skimage.transform 11 | import blosc, struct 12 | 13 | import torch 14 | from torch.utils.data import Dataset, DataLoader 15 | import torch.nn as nn 16 | import torch.optim as optim 17 | from torch.optim import lr_scheduler 18 | import torch.nn.functional as F 19 | from torch.autograd import Variable 20 | 21 | import torchvision 22 | from torchvision import datasets, models, transforms 23 | 24 | class NYUv2_dataLoader(Dataset): 25 | def __init__(self, root_dir, set_name='train', size=[240, 320], rgb=True, downsampleDepthFactor=1, training_depth='inpaint'): 26 | # training depth option: inpaint | original 27 | self.root_dir = root_dir 28 | self.size = size 29 | self.set_name = set_name 30 | self.training_depth = training_depth 31 | self.rgb = rgb 32 | self.current_set_len = 0 33 | self.path2files = [] 34 | self.downsampleDepthFactor = downsampleDepthFactor 35 | self.NYU_MIN_DEPTH_CLIP = 0.0 36 | self.NYU_MAX_DEPTH_CLIP = 10.0 37 | 38 | curfilenamelist = os.listdir(path.join(self.root_dir, self.set_name, 'rgb')) 39 | self.path2files += [path.join(self.root_dir, self.set_name, 'rgb')+'/'+ curfilename for curfilename in curfilenamelist] 40 | self.current_set_len = len(self.path2files) 41 | 42 | self.TF2tensor = transforms.ToTensor() 43 | self.TF2PIL = transforms.ToPILImage() 44 | self.TFNormalize = transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) 45 | self.funcResizeTensor = nn.Upsample(size=self.size, mode='nearest', align_corners=None) 46 | self.funcResizeDepth = nn.Upsample(size=[int(self.size[0]*self.downsampleDepthFactor), 47 | int(self.size[1]*self.downsampleDepthFactor)], 48 | mode='nearest', align_corners=None) 49 | 50 | def __len__(self): 51 | return self.current_set_len 52 | 53 | def __getitem__(self, idx): 54 | filename = self.path2files[idx] 55 | image = PIL.Image.open(filename) 56 | image = np.array(image).astype(np.float32) / 255. 57 | 58 | if self.set_name == 'train': 59 | if self.training_depth == 'original': 60 | depthname = filename.replace('rgb','depth').replace('png','bin') 61 | else: 62 | depthname = filename.replace('rgb','depth_inpainted').replace('png','bin') 63 | else: 64 | # use real depth for validation and testing 65 | depthname = filename.replace('rgb','depth').replace('png','bin') 66 | 67 | depth = read_array_compressed(depthname) 68 | 69 | if self.set_name =='train' and np.random.random(1)>0.5: 70 | image = np.fliplr(image).copy() 71 | depth = np.fliplr(depth).copy() 72 | 73 | # rescale depth samples in training phase 74 | if self.set_name == 'train': 75 | depth = np.clip(depth, self.NYU_MIN_DEPTH_CLIP, self.NYU_MAX_DEPTH_CLIP) # [0, 10] 76 | depth = ((depth/self.NYU_MAX_DEPTH_CLIP) - 0.5) * 2.0 # [-1, 1] 77 | 78 | image = self.TF2tensor(image) 79 | image = self.TFNormalize(image) 80 | image = image.unsqueeze(0) 81 | 82 | depth = np.expand_dims(depth, 2) 83 | depth = self.TF2tensor(depth) 84 | depth = depth.unsqueeze(0) 85 | 86 | image = processNYU_tensor(image) 87 | depth = processNYU_tensor(depth) 88 | 89 | image = self.funcResizeTensor(image) 90 | depth = self.funcResizeTensor(depth) 91 | 92 | if self.downsampleDepthFactor != 1: 93 | depth = self.funcResizeDepth(depth) 94 | 95 | if self.rgb: 96 | image = image.squeeze(0) 97 | else: 98 | image = image.mean(1) 99 | image = image.squeeze(0).unsqueeze(0) 100 | 101 | depth = depth.squeeze(0) 102 | return image, depth 103 | 104 | def ensure_dir_exists(dirname, log_mkdir=True): 105 | """ 106 | Creates a directory if it does not already exist. 107 | :param dirname: Path to a directory. 108 | :param log_mkdir: If true, a debug message is logged when creating a new directory. 109 | :return: Same as `dirname`. 110 | """ 111 | dirname = path.realpath(path.expanduser(dirname)) 112 | if not path.isdir(dirname): 113 | # `exist_ok` in case of race condition. 114 | os.makedirs(dirname, exist_ok=True) 115 | if log_mkdir: 116 | log.debug('mkdir -p {}'.format(dirname)) 117 | return dirname 118 | 119 | 120 | def read_array(filename, dtype=np.float32): 121 | """ 122 | Reads a multi-dimensional array file with the following format: 123 | [int32_t number of dimensions n] 124 | [int32_t dimension 0], [int32_t dimension 1], ..., [int32_t dimension n] 125 | [float or int data] 126 | 127 | :param filename: Path to the array file. 128 | :param dtype: This must be consistent with the saved data type. 129 | :return: A numpy array. 130 | """ 131 | with open(filename, mode='rb') as f: 132 | content = f.read() 133 | return bytes_to_array(content, dtype=dtype) 134 | 135 | 136 | def read_array_compressed(filename, dtype=np.float32): 137 | """ 138 | Reads a multi-dimensional array file compressed with Blosc. 139 | Otherwise the same as `read_float32_array`. 140 | """ 141 | with open(filename, mode='rb') as f: 142 | compressed = f.read() 143 | decompressed = blosc.decompress(compressed) 144 | return bytes_to_array(decompressed, dtype=dtype) 145 | 146 | 147 | def save_array_compressed(filename, arr: np.ndarray): 148 | """ 149 | See `read_array`. 150 | """ 151 | encoded = array_to_bytes(arr) 152 | compressed = blosc.compress(encoded, arr.dtype.itemsize, clevel=7, shuffle=True, cname='lz4hc') 153 | with open(filename, mode='wb') as f: 154 | f.write(compressed) 155 | log.info('Saved {}'.format(filename)) 156 | 157 | 158 | def array_to_bytes(arr: np.ndarray): 159 | """ 160 | Dumps a numpy array into a raw byte string. 161 | :param arr: A numpy array. 162 | :return: A `bytes` string. 163 | """ 164 | shape = arr.shape 165 | ndim = arr.ndim 166 | ret = struct.pack('i', ndim) + struct.pack('i' * ndim, *shape) + arr.tobytes(order='C') 167 | return ret 168 | 169 | 170 | def bytes_to_array(s: bytes, dtype=np.float32): 171 | """ 172 | Unpacks a byte string into a numpy array. 173 | :param s: A byte string containing raw array data. 174 | :param dtype: Data type. 175 | :return: A numpy array. 176 | """ 177 | dims = struct.unpack('i', s[:4])[0] 178 | assert 0 <= dims < 1000 # Sanity check. 179 | shape = struct.unpack('i' * dims, s[4:4 * dims + 4]) 180 | for dim in shape: 181 | assert dim > 0 182 | ret = np.frombuffer(s[4 * dims + 4:], dtype=dtype) 183 | assert ret.size == np.prod(shape), (ret.size, shape) 184 | ret.shape = shape 185 | return ret.copy() 186 | 187 | 188 | def processNYU_tensor(X): 189 | X = X[:,:,45:471,41:601] 190 | return X 191 | 192 | 193 | def cropPBRS(X): 194 | if len(X.shape)==3: return X[45:471,41:601,:] 195 | else: return X[45:471,41:601] 196 | -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunhan-zhao/ARC/13add94311bfa22660e34200ec8a1dd97a66faa3/dataloader/__init__.py -------------------------------------------------------------------------------- /figures/4343-teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunhan-zhao/ARC/13add94311bfa22660e34200ec8a1dd97a66faa3/figures/4343-teaser.gif -------------------------------------------------------------------------------- /figures/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunhan-zhao/ARC/13add94311bfa22660e34200ec8a1dd97a66faa3/figures/demo.png -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import os, random, time, copy 2 | import sys 3 | from skimage import io, transform 4 | import numpy as np 5 | import os.path as path 6 | import scipy.io as sio 7 | import matplotlib.pyplot as plt 8 | 9 | import torch 10 | from torch.utils.data import Dataset, DataLoader 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | from torch.optim import lr_scheduler 14 | import torch.nn.functional as F 15 | from torch.autograd import Variable 16 | 17 | import torchvision 18 | from torchvision import datasets, models, transforms 19 | import torchvision.models as models 20 | 21 | class StyleLoss(nn.Module): 22 | r""" 23 | Perceptual loss, VGG-based 24 | https://arxiv.org/abs/1603.08155 25 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 26 | """ 27 | 28 | def __init__(self, vgg19=None): 29 | super(StyleLoss, self).__init__() 30 | self.add_module('vgg', vgg19) 31 | self.criterion = torch.nn.L1Loss() 32 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 33 | self.vgg.to(self.device) 34 | 35 | def compute_gram(self, x): 36 | b, ch, h, w = x.size() 37 | f = x.view(b, ch, w * h) 38 | f_T = f.transpose(1, 2) 39 | # print(f_T) 40 | G = f.bmm(f_T) / (h * w * ch) 41 | 42 | # test = f.bmm(f_T) / (h * w * ch) 43 | # cond = torch.isnan(test) 44 | # print(torch.sum(cond)) 45 | # if torch.sum(cond) > 0: 46 | # idx = np.argwhere(np.isnan(test.to('cpu').detach().numpy())) 47 | # print(idx[0], test.shape) 48 | # print(test[idx[0]]) 49 | # print(f[torch.isinf(f)], f[torch.isnan(f)]) 50 | # print(f_T[torch.isinf(f_T)], f_T[torch.isnan(f_T)]) 51 | # print(f.bmm(f_T)[cond]) 52 | # print(torch.bmm(f, f_T)[cond]) 53 | # print(h * w * ch) 54 | # sys.exit(1) 55 | 56 | # print(f.bmm(f_T)) 57 | # print(h * w * ch) 58 | # print(G) 59 | 60 | return G 61 | 62 | def __call__(self, x, y): 63 | # Compute features 64 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 65 | 66 | # Compute loss 67 | style_loss = 0.0 68 | style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2'])) 69 | style_loss += self.criterion(self.compute_gram(x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4'])) 70 | style_loss += self.criterion(self.compute_gram(x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4'])) 71 | style_loss += self.criterion(self.compute_gram(x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2'])) 72 | 73 | return style_loss 74 | 75 | class PerceptualLoss(nn.Module): 76 | r""" 77 | Perceptual loss, VGG-based 78 | https://arxiv.org/abs/1603.08155 79 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 80 | """ 81 | 82 | def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0], vgg19=None): 83 | super(PerceptualLoss, self).__init__() 84 | self.add_module('vgg', vgg19) 85 | self.criterion = torch.nn.L1Loss() 86 | self.weights = weights 87 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 88 | self.vgg.to(self.device) 89 | 90 | def __call__(self, x, y): 91 | # Compute features 92 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 93 | 94 | content_loss = 0.0 95 | content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1']) 96 | content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1']) 97 | content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1']) 98 | content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1']) 99 | content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1']) 100 | 101 | return content_loss 102 | 103 | class VGG19(torch.nn.Module): 104 | def __init__(self): 105 | super(VGG19, self).__init__() 106 | features = models.vgg19(pretrained=True).features 107 | self.relu1_1 = torch.nn.Sequential() 108 | self.relu1_2 = torch.nn.Sequential() 109 | 110 | self.relu2_1 = torch.nn.Sequential() 111 | self.relu2_2 = torch.nn.Sequential() 112 | 113 | self.relu3_1 = torch.nn.Sequential() 114 | self.relu3_2 = torch.nn.Sequential() 115 | self.relu3_3 = torch.nn.Sequential() 116 | self.relu3_4 = torch.nn.Sequential() 117 | 118 | self.relu4_1 = torch.nn.Sequential() 119 | self.relu4_2 = torch.nn.Sequential() 120 | self.relu4_3 = torch.nn.Sequential() 121 | self.relu4_4 = torch.nn.Sequential() 122 | 123 | self.relu5_1 = torch.nn.Sequential() 124 | self.relu5_2 = torch.nn.Sequential() 125 | self.relu5_3 = torch.nn.Sequential() 126 | self.relu5_4 = torch.nn.Sequential() 127 | 128 | for x in range(2): 129 | self.relu1_1.add_module(str(x), features[x]) 130 | 131 | for x in range(2, 4): 132 | self.relu1_2.add_module(str(x), features[x]) 133 | 134 | for x in range(4, 7): 135 | self.relu2_1.add_module(str(x), features[x]) 136 | 137 | for x in range(7, 9): 138 | self.relu2_2.add_module(str(x), features[x]) 139 | 140 | for x in range(9, 12): 141 | self.relu3_1.add_module(str(x), features[x]) 142 | 143 | for x in range(12, 14): 144 | self.relu3_2.add_module(str(x), features[x]) 145 | 146 | for x in range(14, 16): 147 | self.relu3_3.add_module(str(x), features[x]) 148 | 149 | for x in range(16, 18): 150 | self.relu3_4.add_module(str(x), features[x]) 151 | 152 | for x in range(18, 21): 153 | self.relu4_1.add_module(str(x), features[x]) 154 | 155 | for x in range(21, 23): 156 | self.relu4_2.add_module(str(x), features[x]) 157 | 158 | for x in range(23, 25): 159 | self.relu4_3.add_module(str(x), features[x]) 160 | 161 | for x in range(25, 27): 162 | self.relu4_4.add_module(str(x), features[x]) 163 | 164 | for x in range(27, 30): 165 | self.relu5_1.add_module(str(x), features[x]) 166 | 167 | for x in range(30, 32): 168 | self.relu5_2.add_module(str(x), features[x]) 169 | 170 | for x in range(32, 34): 171 | self.relu5_3.add_module(str(x), features[x]) 172 | 173 | for x in range(34, 36): 174 | self.relu5_4.add_module(str(x), features[x]) 175 | 176 | # don't need the gradients, just want the features 177 | for param in self.parameters(): 178 | param.requires_grad = False 179 | 180 | def forward(self, x): 181 | relu1_1 = self.relu1_1(x) 182 | relu1_2 = self.relu1_2(relu1_1) 183 | 184 | relu2_1 = self.relu2_1(relu1_2) 185 | relu2_2 = self.relu2_2(relu2_1) 186 | 187 | relu3_1 = self.relu3_1(relu2_2) 188 | relu3_2 = self.relu3_2(relu3_1) 189 | relu3_3 = self.relu3_3(relu3_2) 190 | relu3_4 = self.relu3_4(relu3_3) 191 | 192 | relu4_1 = self.relu4_1(relu3_4) 193 | relu4_2 = self.relu4_2(relu4_1) 194 | relu4_3 = self.relu4_3(relu4_2) 195 | relu4_4 = self.relu4_4(relu4_3) 196 | 197 | relu5_1 = self.relu5_1(relu4_4) 198 | relu5_2 = self.relu5_2(relu5_1) 199 | relu5_3 = self.relu5_3(relu5_2) 200 | relu5_4 = self.relu5_4(relu5_3) 201 | 202 | out = { 203 | 'relu1_1': relu1_1, 204 | 'relu1_2': relu1_2, 205 | 206 | 'relu2_1': relu2_1, 207 | 'relu2_2': relu2_2, 208 | 209 | 'relu3_1': relu3_1, 210 | 'relu3_2': relu3_2, 211 | 'relu3_3': relu3_3, 212 | 'relu3_4': relu3_4, 213 | 214 | 'relu4_1': relu4_1, 215 | 'relu4_2': relu4_2, 216 | 'relu4_3': relu4_3, 217 | 'relu4_4': relu4_4, 218 | 219 | 'relu5_1': relu5_1, 220 | 'relu5_2': relu5_2, 221 | 'relu5_3': relu5_3, 222 | 'relu5_4': relu5_4, 223 | } 224 | return out 225 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunhan-zhao/ARC/13add94311bfa22660e34200ec8a1dd97a66faa3/models/__init__.py -------------------------------------------------------------------------------- /models/depth_generator_networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.autograd import Variable 6 | from torchvision import models 7 | import torch.nn.functional as F 8 | from torch.optim import lr_scheduler 9 | 10 | 11 | ###################################################################################### 12 | # Functions 13 | ###################################################################################### 14 | def get_norm_layer(norm_type='batch'): 15 | if norm_type == 'batch': 16 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 17 | elif norm_type == 'instance': 18 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) 19 | elif norm_type == 'none': 20 | norm_layer = None 21 | else: 22 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 23 | return norm_layer 24 | 25 | 26 | def get_nonlinearity_layer(activation_type='PReLU'): 27 | if activation_type == 'ReLU': 28 | nonlinearity_layer = nn.ReLU(True) 29 | elif activation_type == 'SELU': 30 | nonlinearity_layer = nn.SELU(True) 31 | elif activation_type == 'LeakyReLU': 32 | nonlinearity_layer = nn.LeakyReLU(0.1, True) 33 | elif activation_type == 'PReLU': 34 | nonlinearity_layer = nn.PReLU() 35 | else: 36 | raise NotImplementedError('activation layer [%s] is not found' % activation_type) 37 | return nonlinearity_layer 38 | 39 | 40 | def get_scheduler(optimizer, opt): 41 | if opt.lr_policy == 'lambda': 42 | def lambda_rule(epoch): 43 | lr_l = 1.0 - max(0, epoch+1+1+opt.epoch_count-opt.niter) / float(opt.niter_decay+1) 44 | return lr_l 45 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 46 | elif opt.lr_policy == 'step': 47 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 48 | elif opt.lr_policy == 'exponent': 49 | scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.95) 50 | # scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9) 51 | else: 52 | raise NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 53 | return scheduler 54 | 55 | 56 | def init_weights(net, net_name=None, init_type='normal', gain=0.02): 57 | def init_func(m): 58 | classname = m.__class__.__name__ 59 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 60 | if init_type == 'normal': 61 | init.normal_(m.weight.data, 0.0, gain) 62 | elif init_type == 'xavier': 63 | init.xavier_normal_(m.weight.data, gain=gain) 64 | elif init_type == 'kaiming': 65 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 66 | elif init_type == 'orthogonal': 67 | init.orthogonal_(m.weight.data, gain=gain) 68 | else: 69 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 70 | if hasattr(m, 'bias') and m.bias is not None: 71 | init.constant_(m.bias.data, 0.0) 72 | elif classname.find('BatchNorm2d') != -1: 73 | init.uniform_(m.weight.data, 1.0, gain) 74 | init.constant_(m.bias.data, 0.0) 75 | 76 | print('initialize network {} with {}'.format(net_name, init_type)) 77 | net.apply(init_func) 78 | 79 | 80 | def print_network(net): 81 | num_params = 0 82 | for param in net.parameters(): 83 | num_params += param.numel() 84 | print(net) 85 | print('total number of parameters: %.3f M' % (num_params / 1e6)) 86 | 87 | 88 | def init_net(net, init_type='normal', gpu_ids=[]): 89 | 90 | print_network(net) 91 | 92 | if len(gpu_ids) > 0: 93 | assert(torch.cuda.is_available()) 94 | net = torch.nn.DataParallel(net, gpu_ids) 95 | net.cuda() 96 | init_weights(net, init_type) 97 | return net 98 | 99 | 100 | def _freeze(*args): 101 | for module in args: 102 | if module: 103 | for p in module.parameters(): 104 | p.requires_grad = False 105 | 106 | 107 | def _unfreeze(*args): 108 | for module in args: 109 | if module: 110 | for p in module.parameters(): 111 | p.requires_grad = True 112 | 113 | 114 | # define the generator(transform, task) network 115 | def define_G(input_nc, output_nc, ngf=64, layers=4, norm='batch', activation='PReLU', model_type='UNet', 116 | init_type='xavier', drop_rate=0, add_noise=False, gpu_ids=[], weight=0.1): 117 | 118 | if model_type == 'ResNet': 119 | net = _ResGenerator(input_nc, output_nc, ngf, layers, norm, activation, drop_rate, add_noise, gpu_ids) 120 | elif model_type == 'UNet': 121 | net = _UNetGenerator(input_nc, output_nc, ngf, layers, norm, activation, drop_rate, add_noise, gpu_ids, weight) 122 | # net = _PreUNet16(input_nc, output_nc, ngf, layers, True, norm, activation, drop_rate, gpu_ids) 123 | else: 124 | raise NotImplementedError('model type [%s] is not implemented', model_type) 125 | 126 | return init_net(net, init_type, gpu_ids) 127 | 128 | 129 | # define the discriminator network 130 | def define_D(input_nc, ndf = 64, n_layers = 3, num_D = 1, norm = 'batch', activation = 'PReLU', init_type='xavier', gpu_ids = []): 131 | 132 | net = _MultiscaleDiscriminator(input_nc, ndf, n_layers, num_D, norm, activation, gpu_ids) 133 | 134 | return init_net(net, init_type, gpu_ids) 135 | 136 | 137 | # define the feature discriminator network 138 | def define_featureD(input_nc, n_layers=2, norm='batch', activation='PReLU', init_type='xavier', gpu_ids=[]): 139 | 140 | net = _FeatureDiscriminator(input_nc, n_layers, norm, activation, gpu_ids) 141 | 142 | return init_net(net, init_type, gpu_ids) 143 | 144 | 145 | ###################################################################################### 146 | # Basic Operation 147 | ###################################################################################### 148 | 149 | class GaussianNoiseLayer(nn.Module): 150 | def __init__(self): 151 | super(GaussianNoiseLayer, self).__init__() 152 | 153 | def forward(self, x): 154 | if self.training == False: 155 | return x 156 | noise = Variable((torch.randn(x.size()).cuda(x.data.get_device()) - 0.5) / 10.0) 157 | return x+noise 158 | 159 | 160 | class _InceptionBlock(nn.Module): 161 | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.PReLU(), width=1, drop_rate=0, use_bias=False): 162 | super(_InceptionBlock, self).__init__() 163 | 164 | self.width = width 165 | self.drop_rate = drop_rate 166 | 167 | for i in range(width): 168 | layer = nn.Sequential( 169 | nn.ReflectionPad2d(i*2+1), 170 | nn.Conv2d(input_nc, output_nc, kernel_size=3, padding=0, dilation=i*2+1, bias=use_bias) 171 | ) 172 | setattr(self, 'layer'+str(i), layer) 173 | 174 | self.norm1 = norm_layer(output_nc * width) 175 | self.norm2 = norm_layer(output_nc) 176 | self.nonlinearity = nonlinearity 177 | self.branch1x1 = nn.Sequential( 178 | nn.ReflectionPad2d(1), 179 | nn.Conv2d(output_nc * width, output_nc, kernel_size=3, padding=0, bias=use_bias) 180 | ) 181 | 182 | def forward(self, x): 183 | result = [] 184 | for i in range(self.width): 185 | layer = getattr(self, 'layer'+str(i)) 186 | result.append(layer(x)) 187 | output = torch.cat(result, 1) 188 | output = self.nonlinearity(self.norm1(output)) 189 | output = self.norm2(self.branch1x1(output)) 190 | if self.drop_rate > 0: 191 | output = F.dropout(output, p=self.drop_rate, training=self.training) 192 | 193 | return self.nonlinearity(output+x) 194 | 195 | 196 | class _EncoderBlock(nn.Module): 197 | def __init__(self, input_nc, middle_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.PReLU(), use_bias=False): 198 | super(_EncoderBlock, self).__init__() 199 | 200 | model = [ 201 | nn.Conv2d(input_nc, middle_nc, kernel_size=3, stride=1, padding=1, bias=use_bias), 202 | norm_layer(middle_nc), 203 | nonlinearity, 204 | nn.Conv2d(middle_nc, output_nc, kernel_size=3, stride=1, padding=1, bias=use_bias), 205 | norm_layer(output_nc), 206 | nonlinearity 207 | ] 208 | 209 | self.model = nn.Sequential(*model) 210 | 211 | def forward(self, x): 212 | return self.model(x) 213 | 214 | 215 | class _DownBlock(nn.Module): 216 | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.PReLU(), use_bias=False): 217 | super(_DownBlock, self).__init__() 218 | 219 | model = [ 220 | nn.Conv2d(input_nc, output_nc, kernel_size=3, stride=1, padding=1, bias=use_bias), 221 | norm_layer(output_nc), 222 | nonlinearity, 223 | nn.MaxPool2d(kernel_size=2, stride=2), 224 | ] 225 | 226 | self.model = nn.Sequential(*model) 227 | 228 | def forward(self, x): 229 | return self.model(x) 230 | 231 | 232 | class _ShuffleUpBlock(nn.Module): 233 | def __init__(self, input_nc, up_scale, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.PReLU(), use_bias=False): 234 | super(_ShuffleUpBlock, self).__init__() 235 | 236 | model = [ 237 | nn.Conv2d(input_nc, input_nc*up_scale**2, kernel_size=3, stride=1, padding=1, bias=use_bias), 238 | nn.PixelShuffle(up_scale), 239 | nonlinearity, 240 | nn.Conv2d(input_nc, output_nc, kernel_size=3, stride=1, padding=1, bias=use_bias), 241 | norm_layer(output_nc), 242 | nonlinearity 243 | ] 244 | 245 | self.model = nn.Sequential(*model) 246 | 247 | def forward(self, x): 248 | return self.model(x) 249 | 250 | 251 | class _DecoderUpBlock(nn.Module): 252 | def __init__(self, input_nc, middle_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.PReLU(), use_bias=False): 253 | super(_DecoderUpBlock, self).__init__() 254 | 255 | model = [ 256 | nn.ReflectionPad2d(1), 257 | nn.Conv2d(input_nc, middle_nc, kernel_size=3, stride=1, padding=0, bias=use_bias), 258 | norm_layer(middle_nc), 259 | nonlinearity, 260 | nn.ConvTranspose2d(middle_nc, output_nc, kernel_size=3, stride=2, padding=1, output_padding=1), 261 | norm_layer(output_nc), 262 | nonlinearity 263 | ] 264 | 265 | self.model = nn.Sequential(*model) 266 | 267 | def forward(self, x): 268 | return self.model(x) 269 | 270 | class _DecoderUpBlock_Upsampling(nn.Module): 271 | def __init__(self, input_nc, middle_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.PReLU(), use_bias=False): 272 | super(_DecoderUpBlock_Upsampling, self).__init__() 273 | 274 | model = [ 275 | nn.ReflectionPad2d(1), 276 | nn.Conv2d(input_nc, middle_nc, kernel_size=3, stride=1, padding=0, bias=use_bias), 277 | norm_layer(middle_nc), 278 | nonlinearity, 279 | # nn.ConvTranspose2d(middle_nc, output_nc, kernel_size=3, stride=2, padding=1, output_padding=1), 280 | nn.Upsample(scale_factor = 2, mode='bilinear'), 281 | nn.ReflectionPad2d(1), 282 | nn.Conv2d(middle_nc, output_nc, kernel_size=3, stride=1, padding=0), 283 | norm_layer(output_nc), 284 | nonlinearity 285 | ] 286 | 287 | self.model = nn.Sequential(*model) 288 | 289 | def forward(self, x): 290 | return self.model(x) 291 | 292 | class _OutputBlock(nn.Module): 293 | def __init__(self, input_nc, output_nc, kernel_size=3, use_bias=False): 294 | super(_OutputBlock, self).__init__() 295 | 296 | model = [ 297 | nn.ReflectionPad2d(int(kernel_size/2)), 298 | nn.Conv2d(input_nc, output_nc, kernel_size=kernel_size, padding=0, bias=use_bias), 299 | nn.Tanh() 300 | ] 301 | 302 | self.model = nn.Sequential(*model) 303 | 304 | def forward(self, x): 305 | return self.model(x) 306 | 307 | 308 | ###################################################################################### 309 | # Network structure 310 | ###################################################################################### 311 | 312 | class _ResGenerator(nn.Module): 313 | def __init__(self, input_nc, output_nc, ngf=64, n_blocks=6, norm='batch', activation='PReLU', drop_rate=0, add_noise=False, gpu_ids=[]): 314 | super(_ResGenerator, self).__init__() 315 | 316 | self.gpu_ids = gpu_ids 317 | 318 | norm_layer = get_norm_layer(norm_type=norm) 319 | nonlinearity = get_nonlinearity_layer(activation_type=activation) 320 | 321 | if type(norm_layer) == functools.partial: 322 | use_bias = norm_layer.func == nn.InstanceNorm2d 323 | else: 324 | use_bias = norm_layer == nn.InstanceNorm2d 325 | 326 | encoder = [ 327 | nn.ReflectionPad2d(3), 328 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), 329 | norm_layer(ngf), 330 | nonlinearity 331 | ] 332 | 333 | n_downsampling = 2 334 | mult = 1 335 | for i in range(n_downsampling): 336 | mult_prev = mult 337 | mult = min(2 ** (i+1), 2) 338 | encoder += [ 339 | _EncoderBlock(ngf * mult_prev, ngf*mult, ngf*mult, norm_layer, nonlinearity, use_bias), 340 | nn.AvgPool2d(kernel_size=2, stride=2) 341 | ] 342 | 343 | mult = min(2 ** n_downsampling, 2) 344 | for i in range(n_blocks-n_downsampling): 345 | encoder +=[ 346 | _InceptionBlock(ngf*mult, ngf*mult, norm_layer=norm_layer, nonlinearity=nonlinearity, width=1, 347 | drop_rate=drop_rate, use_bias=use_bias) 348 | ] 349 | 350 | decoder = [] 351 | if add_noise: 352 | decoder += [GaussianNoiseLayer()] 353 | 354 | for i in range(n_downsampling): 355 | mult_prev = mult 356 | mult = min(2 ** (n_downsampling - i -1), 2) 357 | decoder +=[ 358 | _DecoderUpBlock(ngf*mult_prev, ngf*mult_prev, ngf*mult, norm_layer, nonlinearity, use_bias), 359 | ] 360 | 361 | decoder +=[ 362 | nn.ReflectionPad2d(3), 363 | nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), 364 | nn.Tanh() 365 | ] 366 | 367 | self.encoder = nn.Sequential(*encoder) 368 | self.decoder = nn.Sequential(*decoder) 369 | 370 | def forward(self, input): 371 | feature = self.encoder(input) 372 | result = [feature] 373 | output = self.decoder(feature) 374 | result.append(output) 375 | return result 376 | 377 | class _ResGenerator_Upsample_Conv2d(nn.Module): 378 | def __init__(self, input_nc, output_nc, ngf=64, n_blocks=6, norm='batch', activation='PReLU', drop_rate=0, add_noise=False, gpu_ids=[]): 379 | super(_ResGenerator_Upsample_Conv2d, self).__init__() 380 | 381 | self.gpu_ids = gpu_ids 382 | 383 | norm_layer = get_norm_layer(norm_type=norm) 384 | nonlinearity = get_nonlinearity_layer(activation_type=activation) 385 | 386 | if type(norm_layer) == functools.partial: 387 | use_bias = norm_layer.func == nn.InstanceNorm2d 388 | else: 389 | use_bias = norm_layer == nn.InstanceNorm2d 390 | 391 | encoder = [ 392 | nn.ReflectionPad2d(3), 393 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), 394 | norm_layer(ngf), 395 | nonlinearity 396 | ] 397 | 398 | n_downsampling = 2 399 | mult = 1 400 | for i in range(n_downsampling): 401 | mult_prev = mult 402 | mult = min(2 ** (i+1), 2) 403 | encoder += [ 404 | _EncoderBlock(ngf * mult_prev, ngf*mult, ngf*mult, norm_layer, nonlinearity, use_bias), 405 | nn.AvgPool2d(kernel_size=2, stride=2) 406 | ] 407 | 408 | mult = min(2 ** n_downsampling, 2) 409 | for i in range(n_blocks-n_downsampling): 410 | encoder +=[ 411 | _InceptionBlock(ngf*mult, ngf*mult, norm_layer=norm_layer, nonlinearity=nonlinearity, width=1, 412 | drop_rate=drop_rate, use_bias=use_bias) 413 | ] 414 | 415 | decoder = [] 416 | if add_noise: 417 | decoder += [GaussianNoiseLayer()] 418 | 419 | for i in range(n_downsampling): 420 | mult_prev = mult 421 | mult = min(2 ** (n_downsampling - i -1), 2) 422 | decoder +=[ 423 | _DecoderUpBlock_Upsampling(ngf*mult_prev, ngf*mult_prev, ngf*mult, norm_layer, nonlinearity, use_bias), 424 | ] 425 | 426 | decoder += [ 427 | nn.ReflectionPad2d(3), 428 | nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0) 429 | # nn.Conv2d(ngf, output_nc, kernel_size=1, padding=0) 430 | ] 431 | 432 | self.encoder = nn.Sequential(*encoder) 433 | self.decoder = nn.Sequential(*decoder) 434 | 435 | def forward(self, input): 436 | feature = self.encoder(input) 437 | result = [feature] 438 | output = self.decoder(feature) 439 | # print('before first sigmoid before final projection:', output) 440 | # output = self.final_proj(output_pool) 441 | print('before first sigmoid after final projection:', output) 442 | result.append(output) 443 | return result 444 | 445 | class _ResGenerator_Upsample_Conv2d_Pool(nn.Module): 446 | def __init__(self, input_nc, output_nc, output_size, ngf=64, n_blocks=6, norm='batch', activation='PReLU', drop_rate=0, add_noise=False, gpu_ids=[]): 447 | super(_ResGenerator_Upsample_Conv2d_Pool, self).__init__() 448 | 449 | self.gpu_ids = gpu_ids 450 | self.output_h = output_size[0] 451 | self.output_w = output_size[1] 452 | 453 | norm_layer = get_norm_layer(norm_type=norm) 454 | nonlinearity = get_nonlinearity_layer(activation_type=activation) 455 | 456 | if type(norm_layer) == functools.partial: 457 | use_bias = norm_layer.func == nn.InstanceNorm2d 458 | else: 459 | use_bias = norm_layer == nn.InstanceNorm2d 460 | 461 | encoder = [ 462 | nn.ReflectionPad2d(3), 463 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), 464 | norm_layer(ngf), 465 | nonlinearity 466 | ] 467 | 468 | n_downsampling = 2 469 | mult = 1 470 | for i in range(n_downsampling): 471 | mult_prev = mult 472 | mult = min(2 ** (i+1), 2) 473 | encoder += [ 474 | _EncoderBlock(ngf * mult_prev, ngf*mult, ngf*mult, norm_layer, nonlinearity, use_bias), 475 | nn.AvgPool2d(kernel_size=2, stride=2) 476 | ] 477 | 478 | mult = min(2 ** n_downsampling, 2) 479 | for i in range(n_blocks-n_downsampling): 480 | encoder +=[ 481 | _InceptionBlock(ngf*mult, ngf*mult, norm_layer=norm_layer, nonlinearity=nonlinearity, width=1, 482 | drop_rate=drop_rate, use_bias=use_bias) 483 | ] 484 | 485 | decoder = [] 486 | if add_noise: 487 | decoder += [GaussianNoiseLayer()] 488 | 489 | for i in range(n_downsampling): 490 | mult_prev = mult 491 | mult = min(2 ** (n_downsampling - i -1), 2) 492 | decoder +=[ 493 | _DecoderUpBlock_Upsampling(ngf*mult_prev, ngf*mult_prev, ngf*mult, norm_layer, nonlinearity, use_bias), 494 | ] 495 | 496 | final_proj = [ 497 | # nn.ReflectionPad2d(3), 498 | # nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0) 499 | nn.Conv2d(ngf, output_nc, kernel_size=1, padding=0) 500 | ] 501 | 502 | self.encoder = nn.Sequential(*encoder) 503 | self.decoder = nn.Sequential(*decoder) 504 | self.final_proj = nn.Sequential(*final_proj) 505 | 506 | def forward(self, input): 507 | feature = self.encoder(input) 508 | # print(feature) 509 | result = [feature] 510 | output = self.decoder(feature) 511 | H, W = output.size()[2], output.size()[3] 512 | output_pool = F.max_pool2d(output, kernel_size=(int(H/self.output_h), int(W/self.output_w)), 513 | stride=(int(H/self.output_h), int(W/self.output_w))) 514 | print('before first sigmoid, pooling:', output_pool) 515 | output_pool = self.final_proj(output_pool) 516 | print('before first sigmoid, after projection: ', output_pool) 517 | # print(output_pool.size()) 518 | result.append(output_pool) 519 | return result 520 | 521 | class _ResGenerator_Upsample(nn.Module): 522 | def __init__(self, input_nc, output_nc, ngf=64, n_blocks=6, norm='batch', activation='PReLU', drop_rate=0, add_noise=False, gpu_ids=[]): 523 | super(_ResGenerator_Upsample, self).__init__() 524 | 525 | self.gpu_ids = gpu_ids 526 | 527 | norm_layer = get_norm_layer(norm_type=norm) 528 | nonlinearity = get_nonlinearity_layer(activation_type=activation) 529 | 530 | if type(norm_layer) == functools.partial: 531 | use_bias = norm_layer.func == nn.InstanceNorm2d 532 | else: 533 | use_bias = norm_layer == nn.InstanceNorm2d 534 | 535 | encoder = [ 536 | nn.ReflectionPad2d(3), 537 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), 538 | norm_layer(ngf), 539 | nonlinearity 540 | ] 541 | 542 | n_downsampling = 2 543 | mult = 1 544 | for i in range(n_downsampling): 545 | mult_prev = mult 546 | mult = min(2 ** (i+1), 2) 547 | encoder += [ 548 | _EncoderBlock(ngf * mult_prev, ngf*mult, ngf*mult, norm_layer, nonlinearity, use_bias), 549 | nn.AvgPool2d(kernel_size=2, stride=2) 550 | ] 551 | 552 | mult = min(2 ** n_downsampling, 2) 553 | for i in range(n_blocks-n_downsampling): 554 | encoder +=[ 555 | _InceptionBlock(ngf*mult, ngf*mult, norm_layer=norm_layer, nonlinearity=nonlinearity, width=1, 556 | drop_rate=drop_rate, use_bias=use_bias) 557 | ] 558 | 559 | decoder = [] 560 | if add_noise: 561 | decoder += [GaussianNoiseLayer()] 562 | 563 | for i in range(n_downsampling): 564 | mult_prev = mult 565 | mult = min(2 ** (n_downsampling - i -1), 2) 566 | decoder +=[ 567 | _DecoderUpBlock_Upsampling(ngf*mult_prev, ngf*mult_prev, ngf*mult, norm_layer, nonlinearity, use_bias), 568 | ] 569 | 570 | decoder +=[ 571 | nn.ReflectionPad2d(3), 572 | nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), 573 | nn.Tanh() 574 | ] 575 | 576 | self.encoder = nn.Sequential(*encoder) 577 | self.decoder = nn.Sequential(*decoder) 578 | 579 | def forward(self, input): 580 | feature = self.encoder(input) 581 | result = [feature] 582 | output = self.decoder(feature) 583 | result.append(output) 584 | return result 585 | 586 | class _PreUNet16(nn.Module): 587 | def __init__(self, input_nc, output_nc, ngf=64, layers=5, pretrained=False, norm ='batch', activation='PReLu', 588 | drop_rate=0, gpu_ids=[]): 589 | super(_PreUNet16, self).__init__() 590 | 591 | self.gpu_ids = gpu_ids 592 | self.layers = layers 593 | norm_layer = get_norm_layer(norm_type=norm) 594 | nonlinearity = get_nonlinearity_layer(activation_type=activation) 595 | if type(norm_layer) == functools.partial: 596 | use_bias = norm_layer.func == nn.InstanceNorm2d 597 | else: 598 | use_bias = norm_layer == nn.InstanceNorm2d 599 | 600 | encoder = models.vgg16(pretrained=pretrained).features 601 | 602 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 603 | self.relu = nn.ReLU(inplace=True) 604 | 605 | self.conv1 = nn.Sequential(encoder[0], self.relu, encoder[2], self.relu) 606 | self.conv2 = nn.Sequential(encoder[5], self.relu, encoder[7], self.relu) 607 | self.conv3 = nn.Sequential(encoder[10], self.relu, encoder[12], self.relu, encoder[14], self.relu) 608 | self.conv4 = nn.Sequential(encoder[17], self.relu, encoder[19], self.relu, encoder[21], self.relu) 609 | 610 | for i in range(layers - 4): 611 | conv = _EncoderBlock(ngf * 8, ngf * 8, ngf * 8, norm_layer, nonlinearity, use_bias) 612 | setattr(self, 'down' + str(i), conv.model) 613 | 614 | center = [] 615 | for i in range(7 - layers): 616 | center += [ 617 | _InceptionBlock(ngf * 8, ngf * 8, norm_layer, nonlinearity, 7 - layers, drop_rate, use_bias) 618 | ] 619 | 620 | center += [_DecoderUpBlock(ngf * 8, ngf * 8, ngf * 4, norm_layer, nonlinearity, use_bias)] 621 | 622 | for i in range(layers - 4): 623 | upconv = _DecoderUpBlock(ngf * (8 + 4), ngf * 8, ngf * 4, norm_layer, nonlinearity, use_bias) 624 | setattr(self, 'up' + str(i), upconv.model) 625 | 626 | self.deconv4 = _DecoderUpBlock(ngf * (4 + 4), ngf * 8, ngf * 2, norm_layer, nonlinearity, use_bias) 627 | self.deconv3 = _DecoderUpBlock(ngf * (2 + 2) + output_nc, ngf * 4, ngf, norm_layer, nonlinearity, use_bias) 628 | self.deconv2 = _DecoderUpBlock(ngf * (1 + 1) + output_nc, ngf * 2, int(ngf / 2), norm_layer, nonlinearity, use_bias) 629 | 630 | self.deconv1 = _OutputBlock(int(ngf / 2) + output_nc, output_nc, kernel_size=7, use_bias=use_bias) 631 | 632 | self.output4 = _OutputBlock(ngf * (4 + 4), output_nc, kernel_size=3, use_bias=use_bias) 633 | self.output3 = _OutputBlock(ngf * (2 + 2) + output_nc, output_nc, kernel_size=3, use_bias=use_bias) 634 | self.output2 = _OutputBlock(ngf * (1 + 1) + output_nc, output_nc, kernel_size=3, use_bias=use_bias) 635 | 636 | self.center = nn.Sequential(*center) 637 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 638 | 639 | def forward(self, input): 640 | conv1 = self.pool(self.conv1(input)) 641 | conv2 = self.pool(self.conv2(conv1)) 642 | conv3 = self.pool(self.conv3(conv2)) 643 | center_in = self.pool(self.conv4(conv3)) 644 | 645 | middle = [center_in] 646 | for i in range(self.layers - 4): 647 | model = getattr(self, 'down' + str(i)) 648 | center_in = self.pool(model(center_in)) 649 | middle.append(center_in) 650 | 651 | result = [center_in] 652 | 653 | center_out = self.center(center_in) 654 | 655 | for i in range(self.layers - 4): 656 | model = getattr(self, 'up' + str(i)) 657 | center_out = model(torch.cat([center_out, middle[self.layers - 4 - i]], 1)) 658 | 659 | deconv4 = self.deconv4.forward(torch.cat([center_out, conv3 * 0.1], 1)) 660 | output4 = self.output4.forward(torch.cat([center_out, conv3 * 0.1], 1)) 661 | result.append(output4) 662 | deconv3 = self.deconv3.forward(torch.cat([deconv4, conv2 * 0.05, self.upsample(output4)], 1)) 663 | output3 = self.output3.forward(torch.cat([deconv4, conv2 * 0.05, self.upsample(output4)], 1)) 664 | result.append(output3) 665 | deconv2 = self.deconv2.forward(torch.cat([deconv3, conv1 * 0.01, self.upsample(output3)], 1)) 666 | output2 = self.output2.forward(torch.cat([deconv3, conv1 * 0.01, self.upsample(output3)], 1)) 667 | result.append(output2) 668 | 669 | output1 = self.deconv1.forward(torch.cat([deconv2, self.upsample(output2)], 1)) 670 | result.append(output1) 671 | 672 | return result 673 | 674 | 675 | class _UNetGenerator(nn.Module): 676 | def __init__(self, input_nc, output_nc, ngf=64, layers=4, norm='batch', activation='PReLU', drop_rate=0, add_noise=False, gpu_ids=[], 677 | weight=0.1): 678 | super(_UNetGenerator, self).__init__() 679 | 680 | self.gpu_ids = gpu_ids 681 | self.layers = layers 682 | self.weight = weight 683 | norm_layer = get_norm_layer(norm_type=norm) 684 | nonlinearity = get_nonlinearity_layer(activation_type=activation) 685 | 686 | if type(norm_layer) == functools.partial: 687 | use_bias = norm_layer.func == nn.InstanceNorm2d 688 | else: 689 | use_bias = norm_layer == nn.InstanceNorm2d 690 | 691 | # encoder part 692 | self.pool = nn.AvgPool2d(kernel_size=2, stride=2) 693 | self.conv1 = nn.Sequential( 694 | nn.ReflectionPad2d(3), 695 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), 696 | norm_layer(ngf), 697 | nonlinearity 698 | ) 699 | self.conv2 = _EncoderBlock(ngf, ngf*2, ngf*2, norm_layer, nonlinearity, use_bias) 700 | self.conv3 = _EncoderBlock(ngf*2, ngf*4, ngf*4, norm_layer, nonlinearity, use_bias) 701 | self.conv4 = _EncoderBlock(ngf*4, ngf*8, ngf*8, norm_layer, nonlinearity, use_bias) 702 | 703 | for i in range(layers-4): 704 | conv = _EncoderBlock(ngf*8, ngf*8, ngf*8, norm_layer, nonlinearity, use_bias) 705 | setattr(self, 'down'+str(i), conv.model) 706 | 707 | center=[] 708 | for i in range(7-layers): 709 | center +=[ 710 | _InceptionBlock(ngf*8, ngf*8, norm_layer, nonlinearity, 7-layers, drop_rate, use_bias) 711 | ] 712 | 713 | center += [ 714 | _DecoderUpBlock(ngf*8, ngf*8, ngf*4, norm_layer, nonlinearity, use_bias) 715 | ] 716 | if add_noise: 717 | center += [GaussianNoiseLayer()] 718 | self.center = nn.Sequential(*center) 719 | 720 | for i in range(layers-4): 721 | upconv = _DecoderUpBlock(ngf*(8+4), ngf*8, ngf*4, norm_layer, nonlinearity, use_bias) 722 | setattr(self, 'up' + str(i), upconv.model) 723 | 724 | self.deconv4 = _DecoderUpBlock(ngf*(4+4), ngf*8, ngf*2, norm_layer, nonlinearity, use_bias) 725 | self.deconv3 = _DecoderUpBlock(ngf*(2+2)+output_nc, ngf*4, ngf, norm_layer, nonlinearity, use_bias) 726 | self.deconv2 = _DecoderUpBlock(ngf*(1+1)+output_nc, ngf*2, int(ngf/2), norm_layer, nonlinearity, use_bias) 727 | 728 | self.output4 = _OutputBlock(ngf*(4+4), output_nc, 3, use_bias) 729 | self.output3 = _OutputBlock(ngf*(2+2)+output_nc, output_nc, 3, use_bias) 730 | self.output2 = _OutputBlock(ngf*(1+1)+output_nc, output_nc, 3, use_bias) 731 | self.output1 = _OutputBlock(int(ngf/2)+output_nc, output_nc, 7, use_bias) 732 | 733 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest') 734 | 735 | def forward(self, input): 736 | conv1 = self.pool(self.conv1(input)) 737 | conv2 = self.pool(self.conv2.forward(conv1)) 738 | conv3 = self.pool(self.conv3.forward(conv2)) 739 | center_in = self.pool(self.conv4.forward(conv3)) 740 | 741 | middle = [center_in] 742 | for i in range(self.layers-4): 743 | model = getattr(self, 'down'+str(i)) 744 | center_in = self.pool(model.forward(center_in)) 745 | middle.append(center_in) 746 | center_out = self.center.forward(center_in) 747 | result = [center_in] 748 | 749 | for i in range(self.layers-4): 750 | model = getattr(self, 'up'+str(i)) 751 | center_out = model.forward(torch.cat([center_out, middle[self.layers-5-i]], 1)) 752 | 753 | result.append(center_out) 754 | 755 | deconv4 = self.deconv4.forward(torch.cat([center_out, conv3 * self.weight], 1)) 756 | output4 = self.output4.forward(torch.cat([center_out, conv3 * self.weight], 1)) 757 | result.append(output4) 758 | deconv3 = self.deconv3.forward(torch.cat([deconv4, conv2 * self.weight * 0.5, self.upsample(output4)], 1)) 759 | output3 = self.output3.forward(torch.cat([deconv4, conv2 * self.weight * 0.5, self.upsample(output4)], 1)) 760 | result.append(output3) 761 | deconv2 = self.deconv2.forward(torch.cat([deconv3, conv1 * self.weight * 0.1, self.upsample(output3)], 1)) 762 | output2 = self.output2.forward(torch.cat([deconv3, conv1 * self.weight * 0.1, self.upsample(output3)], 1)) 763 | result.append(output2) 764 | output1 = self.output1.forward(torch.cat([deconv2, self.upsample(output2)], 1)) 765 | result.append(output1) 766 | 767 | return result 768 | 769 | 770 | class _SimplifiedUNetGenerator(nn.Module): 771 | def __init__(self, input_nc, output_nc, ngf=64, layers=4, norm='batch', activation='PReLU', drop_rate=0, add_noise=False, gpu_ids=[], 772 | weight=0.1): 773 | super(_SimplifiedUNetGenerator, self).__init__() 774 | 775 | self.gpu_ids = gpu_ids 776 | self.layers = layers 777 | self.weight = weight 778 | norm_layer = get_norm_layer(norm_type=norm) 779 | nonlinearity = get_nonlinearity_layer(activation_type=activation) 780 | 781 | if type(norm_layer) == functools.partial: 782 | use_bias = norm_layer.func == nn.InstanceNorm2d 783 | else: 784 | use_bias = norm_layer == nn.InstanceNorm2d 785 | 786 | # encoder part 787 | self.pool = nn.AvgPool2d(kernel_size=2, stride=2) 788 | self.conv1 = nn.Sequential( 789 | nn.ReflectionPad2d(3), 790 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), 791 | norm_layer(ngf), 792 | nonlinearity 793 | ) 794 | self.conv2 = _EncoderBlock(ngf, ngf*2, ngf*2, norm_layer, nonlinearity, use_bias) 795 | self.conv3 = _EncoderBlock(ngf*2, ngf*4, ngf*4, norm_layer, nonlinearity, use_bias) 796 | self.conv4 = _EncoderBlock(ngf*4, ngf*8, ngf*8, norm_layer, nonlinearity, use_bias) 797 | 798 | for i in range(layers-4): 799 | conv = _EncoderBlock(ngf*8, ngf*8, ngf*8, norm_layer, nonlinearity, use_bias) 800 | setattr(self, 'down'+str(i), conv.model) 801 | 802 | center=[] 803 | for i in range(7-layers): 804 | center +=[ 805 | _InceptionBlock(ngf*8, ngf*8, norm_layer, nonlinearity, 7-layers, drop_rate, use_bias) 806 | ] 807 | 808 | center += [ 809 | _DecoderUpBlock(ngf*8, ngf*8, ngf*4, norm_layer, nonlinearity, use_bias) 810 | ] 811 | if add_noise: 812 | center += [GaussianNoiseLayer()] 813 | self.center = nn.Sequential(*center) 814 | 815 | for i in range(layers-4): 816 | upconv = _DecoderUpBlock(ngf*(8+4), ngf*8, ngf*4, norm_layer, nonlinearity, use_bias) 817 | setattr(self, 'up' + str(i), upconv.model) 818 | 819 | self.deconv4 = _DecoderUpBlock(ngf*(4+4), ngf*8, ngf*2, norm_layer, nonlinearity, use_bias) 820 | self.deconv3 = _DecoderUpBlock(ngf*(2+2)+output_nc, ngf*4, ngf, norm_layer, nonlinearity, use_bias) 821 | self.deconv2 = _DecoderUpBlock(ngf*(1+1)+output_nc, ngf*2, int(ngf/2), norm_layer, nonlinearity, use_bias) 822 | 823 | self.output4 = _OutputBlock(ngf*(4+4), output_nc, 3, use_bias) 824 | self.output3 = _OutputBlock(ngf*(2+2)+output_nc, output_nc, 3, use_bias) 825 | self.output2 = _OutputBlock(ngf*(1+1)+output_nc, output_nc, 3, use_bias) 826 | self.output1 = _OutputBlock(int(ngf/2)+output_nc, output_nc, 7, use_bias) 827 | 828 | self.upsample = nn.Upsample(scale_factor=2, mode='nearest') 829 | 830 | def forward(self, input): 831 | conv1 = self.pool(self.conv1(input)) 832 | conv2 = self.pool(self.conv2.forward(conv1)) 833 | conv3 = self.pool(self.conv3.forward(conv2)) 834 | # conv4 = self.pool(self.conv4.forward(conv3)) 835 | 836 | # middle = [center_in] 837 | # for i in range(self.layers-4): 838 | # model = getattr(self, 'down'+str(i)) 839 | # center_in = self.pool(model.forward(center_in)) 840 | # middle.append(center_in) 841 | # center_out = self.center.forward(center_in) 842 | # result = [center_in] 843 | 844 | # for i in range(self.layers-4): 845 | # model = getattr(self, 'up'+str(i)) 846 | # center_out = model.forward(torch.cat([center_out, middle[self.layers-5-i]], 1)) 847 | 848 | result = [] 849 | 850 | deconv4 = self.deconv4.forward(torch.cat([conv3, conv3 * self.weight], 1)) 851 | output4 = self.output4.forward(torch.cat([conv3, conv3 * self.weight], 1)) 852 | result.append(output4) 853 | deconv3 = self.deconv3.forward(torch.cat([deconv4, conv2 * self.weight * 0.5, self.upsample(output4)], 1)) 854 | output3 = self.output3.forward(torch.cat([deconv4, conv2 * self.weight * 0.5, self.upsample(output4)], 1)) 855 | result.append(output3) 856 | deconv2 = self.deconv2.forward(torch.cat([deconv3, conv1 * self.weight * 0.1, self.upsample(output3)], 1)) 857 | output2 = self.output2.forward(torch.cat([deconv3, conv1 * self.weight * 0.1, self.upsample(output3)], 1)) 858 | result.append(output2) 859 | output1 = self.output1.forward(torch.cat([deconv2, self.upsample(output2)], 1)) 860 | result.append(output1) 861 | 862 | return result 863 | 864 | 865 | class _MultiscaleDiscriminator(nn.Module): 866 | def __init__(self, input_nc, ndf=64, n_layers=3, num_D=1, norm='batch', activation='PReLU', gpu_ids=[]): 867 | super(_MultiscaleDiscriminator, self).__init__() 868 | 869 | self.num_D = num_D 870 | self.gpu_ids = gpu_ids 871 | 872 | for i in range(num_D): 873 | netD = _Discriminator(input_nc, ndf, n_layers, norm, activation, gpu_ids) 874 | setattr(self, 'scale'+str(i), netD) 875 | 876 | self.downsample = nn.AvgPool2d(kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False) 877 | 878 | def forward(self, input): 879 | result = [] 880 | for i in range(self.num_D): 881 | netD = getattr(self, 'scale'+str(i)) 882 | output = netD.forward(input) 883 | result.append(output) 884 | if i != (self.num_D-1): 885 | input = self.downsample(input) 886 | return result 887 | 888 | 889 | class _Discriminator(nn.Module): 890 | def __init__(self, input_nc, ndf=64, n_layers=3, norm='batch', activation='PReLU', gpu_ids=[]): 891 | super(_Discriminator, self).__init__() 892 | 893 | self.gpu_ids = gpu_ids 894 | 895 | norm_layer = get_norm_layer(norm_type=norm) 896 | nonlinearity = get_nonlinearity_layer(activation_type=activation) 897 | 898 | if type(norm_layer) == functools.partial: 899 | use_bias = norm_layer.func == nn.InstanceNorm2d 900 | else: 901 | use_bias = norm_layer == nn.InstanceNorm2d 902 | 903 | model = [ 904 | nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1, bias=use_bias), 905 | nonlinearity, 906 | ] 907 | 908 | nf_mult=1 909 | for i in range(1, n_layers): 910 | nf_mult_prev = nf_mult 911 | nf_mult = min(2**i, 8) 912 | model += [ 913 | nn.Conv2d(ndf*nf_mult_prev, ndf*nf_mult, kernel_size=4, stride=2, padding=1, bias=use_bias), 914 | norm_layer(ndf*nf_mult), 915 | nonlinearity, 916 | ] 917 | 918 | nf_mult_prev = nf_mult 919 | nf_mult = min(2 ** n_layers, 8) 920 | model += [ 921 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=1, padding=1, bias=use_bias), 922 | norm_layer(ndf * 8), 923 | nonlinearity, 924 | nn.Conv2d(ndf*nf_mult, 1, kernel_size=4, stride=1, padding=1) 925 | ] 926 | 927 | self.model = nn.Sequential(*model) 928 | 929 | def forward(self, input): 930 | return self.model(input) 931 | 932 | 933 | class _FeatureDiscriminator(nn.Module): 934 | def __init__(self, input_nc, n_layers=2, norm='batch', activation='PReLU', gpu_ids=[]): 935 | super(_FeatureDiscriminator, self).__init__() 936 | 937 | self.gpu_ids = gpu_ids 938 | 939 | norm_layer = get_norm_layer(norm_type=norm) 940 | nonlinearity = get_nonlinearity_layer(activation_type=activation) 941 | 942 | if type(norm_layer) == functools.partial: 943 | use_bias = norm_layer.func == nn.InstanceNorm2d 944 | else: 945 | use_bias = norm_layer == nn.InstanceNorm2d 946 | 947 | model = [ 948 | nn.Linear(input_nc * 40 * 30, input_nc), 949 | nonlinearity, 950 | ] 951 | 952 | # for i in range(1, n_layers): 953 | # model +=[ 954 | # nn.Linear(input_nc, input_nc), 955 | # nonlinearity 956 | # ] 957 | 958 | model +=[nn.Linear(input_nc, 1)] 959 | 960 | self.model = nn.Sequential(*model) 961 | 962 | def forward(self, input): 963 | result = [] 964 | # print(input.size()) 965 | # input = input.view(-1, 512 * 40 * 12) 966 | input = input.view(-1, 512 * 30 * 40) 967 | output = self.model(input) 968 | result.append(output) 969 | return result -------------------------------------------------------------------------------- /models/discriminator_networks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.optim import lr_scheduler 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | import torchvision 10 | from torchvision import datasets, models, transforms 11 | 12 | class Discriminator80x80InstNorm(nn.Module): 13 | def __init__(self, device='cpu', pretrained=False, patchSize=[64, 64], input_nc=3): 14 | super(Discriminator80x80InstNorm, self).__init__() 15 | self.device = device 16 | self.input_nc = input_nc 17 | self.patchSize = patchSize 18 | self.outputSize = [patchSize[0]/16, patchSize[1]/16] 19 | 20 | self.discriminator = nn.Sequential( 21 | # 128-->60 22 | nn.Conv2d(self.input_nc, 64, kernel_size=5, padding=0, stride=2, bias=True), 23 | nn.LeakyReLU(0.2, inplace=True), 24 | 25 | # 60-->33 26 | nn.Conv2d(64, 128, kernel_size=5, padding=0, stride=2, bias=False), 27 | nn.InstanceNorm2d(128, momentum=0.001, affine=False, track_running_stats=False), 28 | nn.LeakyReLU(0.2, inplace=True), 29 | # 33-> 30 | nn.Conv2d(128, 256, kernel_size=3, padding=0, stride=2, bias=False), 31 | nn.InstanceNorm2d(256, momentum=0.001, affine=False, track_running_stats=False), 32 | nn.LeakyReLU(0.2, inplace=True), 33 | # 34 | nn.Conv2d(256, 512, kernel_size=3, padding=0, stride=2, bias=False), 35 | nn.InstanceNorm2d(512, momentum=0.001, affine=False, track_running_stats=False), 36 | nn.LeakyReLU(0.2, inplace=True), 37 | # final classification for 'real(1) vs. fake(0)' 38 | nn.Conv2d(512, 1, kernel_size=1, padding=0, stride=1, bias=True), 39 | ) 40 | 41 | def forward(self, X): 42 | return self.discriminator(X) 43 | 44 | class Discriminator80x80InstNormDilation(nn.Module): 45 | # same as Discriminator80x80InstNorm except the kernel size of last layer is changed to 3x3 46 | # used to test receptive field 47 | def __init__(self, device='cpu', dialate_size=1, pretrained=False, patchSize=[64, 64], input_nc=3): 48 | super(Discriminator80x80InstNormDilation, self).__init__() 49 | self.device = device 50 | self.input_nc = input_nc 51 | self.patchSize = patchSize 52 | self.outputSize = [patchSize[0]/16, patchSize[1]/16] 53 | self.dialate_size = dialate_size 54 | 55 | self.discriminator = nn.Sequential( 56 | # 128-->60 57 | nn.Conv2d(self.input_nc, 64, kernel_size=5, padding=0, stride=2, bias=True), 58 | nn.LeakyReLU(0.2, inplace=True), 59 | 60 | # 60-->33 61 | nn.Conv2d(64, 128, kernel_size=5, padding=0, stride=2, bias=False), 62 | nn.InstanceNorm2d(128, momentum=0.001, affine=False, track_running_stats=False), 63 | nn.LeakyReLU(0.2, inplace=True), 64 | # 33-> 65 | nn.Conv2d(128, 256, kernel_size=3, padding=0, stride=2, bias=False), 66 | nn.InstanceNorm2d(256, momentum=0.001, affine=False, track_running_stats=False), 67 | nn.LeakyReLU(0.2, inplace=True), 68 | # 69 | nn.Conv2d(256, 512, kernel_size=3, padding=0, stride=2, bias=False), 70 | nn.InstanceNorm2d(512, momentum=0.001, affine=False, track_running_stats=False), 71 | nn.LeakyReLU(0.2, inplace=True), 72 | # final classification for 'real(1) vs. fake(0)' 73 | nn.Conv2d(512, 1, kernel_size=3, padding=0, stride=1, bias=True, dilation=self.dialate_size), 74 | ) 75 | 76 | def forward(self, X): 77 | return self.discriminator(X) 78 | 79 | class Discriminator5121520InstNorm(nn.Module): 80 | def __init__(self, device='cpu', pretrained=False, patchSize=[64, 64], input_nc=3): 81 | super(Discriminator5121520InstNorm, self).__init__() 82 | self.device = device 83 | self.input_nc = input_nc 84 | self.patchSize = patchSize 85 | self.outputSize = [patchSize[0]/16, patchSize[1]/16] 86 | 87 | self.discriminator = nn.Sequential( 88 | # 128-->60 89 | nn.Conv2d(self.input_nc, 256, kernel_size=3, padding=0, stride=1, bias=True), 90 | nn.LeakyReLU(0.2, inplace=True), 91 | 92 | # 60-->33 93 | nn.Conv2d(256, 128, kernel_size=3, padding=0, stride=1, bias=False), 94 | nn.InstanceNorm2d(128, momentum=0.001, affine=False, track_running_stats=False), 95 | nn.LeakyReLU(0.2, inplace=True), 96 | # 33-> 97 | nn.Conv2d(128, 64, kernel_size=3, padding=0, stride=1, bias=False), 98 | nn.InstanceNorm2d(64, momentum=0.001, affine=False, track_running_stats=False), 99 | nn.LeakyReLU(0.2, inplace=True), 100 | 101 | # final classification for 'real(1) vs. fake(0)' 102 | nn.Conv2d(64, 1, kernel_size=1, padding=0, stride=1, bias=True), 103 | ) 104 | 105 | def forward(self, X): 106 | return self.discriminator(X) 107 | 108 | class DiscriminatorGlobalLocal(nn.Module): 109 | """Discriminator. PatchGAN.""" 110 | def __init__(self, image_size=128, bbox_size = 64, conv_dim=64, c_dim=5, repeat_num_global=6, repeat_num_local=5, nc=3): 111 | super(DiscriminatorGlobalLocal, self).__init__() 112 | 113 | maxFilt = 512 if image_size==128 else 128 114 | globalLayers = [] 115 | globalLayers.append(nn.Conv2d(nc, conv_dim, kernel_size=4, stride=2, padding=1,bias=False)) 116 | globalLayers.append(nn.LeakyReLU(0.2, inplace=True)) 117 | 118 | localLayers = [] 119 | localLayers.append(nn.Conv2d(nc, conv_dim, kernel_size=4, stride=2, padding=1, bias=False)) 120 | localLayers.append(nn.LeakyReLU(0.2, inplace=True)) 121 | 122 | curr_dim = conv_dim 123 | for i in range(1, repeat_num_global): 124 | globalLayers.append(nn.Conv2d(curr_dim, min(curr_dim*2,maxFilt), kernel_size=4, stride=2, padding=1, bias=False)) 125 | globalLayers.append(nn.LeakyReLU(0.2, inplace=True)) 126 | curr_dim = min(curr_dim * 2, maxFilt) 127 | 128 | curr_dim = conv_dim 129 | for i in range(1, repeat_num_local): 130 | localLayers.append(nn.Conv2d(curr_dim, min(curr_dim * 2, maxFilt), kernel_size=4, stride=2, padding=1, bias=False)) 131 | localLayers.append(nn.LeakyReLU(0.2, inplace=True)) 132 | curr_dim = min(curr_dim * 2, maxFilt) 133 | 134 | k_size_local = int(bbox_size/ np.power(2, repeat_num_local)) 135 | k_size_global = int(image_size/ np.power(2, repeat_num_global)) 136 | 137 | self.mainGlobal = nn.Sequential(*globalLayers) 138 | self.mainLocal = nn.Sequential(*localLayers) 139 | 140 | # FC 1 for doing real/fake 141 | # self.fc1 = nn.Linear(curr_dim*(k_size_local**2+k_size_global**2), 1, bias=False) 142 | self.fc1 = nn.Linear(10880, 1, bias=False) 143 | 144 | # FC 2 for doing classification only on local patch 145 | if c_dim > 0: 146 | self.fc2 = nn.Linear(curr_dim*(k_size_local**2), c_dim, bias=False) 147 | else: 148 | self.fc2 = None 149 | 150 | def forward(self, x, boxImg, classify=False): 151 | bsz = x.size(0) 152 | h_global = self.mainGlobal(x) 153 | h_local = self.mainLocal(boxImg) 154 | h_append = torch.cat([h_global.view(bsz,-1), h_local.view(bsz,-1)], dim=-1) 155 | out_rf = self.fc1(h_append) 156 | out_cls = self.fc2(h_local.view(bsz,-1)) if classify and (self.fc2 is not None) else None 157 | return out_rf.squeeze(), out_cls, h_append -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import random, time, copy 3 | import argparse 4 | import torch 5 | from torch.utils.data import Dataset, DataLoader 6 | 7 | from dataloader.NYUv2_dataLoader import NYUv2_dataLoader 8 | from dataloader.Joint_xLabel_dataLoader import Joint_xLabel_train_dataLoader 9 | 10 | # step 1: Train Initial Depth Predictor D 11 | # from training.train_initial_depth_predictor_D import train_initial_depth_predictor_D as train_model 12 | 13 | # step 2: Train Style Translator T (pre-train T) 14 | # from training.train_style_translator_T import train_style_translator_T as train_model 15 | 16 | # step 3: Train Initial Attention Module A 17 | # from training.train_initial_attention_module_A import train_initial_attention_module_A as train_model 18 | 19 | # step 4: Train Inpainting Module I (pre-train I) 20 | # from training.train_inpainting_module_I import train_inpainting_module_I as train_model 21 | 22 | # step 5: Jointly Train Depth Predictor D and Attention Module A (pre-train A, D) 23 | # from training.jointly_train_depth_predictor_D_and_attention_module_A import jointly_train_depth_predictor_D_and_attention_module_A as train_model 24 | 25 | # step 6: Finetune the Whole System with Depth Loss (Modular Coordinate Descent) 26 | from training.finetune_the_whole_system_with_depth_loss import finetune_the_whole_system_with_depth_loss as train_model 27 | 28 | import warnings # ignore warnings 29 | warnings.filterwarnings("ignore") 30 | 31 | print(sys.version) 32 | print(torch.__version__) 33 | 34 | ################## set attributes for this project/experiment ################## 35 | 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--exp_dir', type=str, default=os.path.join(os.getcwd(), 'experiments'), 38 | help='place to store all experiments') 39 | parser.add_argument('--project_name', type=str, help='Test Project') 40 | parser.add_argument('--path_to_NYUv2', type=str, default='your absolute path to NYUv2 data', 41 | help='absolute dir of NYUv2 dataset') 42 | parser.add_argument('--path_to_PBRS', type=str, default='your absolute path to PBRS data', 43 | help='absolute dir of PBRS dataset') 44 | parser.add_argument('--isTrain', action='store_true', help='whether this is training phase') 45 | parser.add_argument('--batch_size', type=int, default=16, help='batch size') 46 | parser.add_argument('--eval_batch_size', type=int, default=1, help='batch size') 47 | parser.add_argument('--cropSize', type=list, default=[240, 320] , help='size of samples in experiments') 48 | parser.add_argument('--total_epoch_num', type=int, default=50, help='total number of epoch') 49 | parser.add_argument('--device', type=str, default='cpu', help='whether running on gpu') 50 | parser.add_argument('--num_workers', type=int, default=4, help='number of workers in dataLoaders') 51 | args = parser.parse_args() 52 | 53 | if torch.cuda.is_available(): 54 | args.device='cuda' 55 | torch.cuda.empty_cache() 56 | 57 | # here only for evaluation purpose 58 | datasets_nyuv2 = {set_name: NYUv2_dataLoader(root_dir=args.path_to_NYUv2, set_name=set_name, size=args.cropSize, rgb=True) 59 | for set_name in ['train', 'test']} 60 | dataloaders_nyuv2 = {set_name: DataLoader(datasets_nyuv2[set_name], 61 | batch_size=args.batch_size if set_name=='train' else args.eval_batch_size, 62 | shuffle=set_name=='train', 63 | drop_last=set_name=='train', 64 | num_workers=args.num_workers) 65 | for set_name in ['train', 'test']} 66 | 67 | # for training purpose 68 | datasets_xLabels_joint = Joint_xLabel_train_dataLoader(real_root_dir=args.path_to_NYUv2, syn_root_dir=args.path_to_PBRS, paired_data=False) 69 | dataloaders_xLabels_joint = DataLoader(datasets_xLabels_joint, 70 | batch_size=args.batch_size, 71 | shuffle=True, 72 | drop_last=True, 73 | num_workers=args.num_workers) 74 | 75 | model = train_model(args, dataloaders_xLabels_joint, dataloaders_nyuv2) 76 | 77 | if args.isTrain: 78 | model.train() 79 | model.evaluate(mode='best') 80 | else: 81 | model.evaluate(mode='best') -------------------------------------------------------------------------------- /training/base_model.py: -------------------------------------------------------------------------------- 1 | import os, copy, torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.optim import lr_scheduler 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | from collections import OrderedDict 9 | 10 | import torchvision 11 | from torchvision import datasets, models, transforms 12 | from torchvision.utils import make_grid 13 | from tensorboardX import SummaryWriter 14 | 15 | from utils.metrics import * 16 | 17 | try: 18 | from apex import amp 19 | except ImportError: 20 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run with apex.") 21 | 22 | import torch.multiprocessing as mp 23 | 24 | def set_requires_grad(nets, requires_grad=False): 25 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 26 | Parameters: 27 | nets (network list) -- a list of networks 28 | requires_grad (bool) -- whether the networks require gradients or not 29 | """ 30 | if not isinstance(nets, list): 31 | nets = [nets] 32 | for net in nets: 33 | if net is not None: 34 | for param in net.parameters(): 35 | param.requires_grad = requires_grad 36 | 37 | 38 | def apply_scheduler(optimizer, lr_policy, num_epoch=None, total_num_epoch=None): 39 | if lr_policy == 'linear': 40 | # num_epoch with initial lr 41 | # rest of epoch linearly decrease to 0 (the last epoch is not 0) 42 | def lambda_rule(epoch): 43 | # lr_l = 1.0 - max(0, epoch + 1 + epoch_count - niter) / float(niter_decay + 1) 44 | lr_l = 1.0 - max(0, epoch + 1 - num_epoch) / float(total_num_epoch - num_epoch + 1) 45 | return lr_l 46 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 47 | elif lr_policy == 'step': 48 | scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) 49 | elif lr_policy == 'plateau': 50 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 51 | else: 52 | return NotImplementedError('learning rate policy [%s] is not implemented', lr_policy) 53 | return scheduler 54 | 55 | class base_model(nn.Module): 56 | def __init__(self, args): 57 | super(base_model, self).__init__() 58 | self.device = args.device 59 | self.isTrain = args.isTrain 60 | self.project_name = args.project_name 61 | self.exp_dir = args.exp_dir 62 | 63 | self.use_tensorboardX = True 64 | self.use_apex = True 65 | 66 | self.cropSize = args.cropSize # patch size for training the model. Default: [240, 320] 67 | self.cropSize_h, self.cropSize_w = self.cropSize[0], self.cropSize[1] 68 | self.batch_size = args.batch_size 69 | self.total_epoch_num = args.total_epoch_num # total number of epoch in training 70 | self.save_steps = 5 71 | self.task_lr = 1e-4 # default task learning rate 72 | self.D_lr = 5e-5 # default discriminator learning rate 73 | self.G_lr = 5e-5 # default generator learning rate 74 | self.real_label = 1 75 | self.syn_label = 0 76 | 77 | def _initialize_training(self): 78 | if self.project_name is not None: 79 | self.save_dir = os.path.join(self.exp_dir, self.project_name) 80 | else: 81 | self.project_name = self._get_project_name() 82 | self.save_dir = os.path.join(self.exp_dir, self.project_name) 83 | print('project name: {}'.format(self.project_name)) 84 | print('save dir: {}'.format(self.save_dir)) 85 | if not os.path.exists(self.save_dir): os.makedirs(self.save_dir) 86 | 87 | self.train_log = os.path.join(self.save_dir, 'train.log') 88 | self.evaluate_log = os.path.join(self.save_dir, 'evaluate.log') 89 | self.file_to_note_bestModel = os.path.join(self.save_dir,'note_bestModel.log') 90 | 91 | if self.use_tensorboardX: 92 | self.tensorboard_train_dir = os.path.join(self.save_dir, 'tensorboardX_train_logs') 93 | self.train_SummaryWriter = SummaryWriter(self.tensorboard_train_dir) 94 | 95 | self.tensorboard_eval_dir = os.path.join(self.save_dir, 'tensorboardX_eval_logs') 96 | self.eval_SummaryWriter = SummaryWriter(self.tensorboard_eval_dir) 97 | 98 | # self.train_display_freq = 500 99 | # self.val_write_freq = 10 100 | self.tensorboard_num_display_per_epoch = 5 101 | self.val_display_freq = 10 102 | 103 | def _initialize_networks(self): 104 | for name, model in self.model_dict.items(): 105 | model.train().to(self.device) 106 | init_weights(model, net_name=name, init_type='normal', gain=0.02) 107 | 108 | def _get_scheduler(self, optim_type='linear'): 109 | ''' 110 | if type is None -> all optim use default scheduler 111 | if types is str -> all optim use this types of scheduler 112 | if type is list -> each optim use their own scheduler 113 | ''' 114 | self.scheduler_list = [] 115 | if isinstance(optim_type, str): 116 | for name in self.optim_name: 117 | self.scheduler_list.append(apply_scheduler(getattr(self, name), lr_policy=optim_type, num_epoch=0.6*self.total_epoch_num, 118 | total_num_epoch=self.total_epoch_num)) 119 | elif isinstance(optim_type, list): 120 | for name, optim in zip(self.optim_name, optim_type): 121 | self.scheduler_list.append(apply_scheduler(getattr(self, name), lr_policy=optim, num_epoch=0.6*self.total_epoch_num, 122 | total_num_epoch=self.total_epoch_num)) 123 | else: 124 | raise RuntimeError("optim type should be either string or list!") 125 | 126 | def _init_apex(self, Num_losses): 127 | model_list = [] 128 | optim_list = [] 129 | for m in self.model_name: 130 | model_list.append(getattr(self, m)) 131 | for o in self.optim_name: 132 | optim_list.append(getattr(self, o)) 133 | model_list, optim_list = amp.initialize(model_list, optim_list, opt_level="O1", num_losses=Num_losses) 134 | 135 | def _check_parallel(self): 136 | if torch.cuda.device_count() > 1: 137 | for name in self.model_name: 138 | setattr(self, name, nn.DataParallel(getattr(self, name))) 139 | 140 | def _check_distribute(self): 141 | # not ready to use yet 142 | if torch.cuda.device_count() > 1: 143 | # world size is number of process participat in the job 144 | # torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...') 145 | # mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 146 | if use_apex: 147 | setattr(self, name, apex.parallel.DistributedDataParallel(getattr(self, name))) 148 | else: 149 | for name in self.model_name: 150 | setattr(self, name, nn.DistributedDataParallel(getattr(self, name))) 151 | 152 | def _set_models_train(self, model_name): 153 | for name in model_name: 154 | getattr(self, name).train() 155 | 156 | def _set_models_eval(self, model_name): 157 | for name in model_name: 158 | getattr(self, name).eval() 159 | 160 | def _set_models_float(self, model_name): 161 | for name in model_name: 162 | for layers in getattr(self, name).modules(): 163 | layers.float() 164 | 165 | def save_models(self, model_list, mode, save_list=None): 166 | ''' 167 | mode include best, latest, or a number (epoch) 168 | save as non-dataparallel state_dict 169 | save_list is used when we save model as a different name for later use 170 | ''' 171 | if not save_list: 172 | for model_name in model_list: 173 | if mode == 'latest': 174 | path_to_save_paramOnly = os.path.join(self.save_dir, 'latest_{}.pth'.format(model_name)) 175 | elif mode == 'best': 176 | path_to_save_paramOnly = os.path.join(self.save_dir, 'best_{}.pth'.format(model_name)) 177 | elif isinstance(mode, int): 178 | path_to_save_paramOnly = os.path.join(self.save_dir, 'epoch-{}_{}.pth'.format(str(mode), model_name)) 179 | 180 | try: 181 | state_dict = getattr(self, model_name).module.state_dict() 182 | except AttributeError: 183 | state_dict = getattr(self, model_name).state_dict() 184 | 185 | model_weights = copy.deepcopy(state_dict) 186 | torch.save(model_weights, path_to_save_paramOnly) 187 | else: 188 | assert len(model_list) == len(save_list) 189 | for save_name, model_name in zip(save_list, model_list): 190 | if mode == 'latest': 191 | path_to_save_paramOnly = os.path.join(self.save_dir, 'latest_{}.pth'.format(save_name)) 192 | elif mode == 'best': 193 | path_to_save_paramOnly = os.path.join(self.save_dir, 'best_{}.pth'.format(save_name)) 194 | elif isinstance(mode, int): 195 | path_to_save_paramOnly = os.path.join(self.save_dir, 'epoch-{}_{}.pth'.format(str(mode), save_name)) 196 | 197 | try: 198 | state_dict = getattr(self, model_name).module.state_dict() 199 | except AttributeError: 200 | state_dict = getattr(self, model_name).state_dict() 201 | 202 | model_weights = copy.deepcopy(state_dict) 203 | torch.save(model_weights, path_to_save_paramOnly) 204 | 205 | def _load_models(self, model_list, mode, isTrain=False, model_path=None): 206 | if model_path is None: 207 | model_path = self.save_dir 208 | 209 | for model_name in model_list: 210 | if mode == 'latest': 211 | path = os.path.join(model_path, 'latest_{}.pth'.format(model_name)) 212 | elif mode == 'best': 213 | path = os.path.join(model_path, 'best_{}.pth'.format(model_name)) 214 | elif isinstance(mode, int): 215 | path = os.path.join(model_path, 'epoch-{}_{}.pth'.format(str(mode), model_name)) 216 | else: 217 | raise RuntimeError("Mode not implemented") 218 | 219 | state_dict = torch.load(path) 220 | 221 | try: 222 | getattr(self, model_name).load_state_dict(state_dict) 223 | except RuntimeError: 224 | # in the case of parallel model loading non-parallel state_dict || add module to all keys 225 | new_state_dict = OrderedDict() 226 | for k, v in state_dict.items(): 227 | name = 'module.' + k # add `module.` 228 | new_state_dict[name] = v 229 | 230 | getattr(self, model_name).load_state_dict(new_state_dict) 231 | 232 | if isTrain: 233 | getattr(self, model_name).to(self.device).train() 234 | else: 235 | getattr(self, model_name).to(self.device).eval() 236 | 237 | def save_tensor2np(self, tensor, name, epoch, path=None): 238 | # not ready to use in this project 239 | if path == None: 240 | path = self.save_dir 241 | generated_sample = tensor.detach().cpu().numpy() 242 | generated_sample_save_path = os.path.join(path, 'tensor2np', 'Epoch-%s_%s.npy' % (epoch, name)) 243 | if not os.path.exists(os.path.join(path, 'tensor2np')): 244 | os.makedirs(os.path.join(path, 'tensor2np')) 245 | np.save(generated_sample_save_path, generated_sample) 246 | 247 | def write_2_tensorboardX(self, writer, input_tensor, name, mode, count, nrow=None, normalize=True, value_range=(-1.0, 1.0)): 248 | if mode == 'image': 249 | if not nrow: 250 | raise RuntimeError('tensorboardX: must specify number of rows in image mode') 251 | grid = make_grid(input_tensor, nrow=nrow, normalize=normalize, range=value_range) 252 | writer.add_image(name, grid, count) 253 | elif mode == 'scalar': 254 | if isinstance(input_tensor, list) and isinstance(name, list): 255 | assert len(input_tensor) == len(name) 256 | for n, t in zip(name, input_tensor): 257 | writer.add_scalar(n, t, count) 258 | else: 259 | writer.add_scalar(name, input_tensor, count) 260 | else: 261 | raise RuntimeError('tensorboardX: this mode is not yet implemented') 262 | 263 | -------------------------------------------------------------------------------- /training/finetune_the_whole_system_with_depth_loss.py: -------------------------------------------------------------------------------- 1 | import os, time, sys 2 | import random 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torch.optim import lr_scheduler 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | 11 | import torchvision 12 | from torchvision import datasets, models, transforms 13 | from torchvision.utils import make_grid 14 | from tensorboardX import SummaryWriter 15 | 16 | from models.depth_generator_networks import _UNetGenerator, init_weights, _ResGenerator_Upsample 17 | from models.discriminator_networks import Discriminator80x80InstNorm 18 | from models.attention_networks import _Attention_FullRes 19 | 20 | from utils.metrics import * 21 | from utils.image_pool import ImagePool 22 | 23 | from training.base_model import set_requires_grad, base_model 24 | 25 | try: 26 | from apex import amp 27 | except ImportError: 28 | print("\nPlease consider install apex from https://www.github.com/nvidia/apex to run with apex or set use_apex = False\n") 29 | 30 | import warnings # ignore warnings 31 | warnings.filterwarnings("ignore") 32 | 33 | class finetune_the_whole_system_with_depth_loss(base_model): 34 | def __init__(self, args, dataloaders_xLabels_joint, dataloaders_single): 35 | super(finetune_the_whole_system_with_depth_loss, self).__init__(args) 36 | self._initialize_training() 37 | # self.KITTI_MAX_DEPTH_CLIP = 80.0 38 | # self.EVAL_DEPTH_MIN = 1.0 39 | # self.EVAL_DEPTH_MAX = 50.0 40 | 41 | self.NYU_MAX_DEPTH_CLIP = 10.0 42 | self.EVAL_DEPTH_MIN = 1.0 43 | self.EVAL_DEPTH_MAX = 8.0 44 | 45 | self.dataloaders_single = dataloaders_single 46 | self.dataloaders_xLabels_joint = dataloaders_xLabels_joint 47 | 48 | self.tensorboard_num_display_per_epoch = 1 49 | 50 | self.attModule = _Attention_FullRes(input_nc = 3, output_nc = 1) 51 | self.inpaintNet = _ResGenerator_Upsample(input_nc = 3, output_nc = 3) 52 | self.styleTranslator = _ResGenerator_Upsample(input_nc = 3, output_nc = 3) 53 | self.depthEstModel = _UNetGenerator(input_nc = 3, output_nc = 1) 54 | 55 | self.tau_min = 0.05 56 | self.model_name = ['attModule', 'inpaintNet', 'styleTranslator', 'depthEstModel'] 57 | self.L1loss = nn.L1Loss() 58 | 59 | if self.isTrain: 60 | self.optim_depth = optim.Adam(list(self.depthEstModel.parameters()) + list(self.inpaintNet.parameters()) + list(self.styleTranslator.parameters()), lr=self.task_lr, betas=(0.5, 0.999)) 61 | self.optim_name = ['optim_depth'] 62 | self._get_scheduler() 63 | self.loss_BCE = nn.BCEWithLogitsLoss() 64 | 65 | # load the "best" depth predictor D (from step 5) 66 | preTrain_path = os.path.join(os.getcwd(), 'experiments', 'jointly_train_depth_predictor_D_and_attention_module_A') 67 | self._load_models(model_list=['depthEstModel'], mode='best', isTrain=True, model_path=preTrain_path) 68 | print('Successfully loaded pre-trained {} model from {}'.format('depthEstModel', preTrain_path)) 69 | 70 | # load the "best" style translator T (from step 2) 71 | preTrain_path = os.path.join(os.getcwd(), 'experiments', 'train_style_translator_T') 72 | self._load_models(model_list=['styleTranslator'], mode=480, isTrain=True, model_path=preTrain_path) 73 | print('Successfully loaded pre-trained {} model from {}'.format('styleTranslator', preTrain_path)) 74 | 75 | # load the "best" attention module A (from step 5) 76 | preTrain_path = os.path.join(os.getcwd(), 'experiments', 'jointly_train_depth_predictor_D_and_attention_module_A') 77 | self._load_models(model_list=['attModule'], mode='best', isTrain=True, model_path=preTrain_path) 78 | print('Successfully loaded pre-trained {} model from {}'.format('attModule', preTrain_path)) 79 | 80 | # load the "best" inpainting module I (from step 4) 81 | preTrain_path = os.path.join(os.getcwd(), 'experiments', 'train_inpainting_module_I') 82 | self._load_models(model_list=['inpaintNet'], mode=450, isTrain=True, model_path=preTrain_path) 83 | print('Successfully loaded pre-trained {} model from {}'.format('inpaintNet', preTrain_path)) 84 | 85 | # apex can only be applied to CUDA models 86 | if self.use_apex: 87 | self._init_apex(Num_losses=2) 88 | 89 | self.EVAL_best_loss = float('inf') 90 | self.EVAL_best_model_epoch = 0 91 | self.EVAL_all_results = {} 92 | 93 | self._check_parallel() 94 | 95 | def _get_project_name(self): 96 | return 'finetune_the_whole_system_with_depth_loss' 97 | 98 | def _initialize_networks(self, model_name): 99 | for name in model_name: 100 | getattr(self, name).train().to(self.device) 101 | init_weights(getattr(self, name), net_name=name, init_type='normal', gain=0.02) 102 | 103 | def compute_D_loss(self, real_sample, fake_sample, netD): 104 | loss = 0 105 | syn_acc = 0 106 | real_acc = 0 107 | 108 | output = netD(fake_sample) 109 | label = torch.full((output.size()), self.syn_label, device=self.device) 110 | predSyn = (output > 0.5).to(self.device, dtype=torch.float32) 111 | total_num = torch.numel(output) 112 | syn_acc += (predSyn==label).type(torch.float32).sum().item()/total_num 113 | 114 | loss += self.loss_BCE(output, label) 115 | 116 | output = netD(real_sample) 117 | label = torch.full((output.size()), self.real_label, device=self.device) 118 | predReal = (output > 0.5).to(self.device, dtype=torch.float32) 119 | real_acc += (predReal==label).type(torch.float32).sum().item()/total_num 120 | 121 | loss += self.loss_BCE(output, label) 122 | 123 | return loss, syn_acc, real_acc 124 | 125 | def compute_depth_loss(self, input_rgb, depth_label, depthEstModel, valid_mask=None): 126 | 127 | prediction = depthEstModel(input_rgb)[-1] 128 | if valid_mask is not None: 129 | loss = self.L1loss(prediction[valid_mask], depth_label[valid_mask]) 130 | else: 131 | assert valid_mask == None 132 | loss = self.L1loss(prediction, depth_label) 133 | 134 | return loss 135 | 136 | def compute_spare_attention(self, confident_score, t, isTrain=True): 137 | # t is the temperature --> scalar 138 | if isTrain: 139 | noise = torch.rand(confident_score.size(), requires_grad=False).to(self.device) 140 | noise = (noise + 0.00001) / 1.001 141 | noise = - torch.log(- torch.log(noise)) 142 | 143 | confident_score = (confident_score + 0.00001) / 1.001 144 | confident_score = (confident_score + noise) / t 145 | else: 146 | confident_score = confident_score / t 147 | 148 | confident_score = F.sigmoid(confident_score) 149 | 150 | return confident_score 151 | 152 | def train(self): 153 | phase = 'train' 154 | since = time.time() 155 | best_loss = float('inf') 156 | set_requires_grad(self.attModule, requires_grad=False) # freeze attention module A 157 | 158 | tensorboardX_iter_count = 0 159 | for epoch in range(self.total_epoch_num): 160 | print('\nEpoch {}/{}'.format(epoch+1, self.total_epoch_num)) 161 | print('-' * 10) 162 | fn = open(self.train_log,'a') 163 | fn.write('\nEpoch {}/{}\n'.format(epoch+1, self.total_epoch_num)) 164 | fn.write('--'*5+'\n') 165 | fn.close() 166 | 167 | self._set_models_train(['attModule', 'inpaintNet', 'styleTranslator', 'depthEstModel']) 168 | iterCount,sampleCount = 0, 0 169 | 170 | for sample_dict in self.dataloaders_xLabels_joint: 171 | imageListReal, depthListReal = sample_dict['real'] 172 | imageListSyn, depthListSyn = sample_dict['syn'] 173 | 174 | imageListSyn = imageListSyn.to(self.device) 175 | depthListSyn = depthListSyn.to(self.device) 176 | imageListReal = imageListReal.to(self.device) 177 | depthListReal = depthListReal.to(self.device) 178 | valid_mask = (depthListReal > -1.) 179 | 180 | B, C, H, W = imageListReal.size()[0], imageListReal.size()[1], imageListReal.size()[2], imageListReal.size()[3] 181 | 182 | with torch.set_grad_enabled(phase=='train'): 183 | r2s_img = self.styleTranslator(imageListReal)[-1] 184 | confident_score = self.attModule(imageListReal)[-1] 185 | # convert to sparse confident score 186 | confident_score = self.compute_spare_attention(confident_score, t=self.tau_min, isTrain=False) 187 | # hard threshold 188 | confident_score[confident_score < 0.5] = 0. 189 | confident_score[confident_score >= 0.5] = 1. 190 | 191 | mod_r2s_img = r2s_img * confident_score 192 | inpainted_r2s = self.inpaintNet(mod_r2s_img)[-1] 193 | 194 | reconst_img = inpainted_r2s * (1. - confident_score) + confident_score * r2s_img 195 | 196 | # update 197 | self.optim_depth.zero_grad() 198 | total_loss = 0. 199 | inpainted_depth_loss = self.compute_depth_loss(reconst_img, depthListReal, self.depthEstModel, valid_mask) 200 | # add translated image to finetune the whole system might gives better results (normally could also be commented out) 201 | translated_depth_loss = self.compute_depth_loss(r2s_img, depthListReal, self.depthEstModel, valid_mask) 202 | syn_depth_loss = self.compute_depth_loss(imageListSyn, depthListSyn, self.depthEstModel) 203 | total_loss += (inpainted_depth_loss + translated_depth_loss + syn_depth_loss) 204 | if self.use_apex: 205 | with amp.scale_loss(total_loss, self.optim_depth, loss_id=0) as total_loss_scaled: 206 | total_loss_scaled.backward() 207 | else: 208 | total_loss.backward() 209 | 210 | self.optim_depth.step() 211 | 212 | iterCount += 1 213 | 214 | if self.use_tensorboardX: 215 | nrow = imageListReal.size()[0] 216 | self.train_display_freq = len(self.dataloaders_xLabels_joint) # feel free to adjust the display frequency 217 | if tensorboardX_iter_count % self.train_display_freq == 0: 218 | img_concat = torch.cat((imageListReal, r2s_img, mod_r2s_img, inpainted_r2s, reconst_img), dim=0) 219 | self.write_2_tensorboardX(self.train_SummaryWriter, img_concat, name='real, r2s, r2sMasked, inpaintedR2s, reconst', mode='image', 220 | count=tensorboardX_iter_count, nrow=nrow) 221 | 222 | self.write_2_tensorboardX(self.train_SummaryWriter, confident_score, name='Attention', mode='image', 223 | count=tensorboardX_iter_count, nrow=nrow, value_range=(0., 1.0)) 224 | 225 | # add loss values 226 | loss_val_list = [total_loss, inpainted_depth_loss, translated_depth_loss, syn_depth_loss] 227 | loss_name_list = ['total_loss', 'inpainted_depth_loss', 'translated_depth_loss', 'syn_depth_loss'] 228 | self.write_2_tensorboardX(self.train_SummaryWriter, loss_val_list, name=loss_name_list, mode='scalar', count=tensorboardX_iter_count) 229 | 230 | tensorboardX_iter_count += 1 231 | 232 | if iterCount % 20 == 0: 233 | loss_summary = '\t{}/{}, total_loss: {:.7f}, inpainted_depth_loss: {:.7f}, translated_depth_loss: {:.7f}, syn_depth_loss: {:.7f}'.format( 234 | iterCount, len(self.dataloaders_xLabels_joint), total_loss, inpainted_depth_loss, translated_depth_loss, syn_depth_loss) 235 | 236 | print(loss_summary) 237 | 238 | fn = open(self.train_log,'a') 239 | fn.write(loss_summary + '\n') 240 | fn.close() 241 | 242 | # take step in optimizer 243 | for scheduler in self.scheduler_list: 244 | scheduler.step() 245 | for optim in self.optim_name: 246 | lr = getattr(self, optim).param_groups[0]['lr'] 247 | lr_update = 'Epoch {}/{} finished: {} learning rate = {:.7f}'.format(epoch+1, self.total_epoch_num, optim, lr) 248 | print(lr_update) 249 | fn = open(self.train_log,'a') 250 | fn.write(lr_update) 251 | fn.close() 252 | 253 | if (epoch+1) % self.save_steps == 0: 254 | self.save_models(self.model_name, mode=epoch+1) 255 | self.evaluate(epoch+1) 256 | 257 | time_elapsed = time.time() - since 258 | print('\nTraining complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 259 | 260 | fn = open(self.train_log,'a') 261 | fn.write('\nTraining complete in {:.0f}m {:.0f}s\n'.format(time_elapsed // 60, time_elapsed % 60)) 262 | fn.close() 263 | 264 | best_model_summary = '\nOverall best model is epoch {}'.format(self.EVAL_best_model_epoch) 265 | print(best_model_summary) 266 | print(self.EVAL_all_results[str(self.EVAL_best_model_epoch)]) 267 | fn = open(self.evaluate_log, 'a') 268 | fn.write(best_model_summary + '\n') 269 | fn.write(self.EVAL_all_results[str(self.EVAL_best_model_epoch)]) 270 | fn.close() 271 | 272 | def evaluate(self, mode): 273 | ''' 274 | mode choose from or best 275 | is the number of epoch, represents the number of epoch, used for in training evaluation 276 | 'best' is used for after training mode 277 | ''' 278 | set_name = 'test' 279 | eval_model_list = ['attModule', 'inpaintNet', 'styleTranslator', 'depthEstModel'] 280 | 281 | if isinstance(mode, int) and self.isTrain: 282 | self._set_models_eval(eval_model_list) 283 | if self.EVAL_best_loss == float('inf'): 284 | fn = open(self.evaluate_log, 'w') 285 | else: 286 | fn = open(self.evaluate_log, 'a') 287 | 288 | fn.write('Evaluating with mode: {}\n'.format(mode)) 289 | fn.write('\tEvaluation range min: {} | max: {} \n'.format(self.EVAL_DEPTH_MIN, self.EVAL_DEPTH_MAX)) 290 | fn.close() 291 | 292 | else: 293 | self._load_models(eval_model_list, mode) 294 | 295 | print('Evaluating with mode: {}'.format(mode)) 296 | print('\tEvaluation range min: {} | max: {}'.format(self.EVAL_DEPTH_MIN, self.EVAL_DEPTH_MAX)) 297 | 298 | total_loss, count = 0., 0 299 | predTensor = torch.zeros((1, 1, self.cropSize_h, self.cropSize_w)).to('cpu') 300 | grndTensor = torch.zeros((1, 1, self.cropSize_h, self.cropSize_w)).to('cpu') 301 | idx = 0 302 | 303 | tensorboardX_iter_count = 0 304 | for sample in self.dataloaders_single[set_name]: 305 | imageList, depthList = sample 306 | valid_mask = np.logical_and(depthList > self.EVAL_DEPTH_MIN, depthList < self.EVAL_DEPTH_MAX) 307 | 308 | idx += imageList.shape[0] 309 | print('epoch {}: have processed {} number samples in {} set'.format(mode, str(idx), set_name)) 310 | imageList = imageList.to(self.device) 311 | depthList = depthList.to(self.device) 312 | 313 | if self.isTrain and self.use_apex: 314 | with amp.disable_casts(): 315 | r2s_img = self.styleTranslator(imageList)[-1] 316 | confident_score = self.attModule(imageList)[-1] 317 | # convert to sparse confident score 318 | confident_score = self.compute_spare_attention(confident_score, t=self.tau_min, isTrain=False) 319 | # hard threshold 320 | confident_score[confident_score < 0.5] = 0. 321 | confident_score[confident_score >= 0.5] = 1. 322 | mod_r2s_img = r2s_img * confident_score 323 | inpainted_r2s = self.inpaintNet(mod_r2s_img)[-1] 324 | reconst_img = inpainted_r2s * (1. - confident_score) + confident_score * r2s_img 325 | predList = self.depthEstModel(reconst_img)[-1].detach().to('cpu') 326 | 327 | else: 328 | r2s_img = self.styleTranslator(imageList)[-1] 329 | confident_score = self.attModule(imageList)[-1] 330 | # convert to sparse confident score 331 | confident_score = self.compute_spare_attention(confident_score, t=self.tau_min, isTrain=False) 332 | # hard threshold 333 | confident_score[confident_score < 0.5] = 0. 334 | confident_score[confident_score >= 0.5] = 1. 335 | mod_r2s_img = r2s_img * confident_score 336 | inpainted_r2s = self.inpaintNet(mod_r2s_img)[-1] 337 | reconst_img = inpainted_r2s * (1. - confident_score) + confident_score * r2s_img 338 | predList = self.depthEstModel(reconst_img)[-1].detach().to('cpu') 339 | 340 | # recover real depth 341 | predList = (predList + 1.0) * 0.5 * self.NYU_MAX_DEPTH_CLIP 342 | depthList = depthList.detach().to('cpu') 343 | predTensor = torch.cat((predTensor, predList), dim=0) 344 | grndTensor = torch.cat((grndTensor, depthList), dim=0) 345 | 346 | if self.use_tensorboardX: 347 | nrow = imageList.size()[0] 348 | if tensorboardX_iter_count % self.val_display_freq == 0: 349 | depth_concat = torch.cat((depthList, predList), dim=0) 350 | self.write_2_tensorboardX(self.eval_SummaryWriter, depth_concat, name='{}: ground truth and depth prediction'.format(set_name), 351 | mode='image', count=tensorboardX_iter_count, nrow=nrow, value_range=(0.0, self.NYU_MAX_DEPTH_CLIP)) 352 | 353 | tensorboardX_iter_count += 1 354 | 355 | if isinstance(mode, int) and self.isTrain: 356 | eval_depth_loss = self.L1loss(predList[valid_mask], depthList[valid_mask]) 357 | total_loss += eval_depth_loss.detach().cpu() 358 | 359 | count += 1 360 | 361 | if isinstance(mode, int) and self.isTrain: 362 | validation_loss = (total_loss / count) 363 | print('validation loss is {:.7f}'.format(validation_loss)) 364 | if self.use_tensorboardX: 365 | self.write_2_tensorboardX(self.eval_SummaryWriter, validation_loss, name='validation loss', mode='scalar', count=mode) 366 | 367 | results = Result(mask_min=self.EVAL_DEPTH_MIN, mask_max=self.EVAL_DEPTH_MAX) 368 | results.evaluate(predTensor[1:], grndTensor[1:]) 369 | 370 | result1 = '\tabs_rel:{:.3f}, sq_rel:{:.3f}, rmse:{:.3f}, rmse_log:{:.3f}, mae:{:.3f} '.format( 371 | results.absrel,results.sqrel,results.rmse,results.rmselog,results.mae) 372 | result2 = '\t[<1.25]:{:.3f}, [<1.25^2]:{:.3f}, [<1.25^3]::{:.3f}'.format(results.delta1,results.delta2,results.delta3) 373 | 374 | print(result1) 375 | print(result2) 376 | 377 | if isinstance(mode, int) and self.isTrain: 378 | self.EVAL_all_results[str(mode)] = result1 + '\t' + result2 379 | 380 | if validation_loss.item() < self.EVAL_best_loss: 381 | self.EVAL_best_loss = validation_loss.item() 382 | self.EVAL_best_model_epoch = mode 383 | self.save_models(self.model_name, mode='best') 384 | 385 | best_model_summary = '\tCurrent best loss {:.7f}, current best model {}\n'.format(self.EVAL_best_loss, self.EVAL_best_model_epoch) 386 | print(best_model_summary) 387 | 388 | fn = open(self.evaluate_log, 'a') 389 | fn.write(result1 + '\n') 390 | fn.write(result2 + '\n') 391 | fn.write(best_model_summary + '\n') 392 | fn.close() -------------------------------------------------------------------------------- /training/jointly_train_depth_predictor_D_and_attention_module_A.py: -------------------------------------------------------------------------------- 1 | import os, time, sys 2 | import random 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torch.optim import lr_scheduler 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | 11 | import torchvision 12 | from torchvision import datasets, models, transforms 13 | from torchvision.utils import make_grid 14 | from tensorboardX import SummaryWriter 15 | 16 | from models.depth_generator_networks import _UNetGenerator, init_weights, _ResGenerator_Upsample 17 | from models.discriminator_networks import Discriminator80x80InstNorm 18 | from models.attention_networks import _Attention_FullRes 19 | 20 | from utils.metrics import * 21 | from utils.image_pool import ImagePool 22 | 23 | from training.base_model import set_requires_grad, base_model 24 | 25 | try: 26 | from apex import amp 27 | except ImportError: 28 | print("\nPlease consider install apex from https://www.github.com/nvidia/apex to run with apex or set use_apex = False\n") 29 | 30 | import warnings # ignore warnings 31 | warnings.filterwarnings("ignore") 32 | 33 | class jointly_train_depth_predictor_D_and_attention_module_A(base_model): 34 | def __init__(self, args, dataloaders_xLabels_joint, dataloaders_single): 35 | super(jointly_train_depth_predictor_D_and_attention_module_A, self).__init__(args) 36 | self._initialize_training() 37 | # self.KITTI_MAX_DEPTH_CLIP = 80.0 38 | # self.EVAL_DEPTH_MIN = 1.0 39 | # self.EVAL_DEPTH_MAX = 50.0 40 | 41 | self.NYU_MAX_DEPTH_CLIP = 10.0 42 | self.EVAL_DEPTH_MIN = 1.0 43 | self.EVAL_DEPTH_MAX = 8.0 44 | 45 | self.dataloaders_single = dataloaders_single 46 | self.dataloaders_xLabels_joint = dataloaders_xLabels_joint 47 | 48 | self.attModule = _Attention_FullRes(input_nc = 3, output_nc = 1) 49 | self.inpaintNet = _ResGenerator_Upsample(input_nc = 3, output_nc = 3) 50 | self.styleTranslator = _ResGenerator_Upsample(input_nc = 3, output_nc = 3) 51 | self.netD = Discriminator80x80InstNorm(input_nc = 3) 52 | self.depthEstModel = _UNetGenerator(input_nc = 3, output_nc = 1) 53 | 54 | self.tau_min = 0.05 55 | self.rho = 0.85 56 | self.KL_loss_weight = 1.0 57 | self.dis_weight = 1.0 58 | self.fake_loss_weight = 1e-3 59 | 60 | self.tensorboard_num_display_per_epoch = 1 61 | self.model_name = ['attModule', 'inpaintNet', 'styleTranslator', 'netD', 'depthEstModel'] 62 | self.L1loss = nn.L1Loss() 63 | 64 | if self.isTrain: 65 | self.optim_netD = optim.Adam(self.netD.parameters(), lr=self.task_lr, betas=(0.5, 0.999)) 66 | self.optim_depth = optim.Adam(list(self.depthEstModel.parameters()) + list(self.attModule.parameters()), lr=self.task_lr, betas=(0.5, 0.999)) 67 | self.optim_name = ['optim_depth', 'optim_netD'] 68 | self._get_scheduler() 69 | self.loss_BCE = nn.BCEWithLogitsLoss() 70 | 71 | self._initialize_networks(['netD']) 72 | 73 | # load the "best" depth predictor D (from step 1) 74 | preTrain_path = os.path.join(os.getcwd(), 'experiments', 'train_initial_depth_predictor_D') 75 | self._load_models(model_list=['depthEstModel'], mode='best', isTrain=True, model_path=preTrain_path) 76 | print('Successfully loaded pre-trained {} model from {}'.format('depthEstModel', preTrain_path)) 77 | 78 | # load the "best" style translator T (from step 2) 79 | preTrain_path = os.path.join(os.getcwd(), 'experiments', 'train_style_translator_T') 80 | self._load_models(model_list=['styleTranslator'], mode=480, isTrain=True, model_path=preTrain_path) 81 | print('Successfully loaded pre-trained {} model from {}'.format('styleTranslator', preTrain_path)) 82 | 83 | # load the "best" attention module A (from step 3) 84 | preTrain_path = os.path.join(os.getcwd(), 'experiments', 'train_initial_attention_module_A') 85 | self._load_models(model_list=['attModule'], mode=450, isTrain=True, model_path=preTrain_path) 86 | print('Successfully loaded pre-trained {} model from {}'.format('attModule', preTrain_path)) 87 | 88 | # load the "best" inpainting module I (from step 4) 89 | preTrain_path = os.path.join(os.getcwd(), 'experiments', 'train_inpainting_module_I') 90 | self._load_models(model_list=['inpaintNet'], mode=450, isTrain=True, model_path=preTrain_path) 91 | print('Successfully loaded pre-trained {} model from {}'.format('inpaintNet', preTrain_path)) 92 | 93 | # apex can only be applied to CUDA models 94 | if self.use_apex: 95 | self._init_apex(Num_losses=2) 96 | 97 | self.EVAL_best_loss = float('inf') 98 | self.EVAL_best_model_epoch = 0 99 | self.EVAL_all_results = {} 100 | 101 | self._check_parallel() 102 | 103 | def _get_project_name(self): 104 | return 'jointly_train_depth_predictor_D_and_attention_module_A' 105 | 106 | def _initialize_networks(self, model_name): 107 | for name in model_name: 108 | getattr(self, name).train().to(self.device) 109 | init_weights(getattr(self, name), net_name=name, init_type='normal', gain=0.02) 110 | 111 | def compute_D_loss(self, real_sample, fake_sample, netD): 112 | loss = 0 113 | syn_acc = 0 114 | real_acc = 0 115 | 116 | output = netD(fake_sample) 117 | label = torch.full((output.size()), self.syn_label, device=self.device) 118 | predSyn = (output > 0.5).to(self.device, dtype=torch.float32) 119 | total_num = torch.numel(output) 120 | syn_acc += (predSyn==label).type(torch.float32).sum().item()/total_num 121 | 122 | loss += self.loss_BCE(output, label) 123 | 124 | output = netD(real_sample) 125 | label = torch.full((output.size()), self.real_label, device=self.device) 126 | predReal = (output > 0.5).to(self.device, dtype=torch.float32) 127 | real_acc += (predReal==label).type(torch.float32).sum().item()/total_num 128 | 129 | loss += self.loss_BCE(output, label) 130 | 131 | return loss, syn_acc, real_acc 132 | 133 | def compute_depth_loss(self, input_rgb, depth_label, depthEstModel, valid_mask=None): 134 | 135 | prediction = depthEstModel(input_rgb)[-1] 136 | if valid_mask is not None: 137 | loss = self.L1loss(prediction[valid_mask], depth_label[valid_mask]) 138 | else: 139 | assert valid_mask == None 140 | loss = self.L1loss(prediction, depth_label) 141 | 142 | return loss 143 | 144 | def compute_spare_attention(self, confident_score, t, isTrain=True): 145 | # t is the temperature --> scalar 146 | if isTrain: 147 | noise = torch.rand(confident_score.size(), requires_grad=False).to(self.device) 148 | noise = (noise + 0.00001) / 1.001 149 | noise = - torch.log(- torch.log(noise)) 150 | 151 | confident_score = (confident_score + 0.00001) / 1.001 152 | confident_score = (confident_score + noise) / t 153 | else: 154 | confident_score = confident_score / t 155 | 156 | confident_score = F.sigmoid(confident_score) 157 | 158 | return confident_score 159 | 160 | def compute_KL_div(self, cf, target=0.5): 161 | g = cf.mean() 162 | g = (g + 0.00001) / 1.001 # prevent g = 0. or 1. 163 | y = target*torch.log(target/g) + (1-target)*torch.log((1-target)/(1-g)) 164 | return y 165 | 166 | def compute_real_fake_loss(self, scores, loss_type, datasrc = 'real', loss_for='discr'): 167 | if loss_for == 'discr': 168 | if datasrc == 'real': 169 | if loss_type == 'lsgan': 170 | # The Loss for least-square gan 171 | d_loss = torch.pow(scores - 1., 2).mean() 172 | elif loss_type == 'hinge': 173 | # Hinge loss used in the spectral GAN paper 174 | d_loss = - torch.mean(torch.clamp(scores-1.,max=0.)) 175 | elif loss_type == 'wgan': 176 | # The Loss for Wgan 177 | d_loss = - torch.mean(scores) 178 | else: 179 | scores = scores.view(scores.size(0),-1).mean(dim=1) 180 | d_loss = F.binary_cross_entropy_with_logits(scores, torch.ones_like(scores).detach()) 181 | else: 182 | if loss_type == 'lsgan': 183 | # The Loss for least-square gan 184 | d_loss = torch.pow((scores),2).mean() 185 | elif loss_type == 'hinge': 186 | # Hinge loss used in the spectral GAN paper 187 | d_loss = -torch.mean(torch.clamp(-scores-1.,max=0.)) 188 | elif loss_type == 'wgan': 189 | # The Loss for Wgan 190 | d_loss = torch.mean(scores) 191 | else: 192 | scores = scores.view(scores.size(0),-1).mean(dim=1) 193 | d_loss = F.binary_cross_entropy_with_logits(scores, torch.zeros_like(scores).detach()) 194 | 195 | return d_loss 196 | else: 197 | if loss_type == 'lsgan': 198 | # The Loss for least-square gan 199 | g_loss = torch.pow(scores - 1., 2).mean() 200 | elif (loss_type == 'wgan') or (loss_type == 'hinge') : 201 | g_loss = - torch.mean(scores) 202 | else: 203 | scores = scores.view(scores.size(0),-1).mean(dim=1) 204 | g_loss = F.binary_cross_entropy_with_logits(scores, torch.ones_like(scores).detach()) 205 | return g_loss 206 | 207 | def train(self): 208 | phase = 'train' 209 | since = time.time() 210 | best_loss = float('inf') 211 | 212 | set_requires_grad(self.styleTranslator, requires_grad=False) # freeze style translator T 213 | set_requires_grad(self.inpaintNet, requires_grad=False) # freeze inpainting module I 214 | 215 | tensorboardX_iter_count = 0 216 | for epoch in range(self.total_epoch_num): 217 | print('\nEpoch {}/{}'.format(epoch+1, self.total_epoch_num)) 218 | print('-' * 10) 219 | fn = open(self.train_log,'a') 220 | fn.write('\nEpoch {}/{}\n'.format(epoch+1, self.total_epoch_num)) 221 | fn.write('--'*5+'\n') 222 | fn.close() 223 | 224 | self._set_models_train(['attModule', 'inpaintNet', 'styleTranslator', 'depthEstModel']) 225 | iterCount = 0 226 | 227 | for sample_dict in self.dataloaders_xLabels_joint: 228 | imageListReal, depthListReal = sample_dict['real'] 229 | imageListSyn, depthListSyn = sample_dict['syn'] 230 | 231 | imageListSyn = imageListSyn.to(self.device) 232 | depthListSyn = depthListSyn.to(self.device) 233 | imageListReal = imageListReal.to(self.device) 234 | depthListReal = depthListReal.to(self.device) 235 | valid_mask = (depthListReal > -1.) 236 | 237 | B, C, H, W = imageListReal.size()[0], imageListReal.size()[1], imageListReal.size()[2], imageListReal.size()[3] 238 | 239 | with torch.set_grad_enabled(phase=='train'): 240 | r2s_img = self.styleTranslator(imageListReal)[-1] 241 | confident_score = self.attModule(imageListReal)[-1] 242 | # convert to sparse confident score 243 | confident_score = self.compute_spare_attention(confident_score, t=self.tau_min, isTrain=True) 244 | 245 | mod_r2s_img = r2s_img * confident_score 246 | inpainted_r2s = self.inpaintNet(mod_r2s_img)[-1] 247 | 248 | reconst_img = inpainted_r2s * (1. - confident_score) + confident_score * r2s_img 249 | 250 | # update depth predictor and attention module 251 | self.optim_depth.zero_grad() 252 | total_loss = 0. 253 | real_depth_loss = self.compute_depth_loss(reconst_img, depthListReal, self.depthEstModel, valid_mask) 254 | syn_depth_loss = self.compute_depth_loss(imageListSyn, depthListSyn, self.depthEstModel) 255 | KL_loss = self.compute_KL_div(confident_score, target=self.rho) * self.KL_loss_weight 256 | 257 | fake_pred = self.netD(inpainted_r2s) 258 | fake_label = torch.full(fake_pred.size(), self.real_label, device=self.device) 259 | fake_loss = self.loss_BCE(fake_pred, fake_label) * self.fake_loss_weight 260 | 261 | total_loss += (real_depth_loss + syn_depth_loss + KL_loss + fake_loss) 262 | if self.use_apex: 263 | with amp.scale_loss(total_loss, self.optim_depth, loss_id=0) as total_loss_scaled: 264 | total_loss_scaled.backward() 265 | else: 266 | total_loss.backward() 267 | 268 | self.optim_depth.step() 269 | 270 | # stop adding adversaial loss after stable 271 | if epoch <= 100: 272 | self.optim_netD.zero_grad() 273 | netD_loss = 0. 274 | netD_loss, _, _ = self.compute_D_loss(imageListSyn, inpainted_r2s.detach(), self.netD) 275 | 276 | if self.use_apex: 277 | with amp.scale_loss(netD_loss, self.optim_netD, loss_id=0) as netD_loss_scaled: 278 | netD_loss_scaled.backward() 279 | else: 280 | netD_loss.backward() 281 | 282 | self.optim_netD.step() 283 | else: 284 | netD_loss = 0. 285 | set_requires_grad(self.netD, requires_grad=False) 286 | 287 | iterCount += 1 288 | 289 | if self.use_tensorboardX: 290 | nrow = imageListReal.size()[0] 291 | self.train_display_freq = len(self.dataloaders_xLabels_joint) # feel free to adjust the display frequency 292 | if tensorboardX_iter_count % self.train_display_freq == 0: 293 | img_concat = torch.cat((imageListReal, r2s_img, mod_r2s_img, inpainted_r2s, reconst_img), dim=0) 294 | self.write_2_tensorboardX(self.train_SummaryWriter, img_concat, name='real, r2s, r2sMasked, inpaintedR2s, reconst', mode='image', 295 | count=tensorboardX_iter_count, nrow=nrow) 296 | 297 | self.write_2_tensorboardX(self.train_SummaryWriter, confident_score, name='Attention', mode='image', 298 | count=tensorboardX_iter_count, nrow=nrow, value_range=(0., 1.0)) 299 | 300 | # add loss values 301 | loss_val_list = [total_loss, real_depth_loss, syn_depth_loss, KL_loss, fake_loss, netD_loss] 302 | loss_name_list = ['total_loss', 'real_depth_loss', 'syn_depth_loss', 'KL_loss', 'fake_loss', 'netD_loss'] 303 | self.write_2_tensorboardX(self.train_SummaryWriter, loss_val_list, name=loss_name_list, mode='scalar', count=tensorboardX_iter_count) 304 | 305 | tensorboardX_iter_count += 1 306 | 307 | if iterCount % 20 == 0: 308 | loss_summary = '\t{}/{}, total_loss: {:.7f}, netD_loss: {:.7f}'.format(iterCount, len(self.dataloaders_xLabels_joint), total_loss, netD_loss) 309 | G_loss_summary = '\t\t G loss summary: real_depth_loss: {:.7f}, syn_depth_loss: {:.7f}, KL_loss: {:.7f} fake_loss: {:.7f}'.format(real_depth_loss, syn_depth_loss, KL_loss, fake_loss) 310 | 311 | print(loss_summary) 312 | print(G_loss_summary) 313 | 314 | fn = open(self.train_log,'a') 315 | fn.write(loss_summary + '\n') 316 | fn.write(G_loss_summary + '\n') 317 | fn.close() 318 | 319 | # take step in optimizer 320 | for scheduler in self.scheduler_list: 321 | scheduler.step() 322 | for optim in self.optim_name: 323 | lr = getattr(self, optim).param_groups[0]['lr'] 324 | lr_update = 'Epoch {}/{} finished: {} learning rate = {:.7f}'.format(epoch+1, self.total_epoch_num, optim, lr) 325 | print(lr_update) 326 | fn = open(self.train_log,'a') 327 | fn.write(lr_update) 328 | fn.close() 329 | 330 | if (epoch+1) % self.save_steps == 0: 331 | self.save_models(['depthEstModel', 'attModule'], mode=epoch+1) 332 | self.evaluate(epoch+1) 333 | 334 | time_elapsed = time.time() - since 335 | print('\nTraining complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 336 | 337 | fn = open(self.train_log,'a') 338 | fn.write('\nTraining complete in {:.0f}m {:.0f}s\n'.format(time_elapsed // 60, time_elapsed % 60)) 339 | fn.close() 340 | 341 | best_model_summary = '\nOverall best model is epoch {}'.format(self.EVAL_best_model_epoch) 342 | print(best_model_summary) 343 | print(self.EVAL_all_results[str(self.EVAL_best_model_epoch)]) 344 | fn = open(self.evaluate_log, 'a') 345 | fn.write(best_model_summary + '\n') 346 | fn.write(self.EVAL_all_results[str(self.EVAL_best_model_epoch)]) 347 | fn.close() 348 | 349 | def evaluate(self, mode): 350 | ''' 351 | mode choose from or best 352 | is the number of epoch, represents the number of epoch, used for in training evaluation 353 | 'best' is used for after training mode 354 | ''' 355 | set_name = 'test' 356 | eval_model_list = ['attModule', 'inpaintNet', 'styleTranslator', 'depthEstModel'] 357 | 358 | if isinstance(mode, int) and self.isTrain: 359 | self._set_models_eval(eval_model_list) 360 | if self.EVAL_best_loss == float('inf'): 361 | fn = open(self.evaluate_log, 'w') 362 | else: 363 | fn = open(self.evaluate_log, 'a') 364 | 365 | fn.write('Evaluating with mode: {}\n'.format(mode)) 366 | fn.write('\tEvaluation range min: {} | max: {} \n'.format(self.EVAL_DEPTH_MIN, self.EVAL_DEPTH_MAX)) 367 | fn.close() 368 | 369 | else: 370 | self._load_models(eval_model_list, mode) 371 | 372 | print('Evaluating with mode: {}'.format(mode)) 373 | print('\tEvaluation range min: {} | max: {}'.format(self.EVAL_DEPTH_MIN, self.EVAL_DEPTH_MAX)) 374 | 375 | total_loss, count = 0., 0 376 | predTensor = torch.zeros((1, 1, self.cropSize_h, self.cropSize_w)).to('cpu') 377 | grndTensor = torch.zeros((1, 1, self.cropSize_h, self.cropSize_w)).to('cpu') 378 | idx = 0 379 | 380 | tensorboardX_iter_count = 0 381 | for sample in self.dataloaders_single[set_name]: 382 | imageList, depthList = sample 383 | valid_mask = np.logical_and(depthList > self.EVAL_DEPTH_MIN, depthList < self.EVAL_DEPTH_MAX) 384 | 385 | idx += imageList.shape[0] 386 | print('epoch {}: have processed {} number samples in {} set'.format(mode, str(idx), set_name)) 387 | imageList = imageList.to(self.device) 388 | depthList = depthList.to(self.device) 389 | 390 | if self.isTrain and self.use_apex: 391 | with amp.disable_casts(): 392 | r2s_img = self.styleTranslator(imageList)[-1] 393 | confident_score = self.attModule(imageList)[-1] 394 | # convert to sparse confident score 395 | confident_score = self.compute_spare_attention(confident_score, t=self.tau_min, isTrain=False) 396 | # hard threshold 397 | confident_score[confident_score < 0.5] = 0. 398 | confident_score[confident_score >= 0.5] = 1. 399 | mod_r2s_img = r2s_img * confident_score 400 | inpainted_r2s = self.inpaintNet(mod_r2s_img)[-1] 401 | reconst_img = inpainted_r2s * (1. - confident_score) + confident_score * r2s_img 402 | predList = self.depthEstModel(reconst_img)[-1].detach().to('cpu') # [-1, 1] 403 | 404 | else: 405 | r2s_img = self.styleTranslator(imageList)[-1] 406 | confident_score = self.attModule(imageList)[-1] 407 | # convert to sparse confident score 408 | confident_score = self.compute_spare_attention(confident_score, t=self.tau_min, isTrain=False) 409 | # hard threshold 410 | confident_score[confident_score < 0.5] = 0. 411 | confident_score[confident_score >= 0.5] = 1. 412 | mod_r2s_img = r2s_img * confident_score 413 | inpainted_r2s = self.inpaintNet(mod_r2s_img)[-1] 414 | reconst_img = inpainted_r2s * (1. - confident_score) + confident_score * r2s_img 415 | predList = self.depthEstModel(reconst_img)[-1].detach().to('cpu') # [-1, 1] 416 | 417 | # recover real depth 418 | predList = (predList + 1.0) * 0.5 * self.NYU_MAX_DEPTH_CLIP 419 | depthList = depthList.detach().to('cpu') 420 | predTensor = torch.cat((predTensor, predList), dim=0) 421 | grndTensor = torch.cat((grndTensor, depthList), dim=0) 422 | 423 | if self.use_tensorboardX: 424 | nrow = imageList.size()[0] 425 | if tensorboardX_iter_count % self.val_display_freq == 0: 426 | depth_concat = torch.cat((depthList, predList), dim=0) 427 | self.write_2_tensorboardX(self.eval_SummaryWriter, depth_concat, name='{}: ground truth and depth prediction'.format(set_name), 428 | mode='image', count=tensorboardX_iter_count, nrow=nrow, value_range=(0.0, self.NYU_MAX_DEPTH_CLIP)) 429 | 430 | tensorboardX_iter_count += 1 431 | 432 | if isinstance(mode, int) and self.isTrain: 433 | eval_depth_loss = self.L1loss(predList[valid_mask], depthList[valid_mask]) 434 | total_loss += eval_depth_loss.detach().cpu() 435 | 436 | count += 1 437 | 438 | if isinstance(mode, int) and self.isTrain: 439 | validation_loss = (total_loss / count) 440 | print('validation loss is {:.7f}'.format(validation_loss)) 441 | if self.use_tensorboardX: 442 | self.write_2_tensorboardX(self.eval_SummaryWriter, validation_loss, name='validation loss', mode='scalar', count=mode) 443 | 444 | results = Result(mask_min=self.EVAL_DEPTH_MIN, mask_max=self.EVAL_DEPTH_MAX) 445 | results.evaluate(predTensor[1:], grndTensor[1:]) 446 | 447 | result1 = '\tabs_rel:{:.3f}, sq_rel:{:.3f}, rmse:{:.3f}, rmse_log:{:.3f}, mae:{:.3f} '.format( 448 | results.absrel,results.sqrel,results.rmse,results.rmselog,results.mae) 449 | result2 = '\t[<1.25]:{:.3f}, [<1.25^2]:{:.3f}, [<1.25^3]::{:.3f}'.format(results.delta1,results.delta2,results.delta3) 450 | 451 | print(result1) 452 | print(result2) 453 | 454 | if isinstance(mode, int) and self.isTrain: 455 | self.EVAL_all_results[str(mode)] = result1 + '\t' + result2 456 | 457 | if validation_loss.item() < self.EVAL_best_loss: 458 | self.EVAL_best_loss = validation_loss.item() 459 | self.EVAL_best_model_epoch = mode 460 | self.save_models(['depthEstModel', 'attModule'], mode='best') 461 | 462 | best_model_summary = '\tCurrent best loss {:.7f}, current best model {}\n'.format(self.EVAL_best_loss, self.EVAL_best_model_epoch) 463 | print(best_model_summary) 464 | 465 | fn = open(self.evaluate_log, 'a') 466 | fn.write(result1 + '\n') 467 | fn.write(result2 + '\n') 468 | fn.write(best_model_summary + '\n') 469 | fn.close() -------------------------------------------------------------------------------- /training/train_initial_attention_module_A.py: -------------------------------------------------------------------------------- 1 | import os, time, sys 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.optim import lr_scheduler 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | 10 | import torchvision 11 | from torchvision import datasets, models, transforms 12 | from torchvision.utils import make_grid 13 | from tensorboardX import SummaryWriter 14 | 15 | from models.depth_generator_networks import _UNetGenerator, init_weights, _ResGenerator_Upsample 16 | from models.discriminator_networks import Discriminator80x80InstNorm 17 | from models.attention_networks import _Attention_FullRes 18 | 19 | from utils.metrics import * 20 | from utils.image_pool import ImagePool 21 | 22 | from training.base_model import set_requires_grad, base_model 23 | 24 | try: 25 | from apex import amp 26 | except ImportError: 27 | print("\nPlease consider install apex from https://www.github.com/nvidia/apex to run with apex or set use_apex = False\n") 28 | 29 | import warnings # ignore warnings 30 | warnings.filterwarnings("ignore") 31 | 32 | def value_scheduler(start, total_num_epoch, end=None, ratio=None, step_size=None, multiple=None, mode='linear'): 33 | if mode == 'linear': 34 | return np.linspace(start, end, total_num_epoch) 35 | elif mode == 'linear_ratio': 36 | assert ratio is not None 37 | linear = np.linspace(start, end, total_num_epoch * ratio) 38 | stable = np.repeat(end, total_num_epoch * (1 - ratio)) 39 | return np.concatenate((linear, stable)) 40 | 41 | elif mode == 'step_wise': 42 | assert step_size is not None 43 | times, res = divmod(total_num_epoch, step_size) 44 | for i in range(0, times): 45 | value = np.repeat(start * (multiple**i), step_size) 46 | if i == 0: 47 | final = value 48 | else: 49 | final = np.concatenate((final, value)) 50 | 51 | if res != 0: 52 | final = np.concatenate((final, np.repeat(start * (multiple**(times)), res))) 53 | return final 54 | 55 | class train_initial_attention_module_A(base_model): 56 | def __init__(self, args, dataloaders_xLabels_joint, dataloaders_single): 57 | super(train_initial_attention_module_A, self).__init__(args) 58 | self._initialize_training() 59 | 60 | self.dataloaders_single = dataloaders_single 61 | self.dataloaders_xLabels_joint = dataloaders_xLabels_joint 62 | 63 | # define loss weights 64 | self.lambda_identity = 0.5 # coefficient of identity mapping score 65 | self.lambda_real = 10.0 66 | self.lambda_synthetic = 10.0 67 | self.lambda_GAN = 1.0 68 | 69 | self.KL_loss_weight_max = 1. 70 | self.rho = 0.99 71 | self.tau_min = 0.05 72 | self.tau_max = 0.9 73 | 74 | self.pool_size = 50 75 | self.generated_syn_pool = ImagePool(self.pool_size) 76 | self.generated_real_pool = ImagePool(self.pool_size) 77 | 78 | self.attModule = _Attention_FullRes(input_nc = 3, output_nc = 1) 79 | self.netD_s = Discriminator80x80InstNorm(input_nc = 3) 80 | self.netD_r = Discriminator80x80InstNorm(input_nc = 3) 81 | self.netG_s2r = _ResGenerator_Upsample(input_nc = 3, output_nc = 3) 82 | self.netG_r2s = _ResGenerator_Upsample(input_nc = 3, output_nc = 3) 83 | 84 | self.model_name = ['netD_s', 'netD_r', 'netG_s2r', 'netG_r2s', 'attModule'] 85 | self.L1loss = nn.L1Loss() 86 | 87 | if self.isTrain: 88 | self.netD_optimizer = optim.Adam(list(self.netD_s.parameters()) + list(self.netD_r.parameters()), lr=self.D_lr, betas=(0.5, 0.999)) 89 | self.netG_optimizer = optim.Adam(list(self.netG_r2s.parameters()) + list(self.netG_s2r.parameters()) + list(self.attModule.parameters()), lr=self.G_lr, betas=(0.5, 0.999)) 90 | self.optim_name = ['netD_optimizer', 'netG_optimizer'] 91 | self._get_scheduler() 92 | self.loss_BCE = nn.BCEWithLogitsLoss() 93 | self._initialize_networks() 94 | 95 | # apex can only be applied to CUDA models 96 | if self.use_apex: 97 | self._init_apex(Num_losses=3) 98 | 99 | self._check_parallel() 100 | 101 | def _get_project_name(self): 102 | return 'train_initial_attention_module_A' 103 | 104 | def _initialize_networks(self): 105 | for name in self.model_name: 106 | getattr(self, name).train().to(self.device) 107 | init_weights(getattr(self, name), net_name=name, init_type='normal', gain=0.02) 108 | 109 | def compute_D_loss(self, real_sample, fake_sample, netD): 110 | loss = 0 111 | syn_acc = 0 112 | real_acc = 0 113 | 114 | output = netD(fake_sample) 115 | label = torch.full((output.size()), self.syn_label, device=self.device) 116 | 117 | predSyn = (output > 0.5).to(self.device, dtype=torch.float32) 118 | total_num = torch.numel(output) 119 | syn_acc += (predSyn==label).type(torch.float32).sum().item()/total_num 120 | loss += self.loss_BCE(output, label) 121 | 122 | output = netD(real_sample) 123 | label = torch.full((output.size()), self.real_label, device=self.device) 124 | 125 | predReal = (output > 0.5).to(self.device, dtype=torch.float32) 126 | real_acc += (predReal==label).type(torch.float32).sum().item()/total_num 127 | loss += self.loss_BCE(output, label) 128 | 129 | return loss, syn_acc, real_acc 130 | 131 | def compute_G_loss(self, real_sample, synthetic_sample, r2s_rgb, s2r_rgb, rct_real, rct_syn, cs_imageListReal): 132 | ''' 133 | real_sample: [batch_size, 4, 240, 320] real rgb 134 | synthetic_sample: [batch_size, 4, 240, 320] synthetic rgb 135 | r2s_rgb: netG_r2s(real) 136 | s2r_rgb: netG_s2r(synthetic) 137 | ''' 138 | non_reduction_L1loss = nn.L1Loss(reduction='none') 139 | loss = 0 140 | 141 | # identity loss if applicable 142 | if self.lambda_identity > 0: 143 | idt_real = self.netG_s2r(real_sample)[-1] 144 | idt_synthetic = self.netG_r2s(synthetic_sample)[-1] 145 | idt_loss = (self.L1loss(idt_real, real_sample) * self.lambda_real + 146 | self.L1loss(idt_synthetic, synthetic_sample) * self.lambda_synthetic) * self.lambda_identity 147 | else: 148 | idt_loss = 0 149 | 150 | # GAN loss 151 | real_pred = self.netD_r(s2r_rgb) 152 | real_label = torch.full(real_pred.size(), self.real_label, device=self.device) 153 | GAN_loss_real = self.loss_BCE(real_pred, real_label) 154 | 155 | syn_pred = self.netD_s(r2s_rgb) 156 | syn_label = torch.full(syn_pred.size(), self.real_label, device=self.device) 157 | GAN_loss_syn = self.loss_BCE(syn_pred, syn_label) 158 | 159 | GAN_loss = (GAN_loss_real + GAN_loss_syn) * self.lambda_GAN 160 | 161 | # cycle consist loss 162 | rec_real_loss = cs_imageListReal * non_reduction_L1loss(rct_real, real_sample) 163 | rec_real_loss = rec_real_loss.mean() * self.lambda_real 164 | 165 | rec_syn_loss = self.L1loss(rct_syn, synthetic_sample) * self.lambda_synthetic 166 | rec_loss = rec_real_loss + rec_syn_loss 167 | 168 | loss += (idt_loss + GAN_loss + rec_loss) 169 | 170 | return loss, idt_loss, GAN_loss, rec_loss 171 | 172 | def compute_spare_attention(self, confident_score, t, isTrain=True): 173 | # t is the temperature --> scalar 174 | if isTrain: 175 | noise = torch.rand(confident_score.size(), requires_grad=False).to(self.device) 176 | noise = (noise + 0.00001) / 1.001 177 | noise = - torch.log(- torch.log(noise)) 178 | 179 | confident_score = (confident_score + 0.00001) / 1.001 180 | confident_score = (confident_score + noise) / t 181 | else: 182 | confident_score = confident_score / t 183 | 184 | confident_score = F.sigmoid(confident_score) 185 | 186 | return confident_score 187 | 188 | def compute_KL_div(self, cf, target=0.5): 189 | g = cf.mean() 190 | g = (g + 0.00001) / 1.001 # prevent g = 0. or 1. 191 | y = target * torch.log(target/g) + (1-target) * torch.log((1-target)/(1-g)) 192 | return y 193 | 194 | def train(self): 195 | phase = 'train' 196 | since = time.time() 197 | best_loss = float('inf') 198 | 199 | self.train_display_freq = len(self.dataloaders_xLabels_joint) // self.tensorboard_num_display_per_epoch 200 | tau_value_scheduler = value_scheduler(self.tau_max, self.total_epoch_num, end=self.tau_min, mode='linear') 201 | 202 | tensorboardX_iter_count = 0 203 | for epoch in range(self.total_epoch_num): 204 | print('\nEpoch {}/{}'.format(epoch+1, self.total_epoch_num)) 205 | print('-' * 10) 206 | fn = open(self.train_log,'a') 207 | fn.write('\nEpoch {}/{}\n'.format(epoch+1, self.total_epoch_num)) 208 | fn.write('--'*5+'\n') 209 | fn.close() 210 | 211 | iterCount = 0 212 | 213 | for sample_dict in self.dataloaders_xLabels_joint: 214 | imageListReal, depthListReal = sample_dict['real'] 215 | imageListSyn, depthListSyn = sample_dict['syn'] 216 | 217 | imageListSyn = imageListSyn.to(self.device) 218 | depthListSyn = depthListSyn.to(self.device) 219 | imageListReal = imageListReal.to(self.device) 220 | depthListReal = depthListReal.to(self.device) 221 | 222 | with torch.set_grad_enabled(phase=='train'): 223 | s2r_rgb = self.netG_s2r(imageListSyn)[-1] 224 | rct_syn = self.netG_r2s(s2r_rgb)[-1] 225 | 226 | cs_imageListReal = self.attModule(imageListReal)[-1] 227 | cs_imageListReal = self.compute_spare_attention(cs_imageListReal, t=tau_value_scheduler[epoch], isTrain=True) 228 | mod_imageListReal = imageListReal * cs_imageListReal 229 | r2s_rgb = self.netG_r2s(mod_imageListReal)[-1] 230 | 231 | rct_real = self.netG_s2r(r2s_rgb)[-1] 232 | 233 | ############# update generator 234 | set_requires_grad([self.netD_r, self.netD_s], False) 235 | netG_loss = 0. 236 | self.netG_optimizer.zero_grad() 237 | netG_loss, G_idt_loss, G_GAN_loss, G_rec_loss = self.compute_G_loss(imageListReal, imageListSyn, 238 | r2s_rgb, s2r_rgb, rct_real, rct_syn, cs_imageListReal) 239 | 240 | KL_loss = 0. 241 | KL_loss += self.compute_KL_div(cs_imageListReal, target=self.rho) * self.KL_loss_weight_max 242 | netG_loss += KL_loss 243 | 244 | if self.use_apex: 245 | with amp.scale_loss(netG_loss, self.netG_optimizer, loss_id=0) as netG_loss_scaled: 246 | netG_loss_scaled.backward() 247 | else: 248 | netG_loss.backward() 249 | 250 | self.netG_optimizer.step() 251 | 252 | ############# update discriminator 253 | set_requires_grad([self.netD_r, self.netD_s], True) 254 | 255 | self.netD_optimizer.zero_grad() 256 | 257 | r2s_rgb_pool = self.generated_syn_pool.query(r2s_rgb) 258 | netD_s_loss, netD_s_syn_acc, netD_s_real_acc = self.compute_D_loss(imageListSyn, r2s_rgb.detach(), self.netD_s) 259 | s2r_rgb_pool = self.generated_real_pool.query(s2r_rgb) 260 | netD_r_loss, netD_r_syn_acc, netD_r_real_acc = self.compute_D_loss(imageListReal, s2r_rgb.detach(), self.netD_r) 261 | 262 | netD_loss = netD_s_loss + netD_r_loss 263 | 264 | if self.use_apex: 265 | with amp.scale_loss(netD_loss, self.netD_optimizer, loss_id=1) as netD_loss_scaled: 266 | netD_loss_scaled.backward() 267 | else: 268 | netD_loss.backward() 269 | self.netD_optimizer.step() 270 | 271 | iterCount += 1 272 | 273 | if self.use_tensorboardX: 274 | self.train_display_freq = len(self.dataloaders_xLabels_joint) # feel free to adjust the display frequency 275 | nrow = imageListReal.size()[0] 276 | if tensorboardX_iter_count % self.train_display_freq == 0: 277 | s2r_rgb_concat = torch.cat((imageListSyn, s2r_rgb, imageListReal, rct_syn), dim=0) 278 | self.write_2_tensorboardX(self.train_SummaryWriter, s2r_rgb_concat, name='RGB: syn, s2r, real, reconstruct syn', mode='image', 279 | count=tensorboardX_iter_count, nrow=nrow) 280 | 281 | r2s_rgb_concat = torch.cat((imageListReal, r2s_rgb, imageListSyn, rct_real), dim=0) 282 | self.write_2_tensorboardX(self.train_SummaryWriter, r2s_rgb_concat, name='RGB: real, r2s, synthetic, reconstruct real', mode='image', 283 | count=tensorboardX_iter_count, nrow=nrow) 284 | 285 | self.write_2_tensorboardX(self.train_SummaryWriter, cs_imageListReal, name='Atten: real', mode='image', 286 | count=tensorboardX_iter_count, nrow=nrow, value_range=(0.0, 1.0)) 287 | 288 | loss_val_list = [netD_loss, netG_loss, KL_loss] 289 | loss_name_list = ['netD_loss', 'netG_loss', 'KL_loss'] 290 | self.write_2_tensorboardX(self.train_SummaryWriter, loss_val_list, name=loss_name_list, mode='scalar', count=tensorboardX_iter_count) 291 | 292 | tensorboardX_iter_count += 1 293 | 294 | if iterCount % 20 == 0: 295 | loss_summary = '\t{}/{} netD: {:.7f}, netG: {:.7f}'.format(iterCount, len(self.dataloaders_xLabels_joint), netD_loss, netG_loss) 296 | G_loss_summary = '\t\tG loss summary: netG: {:.7f}, idt_loss: {:.7f}, GAN_loss: {:.7f}, rec_loss: {:.7f}, KL_loss: {:.7f}'.format( 297 | netG_loss, G_idt_loss, G_GAN_loss, G_rec_loss, KL_loss) 298 | 299 | print(loss_summary) 300 | print(G_loss_summary) 301 | 302 | fn = open(self.train_log,'a') 303 | fn.write(loss_summary + '\n') 304 | fn.write(G_loss_summary + '\n') 305 | fn.close() 306 | 307 | if (epoch+1) % self.save_steps == 0: 308 | self.save_models(['attModule'], mode=epoch+1) 309 | 310 | # take step in optimizer 311 | for scheduler in self.scheduler_list: 312 | scheduler.step() 313 | for optim in self.optim_name: 314 | lr = getattr(self, optim).param_groups[0]['lr'] 315 | lr_update = 'Epoch {}/{} finished: {} learning rate = {:.7f}'.format(epoch+1, self.total_epoch_num, optim, lr) 316 | print(lr_update) 317 | 318 | fn = open(self.train_log,'a') 319 | fn.write(lr_update + '\n') 320 | fn.close() 321 | 322 | time_elapsed = time.time() - since 323 | print('\nTraining complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 324 | 325 | fn = open(self.train_log,'a') 326 | fn.write('\nTraining complete in {:.0f}m {:.0f}s\n'.format(time_elapsed // 60, time_elapsed % 60)) 327 | fn.close() 328 | 329 | def evaluate(self, mode): 330 | pass -------------------------------------------------------------------------------- /training/train_initial_depth_predictor_D.py: -------------------------------------------------------------------------------- 1 | import os, time, sys 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.optim import lr_scheduler 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | 10 | import torchvision 11 | from torchvision import datasets, models, transforms 12 | from torchvision.utils import make_grid 13 | from tensorboardX import SummaryWriter 14 | 15 | from models.depth_generator_networks import _UNetGenerator, init_weights 16 | 17 | from utils.metrics import * 18 | from utils.image_pool import ImagePool 19 | 20 | from training.base_model import set_requires_grad, base_model 21 | 22 | try: 23 | from apex import amp 24 | except ImportError: 25 | print("\nPlease consider install apex from https://www.github.com/nvidia/apex to run with apex or set use_apex = False\n") 26 | 27 | import warnings # ignore warnings 28 | warnings.filterwarnings("ignore") 29 | 30 | class train_initial_depth_predictor_D(base_model): 31 | def __init__(self, args, dataloaders_xLabels_joint, dataloaders_single): 32 | super(train_initial_depth_predictor_D, self).__init__(args) 33 | self._initialize_training() 34 | # self.KITTI_MAX_DEPTH_CLIP = 80.0 35 | # self.EVAL_DEPTH_MIN = 1.0 36 | # self.EVAL_DEPTH_MAX = 50.0 37 | 38 | self.NYU_MAX_DEPTH_CLIP = 10.0 39 | self.EVAL_DEPTH_MIN = 1.0 40 | self.EVAL_DEPTH_MAX = 8.0 41 | 42 | self.dataloaders_single = dataloaders_single 43 | self.dataloaders_xLabels_joint = dataloaders_xLabels_joint 44 | 45 | self.depthEstModel = _UNetGenerator(input_nc = 3, output_nc = 1) 46 | self.model_name = ['depthEstModel'] 47 | self.L1loss = nn.L1Loss() 48 | 49 | if self.isTrain: 50 | self.depth_optimizer = optim.Adam(self.depthEstModel.parameters(), lr=self.task_lr, betas=(0.5, 0.999)) 51 | self.optim_name = ['depth_optimizer'] 52 | self._get_scheduler() 53 | self.loss_BCE = nn.BCEWithLogitsLoss() 54 | self._initialize_networks() 55 | 56 | # apex can only be applied to CUDA models 57 | if self.use_apex: 58 | self._init_apex(Num_losses=2) 59 | 60 | self.EVAL_best_loss = float('inf') 61 | self.EVAL_best_model_epoch = 0 62 | self.EVAL_all_results = {} 63 | 64 | self._check_parallel() 65 | 66 | def _get_project_name(self): 67 | return 'train_initial_depth_predictor_D' 68 | 69 | def _initialize_networks(self): 70 | for name in self.model_name: 71 | getattr(self, name).train().to(self.device) 72 | init_weights(getattr(self, name), net_name=name, init_type='normal', gain=0.02) 73 | 74 | def compute_depth_loss(self, input_rgb, depth_label, depthEstModel, valid_mask=None): 75 | 76 | prediction = depthEstModel(input_rgb)[-1] 77 | if valid_mask is not None: 78 | loss = self.L1loss(prediction[valid_mask], depth_label[valid_mask]) 79 | else: 80 | assert valid_mask == None 81 | loss = self.L1loss(prediction, depth_label) 82 | 83 | return loss 84 | 85 | def train(self): 86 | phase = 'train' 87 | since = time.time() 88 | best_loss = float('inf') 89 | 90 | tensorboardX_iter_count = 0 91 | for epoch in range(self.total_epoch_num): 92 | print('\nEpoch {}/{}'.format(epoch+1, self.total_epoch_num)) 93 | print('-' * 10) 94 | fn = open(self.train_log,'a') 95 | fn.write('\nEpoch {}/{}\n'.format(epoch+1, self.total_epoch_num)) 96 | fn.write('--'*5+'\n') 97 | fn.close() 98 | 99 | self._set_models_train(['depthEstModel']) 100 | iterCount = 0 101 | 102 | for sample_dict in self.dataloaders_xLabels_joint: 103 | imageListReal, depthListReal = sample_dict['real'] 104 | imageListSyn, depthListSyn = sample_dict['syn'] 105 | 106 | imageListSyn = imageListSyn.to(self.device) 107 | depthListSyn = depthListSyn.to(self.device) 108 | imageListReal = imageListReal.to(self.device) 109 | depthListReal = depthListReal.to(self.device) 110 | valid_mask = (depthListReal > -1.) # remove undefined regions 111 | 112 | with torch.set_grad_enabled(phase=='train'): 113 | total_loss = 0. 114 | self.depth_optimizer.zero_grad() 115 | real_depth_loss = self.compute_depth_loss(imageListReal, depthListReal, self.depthEstModel, valid_mask) 116 | syn_depth_loss = self.compute_depth_loss(imageListSyn, depthListSyn, self.depthEstModel) 117 | total_loss += (real_depth_loss + syn_depth_loss) 118 | 119 | if self.use_apex: 120 | with amp.scale_loss(total_loss, self.depth_optimizer) as total_loss_scaled: 121 | total_loss_scaled.backward() 122 | else: 123 | total_loss.backward() 124 | 125 | self.depth_optimizer.step() 126 | 127 | iterCount += 1 128 | 129 | if self.use_tensorboardX: 130 | self.train_display_freq = len(self.dataloaders_xLabels_joint) 131 | nrow = imageListReal.size()[0] 132 | if tensorboardX_iter_count % self.train_display_freq == 0: 133 | pred_depth_real = self.depthEstModel(imageListReal)[-1] 134 | 135 | tensorboardX_grid_real_rgb = make_grid(imageListReal, nrow=nrow, normalize=True, range=(-1.0, 1.0)) 136 | self.train_SummaryWriter.add_image('real rgb images', tensorboardX_grid_real_rgb, tensorboardX_iter_count) 137 | 138 | tensorboardX_depth_concat = torch.cat((depthListReal, pred_depth_real), dim=0) 139 | tensorboardX_grid_real_depth = make_grid(tensorboardX_depth_concat, nrow=nrow, normalize=True, range=(-1.0, 1.0)) 140 | self.train_SummaryWriter.add_image('real depth and depth prediction', tensorboardX_grid_real_depth, tensorboardX_iter_count) 141 | 142 | # add loss values 143 | loss_val_list = [total_loss, real_depth_loss, syn_depth_loss] 144 | loss_name_list = ['total_loss', 'real_depth_loss', 'syn_depth_loss'] 145 | self.write_2_tensorboardX(self.train_SummaryWriter, loss_val_list, name=loss_name_list, mode='scalar', count=tensorboardX_iter_count) 146 | 147 | tensorboardX_iter_count += 1 148 | 149 | if iterCount % 20 == 0: 150 | loss_summary = '\t{}/{} total_loss: {:.7f}, real_depth_loss: {:.7f}, syn_depth_loss: {:.7f}'.format( 151 | iterCount, len(self.dataloaders_xLabels_joint), total_loss, real_depth_loss, syn_depth_loss) 152 | 153 | print(loss_summary) 154 | fn = open(self.train_log,'a') 155 | fn.write(loss_summary) 156 | fn.close() 157 | 158 | # take step in optimizer 159 | for scheduler in self.scheduler_list: 160 | scheduler.step() 161 | for optim in self.optim_name: 162 | lr = getattr(self, optim).param_groups[0]['lr'] 163 | lr_update = 'Epoch {}/{} finished: {} learning rate = {:.7f}'.format(epoch+1, self.total_epoch_num, optim, lr) 164 | print(lr_update) 165 | 166 | fn = open(self.train_log,'a') 167 | fn.write(lr_update) 168 | fn.close() 169 | 170 | if (epoch+1) % self.save_steps == 0: 171 | self.save_models(self.model_name, mode=epoch+1) 172 | self.evaluate(epoch+1) 173 | 174 | time_elapsed = time.time() - since 175 | print('\nTraining complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 176 | 177 | fn = open(self.train_log,'a') 178 | fn.write('\nTraining complete in {:.0f}m {:.0f}s\n'.format(time_elapsed // 60, time_elapsed % 60)) 179 | fn.close() 180 | 181 | best_model_summary = '\nOverall best model is epoch {}'.format(self.EVAL_best_model_epoch) 182 | print(best_model_summary) 183 | print(self.EVAL_all_results[str(self.EVAL_best_model_epoch)]) 184 | fn = open(self.evaluate_log, 'a') 185 | fn.write(best_model_summary + '\n') 186 | fn.write(self.EVAL_all_results[str(self.EVAL_best_model_epoch)]) 187 | fn.close() 188 | 189 | def evaluate(self, mode): 190 | ''' 191 | mode choose from or best 192 | is the number of epoch, represents the number of epoch, used for in training evaluation 193 | 'best' is used for after training mode 194 | ''' 195 | 196 | set_name = 'test' 197 | eval_model_list = ['depthEstModel'] 198 | 199 | if isinstance(mode, int) and self.isTrain: 200 | self._set_models_eval(eval_model_list) 201 | if self.EVAL_best_loss == float('inf'): 202 | fn = open(self.evaluate_log, 'w') 203 | else: 204 | fn = open(self.evaluate_log, 'a') 205 | 206 | fn.write('Evaluating with mode: {}\n'.format(mode)) 207 | fn.write('\tEvaluation range min: {} | max: {} \n'.format(self.EVAL_DEPTH_MIN, self.EVAL_DEPTH_MAX)) 208 | fn.close() 209 | 210 | else: 211 | self._load_models(eval_model_list, mode) 212 | 213 | print('Evaluating with mode: {}'.format(mode)) 214 | print('\tEvaluation range min: {} | max: {}'.format(self.EVAL_DEPTH_MIN, self.EVAL_DEPTH_MAX)) 215 | 216 | total_loss, count = 0., 0 217 | predTensor = torch.zeros((1, 1, self.cropSize_h, self.cropSize_w)).to('cpu') 218 | grndTensor = torch.zeros((1, 1, self.cropSize_h, self.cropSize_w)).to('cpu') 219 | idx = 0 220 | 221 | tensorboardX_iter_count = 0 222 | for sample in self.dataloaders_single[set_name]: 223 | imageList, depthList = sample 224 | valid_mask = np.logical_and(depthList > self.EVAL_DEPTH_MIN, depthList < self.EVAL_DEPTH_MAX) 225 | 226 | idx += imageList.shape[0] 227 | print('epoch {}: have processed {} number samples in {} set'.format(mode, str(idx), set_name)) 228 | imageList = imageList.to(self.device) 229 | depthList = depthList.to(self.device) # real depth 230 | 231 | if self.isTrain and self.use_apex: 232 | with amp.disable_casts(): 233 | predList = self.depthEstModel(imageList)[-1].detach().to('cpu') 234 | else: 235 | predList = self.depthEstModel(imageList)[-1].detach().to('cpu') 236 | 237 | # recover real depth 238 | predList = (predList + 1.0) * 0.5 * self.NYU_MAX_DEPTH_CLIP 239 | depthList = depthList.detach().to('cpu') 240 | predTensor = torch.cat((predTensor, predList), dim=0) 241 | grndTensor = torch.cat((grndTensor, depthList), dim=0) 242 | 243 | if self.use_tensorboardX: 244 | nrow = imageList.size()[0] 245 | if tensorboardX_iter_count % self.val_display_freq == 0: 246 | depth_concat = torch.cat((depthList, predList), dim=0) 247 | self.write_2_tensorboardX(self.eval_SummaryWriter, depth_concat, name='{}: ground truth and depth prediction'.format(set_name), 248 | mode='image', count=tensorboardX_iter_count, nrow=nrow, value_range=(0.0, self.NYU_MAX_DEPTH_CLIP)) 249 | 250 | tensorboardX_iter_count += 1 251 | 252 | if isinstance(mode, int) and self.isTrain: 253 | eval_depth_loss = self.L1loss(predList[valid_mask], depthList[valid_mask]) 254 | total_loss += eval_depth_loss.detach().cpu() 255 | 256 | count += 1 257 | 258 | if isinstance(mode, int) and self.isTrain: 259 | validation_loss = (total_loss / count) 260 | print('validation loss is {:.7f}'.format(validation_loss)) 261 | if self.use_tensorboardX: 262 | self.write_2_tensorboardX(self.eval_SummaryWriter, validation_loss, name='validation loss', mode='scalar', count=mode) 263 | 264 | results = Result(mask_min=self.EVAL_DEPTH_MIN, mask_max=self.EVAL_DEPTH_MAX) 265 | results.evaluate(predTensor[1:], grndTensor[1:]) 266 | 267 | result1 = '\tabs_rel:{:.3f}, sq_rel:{:.3f}, rmse:{:.3f}, rmse_log:{:.3f}, mae:{:.3f} '.format( 268 | results.absrel,results.sqrel,results.rmse,results.rmselog,results.mae) 269 | result2 = '\t[<1.25]:{:.3f}, [<1.25^2]:{:.3f}, [<1.25^3]::{:.3f}'.format(results.delta1,results.delta2,results.delta3) 270 | 271 | print(result1) 272 | print(result2) 273 | 274 | if isinstance(mode, int) and self.isTrain: 275 | self.EVAL_all_results[str(mode)] = result1 + '\t' + result2 276 | 277 | if validation_loss.item() < self.EVAL_best_loss: 278 | self.EVAL_best_loss = validation_loss.item() 279 | self.EVAL_best_model_epoch = mode 280 | self.save_models(self.model_name, mode='best') 281 | 282 | best_model_summary = '\tCurrent best loss {:.7f}, current best model {}\n'.format(self.EVAL_best_loss, self.EVAL_best_model_epoch) 283 | print(best_model_summary) 284 | 285 | fn = open(self.evaluate_log, 'a') 286 | fn.write(result1 + '\n') 287 | fn.write(result2 + '\n') 288 | fn.write(best_model_summary + '\n') 289 | fn.close() -------------------------------------------------------------------------------- /training/train_inpainting_module_I.py: -------------------------------------------------------------------------------- 1 | import os, time, sys 2 | import random 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torch.optim import lr_scheduler 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | 11 | import torchvision 12 | from torchvision import datasets, models, transforms 13 | from torchvision.utils import make_grid 14 | from tensorboardX import SummaryWriter 15 | 16 | from loss import PerceptualLoss, StyleLoss, VGG19 17 | from models.depth_generator_networks import _UNetGenerator, init_weights, _ResGenerator_Upsample 18 | from models.attention_networks import _Attention_FullRes 19 | from models.discriminator_networks import Discriminator80x80InstNorm, DiscriminatorGlobalLocal 20 | 21 | from utils.metrics import * 22 | from utils.image_pool import ImagePool 23 | 24 | from training.base_model import set_requires_grad, base_model 25 | 26 | try: 27 | from apex import amp 28 | except ImportError: 29 | print("\nPlease consider install apex from https://www.github.com/nvidia/apex to run with apex or set use_apex = False\n") 30 | 31 | import warnings # ignore warnings 32 | warnings.filterwarnings("ignore") 33 | 34 | class Mask_Buffer(): 35 | """This class implements an image buffer that stores previously generated images. 36 | 37 | This buffer enables us to update discriminators using a history of generated images 38 | rather than the ones produced by the latest generators. 39 | """ 40 | 41 | def __init__(self, pool_size): 42 | """Initialize the ImagePool class 43 | 44 | Parameters: 45 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 46 | """ 47 | self.pool_size = pool_size 48 | if self.pool_size > 0: # create an empty pool 49 | self.num_imgs = 0 50 | self.images = [] 51 | 52 | def query(self, images): 53 | """Return an image from the pool. 54 | 55 | Parameters: 56 | images: the latest generated images from the generator 57 | 58 | Returns images from the buffer. 59 | 60 | By 50/100, the buffer will return input images. 61 | By 50/100, the buffer will return images previously stored in the buffer, 62 | and insert the current images to the buffer. 63 | """ 64 | if self.pool_size == 0: # if the buffer size is 0, do nothing 65 | return images 66 | return_images = [] 67 | for image in images: 68 | image = torch.unsqueeze(image.data, 0) 69 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer 70 | self.num_imgs = self.num_imgs + 1 71 | self.images.append(image) 72 | return_images.append(image) 73 | else: 74 | # p = random.uniform(0, 1) 75 | # if p > 0.5: # the buffer will always return a previously stored image, and insert the current image into the buffer 76 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 77 | tmp = self.images[random_id].clone() 78 | self.images[random_id] = image 79 | return_images.append(tmp) 80 | # else: # by another 50% chance, the buffer will return the current image 81 | # return_images.append(image) 82 | return_images = torch.cat(return_images, 0) # collect all the images and return 83 | return return_images 84 | 85 | class train_inpainting_module_I(base_model): 86 | def __init__(self, args, dataloaders_xLabels_joint, dataloaders_single): 87 | super(train_inpainting_module_I, self).__init__(args) 88 | self._initialize_training() 89 | 90 | self.dataloaders_single = dataloaders_single 91 | self.dataloaders_xLabels_joint = dataloaders_xLabels_joint 92 | 93 | self.use_apex = False # use apex might cause style loss to be 0 94 | 95 | self.mask_buffer = Mask_Buffer(500) 96 | 97 | self.attModule = _Attention_FullRes(input_nc = 3, output_nc = 1) # logits, no tanh() 98 | self.inpaintNet = _ResGenerator_Upsample(input_nc = 3, output_nc = 3) 99 | self.styleTranslator = _ResGenerator_Upsample(input_nc = 3, output_nc = 3) 100 | self.netD = DiscriminatorGlobalLocal(image_size=240) 101 | 102 | self.tau_min = 0.05 103 | self.use_perceptual_loss = True 104 | 105 | self.p_vgg = VGG19() 106 | self.s_vgg = VGG19() 107 | 108 | self.perceptual_loss = PerceptualLoss(vgg19=self.p_vgg) 109 | self.style_loss = StyleLoss(vgg19=self.s_vgg) 110 | 111 | self.reconst_loss_weight = 1.0 112 | self.perceptual_loss_weight = 1.0 113 | self.style_loss_weight = 1.0 114 | self.fake_loss_weight = 0.01 115 | 116 | self.model_name = ['attModule', 'inpaintNet', 'styleTranslator', 'netD', 'p_vgg', 's_vgg'] 117 | self.L1loss = nn.L1Loss() 118 | 119 | if self.isTrain: 120 | self.optim_inpaintNet = optim.Adam(self.inpaintNet.parameters(), lr=self.task_lr, betas=(0.5, 0.999)) 121 | self.optim_netD = optim.Adam(self.netD.parameters(), lr=self.task_lr, betas=(0.5, 0.999)) 122 | self.optim_name = ['optim_inpaintNet', 'optim_netD'] 123 | self._get_scheduler() 124 | self.loss_BCE = nn.BCEWithLogitsLoss() 125 | self._initialize_networks(['inpaintNet', 'netD']) 126 | 127 | # load the "best" style translator T (from step 2) 128 | preTrain_path = os.path.join(os.getcwd(), 'experiments', 'train_style_translator_T') 129 | self._load_models(model_list=['styleTranslator'], mode=480, isTrain=True, model_path=preTrain_path) 130 | print('Successfully loaded pre-trained {} model from {}'.format('styleTranslator', preTrain_path)) 131 | 132 | # load the "best" attention module A (from step 3) 133 | preTrain_path = os.path.join(os.getcwd(), 'experiments', 'train_initial_attention_module_A') 134 | self._load_models(model_list=['attModule'], mode=450, isTrain=True, model_path=preTrain_path) 135 | print('Successfully loaded pre-trained {} model from {}'.format('attModule', preTrain_path)) 136 | 137 | # apex can only be applied to CUDA models 138 | if self.use_apex: 139 | self._init_apex(Num_losses=2) 140 | 141 | self._check_parallel() 142 | 143 | def _get_project_name(self): 144 | return 'train_inpainting_module_I' 145 | 146 | def _initialize_networks(self, model_name): 147 | for name in model_name: 148 | getattr(self, name).train().to(self.device) 149 | init_weights(getattr(self, name), net_name=name, init_type='normal', gain=0.02) 150 | 151 | def compute_spare_attention(self, confident_score, t, isTrain=True): 152 | # t is the temperature --> scalar 153 | if isTrain: 154 | noise = torch.rand(confident_score.size(), requires_grad=False).to(self.device) 155 | noise = (noise + 0.00001) / 1.001 156 | noise = - torch.log(- torch.log(noise)) 157 | 158 | confident_score = (confident_score + 0.00001) / 1.001 159 | confident_score = (confident_score + noise) / t 160 | else: 161 | confident_score = confident_score / t 162 | 163 | confident_score = F.sigmoid(confident_score) 164 | 165 | return confident_score 166 | 167 | def compute_KL_div(self, cf, target=0.5): 168 | g = cf.mean() 169 | g = (g + 0.00001) / 1.001 # prevent g = 0. or 1. 170 | y = target*torch.log(target/g) + (1-target)*torch.log((1-target)/(1-g)) 171 | return y 172 | 173 | def compute_real_fake_loss(self, scores, loss_type, datasrc = 'real', loss_for='discr'): 174 | if loss_for == 'discr': 175 | if datasrc == 'real': 176 | if loss_type == 'lsgan': 177 | # The Loss for least-square gan 178 | d_loss = torch.pow(scores - 1., 2).mean() 179 | elif loss_type == 'hinge': 180 | # Hinge loss used in the spectral GAN paper 181 | d_loss = - torch.mean(torch.clamp(scores-1.,max=0.)) 182 | elif loss_type == 'wgan': 183 | # The Loss for Wgan 184 | d_loss = - torch.mean(scores) 185 | else: 186 | scores = scores.view(scores.size(0),-1).mean(dim=1) 187 | d_loss = F.binary_cross_entropy_with_logits(scores, torch.ones_like(scores).detach()) 188 | else: 189 | if loss_type == 'lsgan': 190 | # The Loss for least-square gan 191 | d_loss = torch.pow((scores),2).mean() 192 | elif loss_type == 'hinge': 193 | # Hinge loss used in the spectral GAN paper 194 | d_loss = -torch.mean(torch.clamp(-scores-1.,max=0.)) 195 | elif loss_type == 'wgan': 196 | # The Loss for Wgan 197 | d_loss = torch.mean(scores) 198 | else: 199 | scores = scores.view(scores.size(0),-1).mean(dim=1) 200 | d_loss = F.binary_cross_entropy_with_logits(scores, torch.zeros_like(scores).detach()) 201 | 202 | return d_loss 203 | else: 204 | if loss_type == 'lsgan': 205 | # The Loss for least-square gan 206 | g_loss = torch.pow(scores - 1., 2).mean() 207 | elif (loss_type == 'wgan') or (loss_type == 'hinge') : 208 | g_loss = - torch.mean(scores) 209 | else: 210 | scores = scores.view(scores.size(0),-1).mean(dim=1) 211 | g_loss = F.binary_cross_entropy_with_logits(scores, torch.ones_like(scores).detach()) 212 | return g_loss 213 | 214 | def train(self): 215 | phase = 'train' 216 | since = time.time() 217 | best_loss = float('inf') 218 | 219 | set_requires_grad(self.attModule, requires_grad=False) # freeze attention module 220 | set_requires_grad(self.styleTranslator, requires_grad=False) # freeze sytle translator 221 | 222 | tensorboardX_iter_count = 0 223 | for epoch in range(self.total_epoch_num): 224 | print('\nEpoch {}/{}'.format(epoch+1, self.total_epoch_num)) 225 | print('-' * 10) 226 | fn = open(self.train_log,'a') 227 | fn.write('\nEpoch {}/{}\n'.format(epoch+1, self.total_epoch_num)) 228 | fn.write('--'*5+'\n') 229 | fn.close() 230 | 231 | iterCount = 0 232 | 233 | for sample_dict in self.dataloaders_xLabels_joint: 234 | imageListReal, depthListReal = sample_dict['real'] 235 | imageListSyn, depthListSyn = sample_dict['syn'] 236 | 237 | imageListSyn = imageListSyn.to(self.device) 238 | depthListSyn = depthListSyn.to(self.device) 239 | imageListReal = imageListReal.to(self.device) 240 | depthListReal = depthListReal.to(self.device) 241 | 242 | B, C, H, W = imageListReal.size()[0], imageListReal.size()[1], imageListReal.size()[2], imageListReal.size()[3] 243 | 244 | with torch.set_grad_enabled(phase=='train'): 245 | r2s_img = self.styleTranslator(imageListReal)[-1] 246 | confident_score = self.attModule(imageListReal)[-1] 247 | # convert to sparse confident score 248 | confident_score = self.compute_spare_attention(confident_score, t=self.tau_min, isTrain=False) 249 | # hard threshold 250 | confident_score[confident_score < 0.5] = 0. 251 | confident_score[confident_score >= 0.5] = 1. 252 | 253 | confident_score = self.mask_buffer.query(confident_score) 254 | 255 | mod_r2s_img = r2s_img * confident_score 256 | inpainted_r2s = self.inpaintNet(mod_r2s_img)[-1] 257 | 258 | reconst_img = inpainted_r2s * (1. - confident_score) + confident_score * r2s_img 259 | 260 | # update generators 261 | self.optim_inpaintNet.zero_grad() 262 | total_loss = 0. 263 | reconst_loss = self.L1loss(inpainted_r2s, r2s_img) * self.reconst_loss_weight 264 | if self.use_perceptual_loss: 265 | perceptual_loss = self.perceptual_loss(inpainted_r2s, r2s_img) * self.perceptual_loss_weight 266 | style_loss = self.style_loss(inpainted_r2s * (1.-confident_score), r2s_img * (1.-confident_score)) * self.style_loss_weight 267 | total_loss += (perceptual_loss + style_loss) 268 | 269 | d_score, _, _ = self.netD(inpainted_r2s, boxImg=confident_score.expand(B, 3, H, W)) 270 | fake_loss = self.compute_real_fake_loss(d_score, loss_type='lsgan', loss_for='generator') * self.fake_loss_weight 271 | 272 | total_loss += (reconst_loss + fake_loss) 273 | if self.use_apex: 274 | with amp.scale_loss(total_loss, self.optim_inpaintNet, loss_id=0) as total_loss_scaled: 275 | total_loss_scaled.backward() 276 | else: 277 | total_loss.backward() 278 | 279 | self.optim_inpaintNet.step() 280 | 281 | # update discriminator 282 | self.optim_netD.zero_grad() 283 | 284 | real_d_score, _, _ = self.netD(r2s_img, boxImg=confident_score.expand(B, 3, H, W)) 285 | real_d_loss = self.compute_real_fake_loss(real_d_score, loss_type='lsgan', datasrc='real') 286 | 287 | fake_d_score, _, _ = self.netD(inpainted_r2s.detach(), boxImg=confident_score.expand(B, 3, H, W)) 288 | fake_d_loss = self.compute_real_fake_loss(fake_d_score, loss_type='lsgan', datasrc='fake') 289 | 290 | total_d_loss = (real_d_loss + fake_d_loss) 291 | 292 | if self.use_apex: 293 | with amp.scale_loss(total_d_loss, self.optim_netD, loss_id=1) as total_d_loss_scaled: 294 | total_d_loss_scaled.backward() 295 | else: 296 | total_d_loss.backward() 297 | 298 | self.optim_netD.step() 299 | 300 | iterCount += 1 301 | 302 | if self.use_tensorboardX: 303 | nrow = imageListReal.size()[0] 304 | self.train_display_freq = len(self.dataloaders_xLabels_joint) # feel free to adjust the display frequency 305 | if tensorboardX_iter_count % self.train_display_freq == 0: 306 | img_concat = torch.cat((imageListReal, r2s_img, mod_r2s_img, inpainted_r2s, reconst_img), dim=0) 307 | self.write_2_tensorboardX(self.train_SummaryWriter, img_concat, name='real, r2s, r2sMasked, inpaintedR2s, reconst', mode='image', 308 | count=tensorboardX_iter_count, nrow=nrow) 309 | 310 | self.write_2_tensorboardX(self.train_SummaryWriter, confident_score, name='Attention', mode='image', 311 | count=tensorboardX_iter_count, nrow=nrow, value_range=(0., 1.0)) 312 | 313 | # add loss values 314 | loss_val_list = [total_loss, total_d_loss] 315 | loss_name_list = ['total_loss', 'total_d_loss'] 316 | self.write_2_tensorboardX(self.train_SummaryWriter, loss_val_list, name=loss_name_list, mode='scalar', count=tensorboardX_iter_count) 317 | 318 | tensorboardX_iter_count += 1 319 | 320 | if iterCount % 20 == 0: 321 | loss_summary = '\t{}/{}, total_loss: {:.7f}, total_d_loss: {:.7f}'.format( 322 | iterCount, len(self.dataloaders_xLabels_joint), total_loss, total_d_loss) 323 | G_loss_summary = '\t\t G loss summary: reconst_loss: {:.7f}, fake_loss: {:.7f}, perceptual_loss: {:.7f} style_loss: {:.7f}'.format( 324 | reconst_loss, fake_loss, perceptual_loss, style_loss) 325 | D_loss_summary = '\t\t D loss summary: real_d_loss: {:.7f}, fake_d_loss: {:.7f}'.format(real_d_loss, fake_d_loss) 326 | 327 | print(loss_summary) 328 | print(G_loss_summary) 329 | print(D_loss_summary) 330 | 331 | fn = open(self.train_log,'a') 332 | fn.write(loss_summary + '\n') 333 | fn.write(G_loss_summary + '\n') 334 | fn.write(D_loss_summary + '\n') 335 | fn.close() 336 | 337 | if (epoch+1) % self.save_steps == 0: 338 | self.save_models(['inpaintNet'], mode=epoch+1) 339 | 340 | # take step in optimizer 341 | for scheduler in self.scheduler_list: 342 | scheduler.step() 343 | # print learning rate 344 | for optim in self.optim_name: 345 | lr = getattr(self, optim).param_groups[0]['lr'] 346 | lr_update = 'Epoch {}/{} finished: {} learning rate = {:.7f}'.format(epoch+1, self.total_epoch_num, optim, lr) 347 | print(lr_update) 348 | 349 | fn = open(self.train_log,'a') 350 | fn.write(lr_update + '\n') 351 | fn.close() 352 | 353 | time_elapsed = time.time() - since 354 | print('\nTraining complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 355 | 356 | fn = open(self.train_log,'a') 357 | fn.write('\nTraining complete in {:.0f}m {:.0f}s\n'.format(time_elapsed // 60, time_elapsed % 60)) 358 | fn.close() 359 | 360 | def evaluate(self, mode): 361 | pass -------------------------------------------------------------------------------- /training/train_style_translator_T.py: -------------------------------------------------------------------------------- 1 | import os, time, sys 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.optim import lr_scheduler 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | 10 | import torchvision 11 | from torchvision import datasets, models, transforms 12 | from torchvision.utils import make_grid 13 | from tensorboardX import SummaryWriter 14 | 15 | from models.depth_generator_networks import _UNetGenerator, init_weights, _ResGenerator_Upsample 16 | from models.discriminator_networks import Discriminator80x80InstNorm 17 | 18 | from utils.metrics import * 19 | from utils.image_pool import ImagePool 20 | 21 | from training.base_model import set_requires_grad, base_model 22 | 23 | try: 24 | from apex import amp 25 | except ImportError: 26 | print("\nPlease consider install apex from https://www.github.com/nvidia/apex to run with apex or set use_apex = False\n") 27 | 28 | import warnings # ignore warnings 29 | warnings.filterwarnings("ignore") 30 | 31 | class train_style_translator_T(base_model): 32 | def __init__(self, args, dataloaders_xLabels_joint, dataloaders_single): 33 | super(train_style_translator_T, self).__init__(args) 34 | self._initialize_training() 35 | 36 | self.dataloaders_single = dataloaders_single 37 | self.dataloaders_xLabels_joint = dataloaders_xLabels_joint 38 | 39 | # define loss weights 40 | self.lambda_identity = 0.5 # coefficient of identity mapping score 41 | self.lambda_real = 10.0 42 | self.lambda_synthetic = 10.0 43 | self.lambda_GAN = 1.0 44 | 45 | # define pool size in adversarial loss 46 | self.pool_size = 50 47 | self.generated_syn_pool = ImagePool(self.pool_size) 48 | self.generated_real_pool = ImagePool(self.pool_size) 49 | 50 | self.netD_s = Discriminator80x80InstNorm(input_nc = 3) 51 | self.netD_r = Discriminator80x80InstNorm(input_nc = 3) 52 | self.netG_s2r = _ResGenerator_Upsample(input_nc = 3, output_nc = 3) 53 | self.netG_r2s = _ResGenerator_Upsample(input_nc = 3, output_nc = 3) 54 | self.model_name = ['netD_s', 'netD_r', 'netG_s2r', 'netG_r2s'] 55 | self.L1loss = nn.L1Loss() 56 | 57 | if self.isTrain: 58 | self.netD_optimizer = optim.Adam(list(self.netD_s.parameters()) + list(self.netD_r.parameters()), lr=self.D_lr, betas=(0.5, 0.999)) 59 | self.netG_optimizer = optim.Adam(list(self.netG_r2s.parameters()) + list(self.netG_s2r.parameters()), lr=self.G_lr, betas=(0.5, 0.999)) 60 | self.optim_name = ['netD_optimizer', 'netG_optimizer'] 61 | self._get_scheduler() 62 | self.loss_BCE = nn.BCEWithLogitsLoss() 63 | self._initialize_networks() 64 | 65 | # apex can only be applied to CUDA models 66 | if self.use_apex: 67 | self._init_apex(Num_losses=3) 68 | 69 | self._check_parallel() 70 | 71 | def _get_project_name(self): 72 | return 'train_style_translator_T' 73 | 74 | def _initialize_networks(self): 75 | for name in self.model_name: 76 | getattr(self, name).train().to(self.device) 77 | init_weights(getattr(self, name), net_name=name, init_type='normal', gain=0.02) 78 | 79 | def compute_D_loss(self, real_sample, fake_sample, netD): 80 | loss = 0 81 | syn_acc = 0 82 | real_acc = 0 83 | 84 | output = netD(fake_sample) 85 | label = torch.full((output.size()), self.syn_label, device=self.device) 86 | 87 | predSyn = (output > 0.5).to(self.device, dtype=torch.float32) 88 | total_num = torch.numel(output) 89 | syn_acc += (predSyn==label).type(torch.float32).sum().item()/total_num 90 | loss += self.loss_BCE(output, label) 91 | 92 | output = netD(real_sample) 93 | label = torch.full((output.size()), self.real_label, device=self.device) 94 | 95 | predReal = (output > 0.5).to(self.device, dtype=torch.float32) 96 | real_acc += (predReal==label).type(torch.float32).sum().item()/total_num 97 | loss += self.loss_BCE(output, label) 98 | 99 | return loss, syn_acc, real_acc 100 | 101 | def compute_G_loss(self, real_sample, synthetic_sample, r2s_rgb, s2r_rgb, reconstruct_real, reconstruct_syn): 102 | ''' 103 | real_sample: [batch_size, 4, 240, 320] real rgb 104 | synthetic_sample: [batch_size, 4, 240, 320] synthetic rgb 105 | r2s_rgb: netG_r2s(real) 106 | s2r_rgb: netG_s2r(synthetic) 107 | ''' 108 | loss = 0 109 | 110 | # identity loss if applicable 111 | if self.lambda_identity > 0: 112 | idt_real = self.netG_s2r(real_sample)[-1] 113 | idt_synthetic = self.netG_r2s(synthetic_sample)[-1] 114 | idt_loss = (self.L1loss(idt_real, real_sample) * self.lambda_real + 115 | self.L1loss(idt_synthetic, synthetic_sample) * self.lambda_synthetic) * self.lambda_identity 116 | else: 117 | idt_loss = 0 118 | 119 | # GAN loss 120 | real_pred = self.netD_r(s2r_rgb) 121 | real_label = torch.full(real_pred.size(), self.real_label, device=self.device) 122 | GAN_loss_real = self.loss_BCE(real_pred, real_label) 123 | 124 | syn_pred = self.netD_s(r2s_rgb) 125 | syn_label = torch.full(syn_pred.size(), self.real_label, device=self.device) 126 | GAN_loss_syn = self.loss_BCE(syn_pred, syn_label) 127 | 128 | GAN_loss = (GAN_loss_real + GAN_loss_syn) * self.lambda_GAN 129 | 130 | # cycle consistency loss 131 | rec_real_loss = self.L1loss(reconstruct_real, real_sample) * self.lambda_real 132 | rec_syn_loss = self.L1loss(reconstruct_syn, synthetic_sample) * self.lambda_synthetic 133 | rec_loss = rec_real_loss + rec_syn_loss 134 | 135 | loss += (idt_loss + GAN_loss + rec_loss) 136 | 137 | return loss, idt_loss, GAN_loss, rec_loss 138 | 139 | def train(self): 140 | phase = 'train' 141 | since = time.time() 142 | best_loss = float('inf') 143 | 144 | tensorboardX_iter_count = 0 145 | for epoch in range(self.total_epoch_num): 146 | print('\nEpoch {}/{}'.format(epoch+1, self.total_epoch_num)) 147 | print('-' * 10) 148 | fn = open(self.train_log,'a') 149 | fn.write('\nEpoch {}/{}\n'.format(epoch+1, self.total_epoch_num)) 150 | fn.write('--'*5+'\n') 151 | fn.close() 152 | 153 | iterCount = 0 154 | 155 | for sample_dict in self.dataloaders_xLabels_joint: 156 | imageListReal, depthListReal = sample_dict['real'] 157 | imageListSyn, depthListSyn = sample_dict['syn'] 158 | 159 | imageListSyn = imageListSyn.to(self.device) 160 | depthListSyn = depthListSyn.to(self.device) 161 | imageListReal = imageListReal.to(self.device) 162 | depthListReal = depthListReal.to(self.device) 163 | 164 | with torch.set_grad_enabled(phase=='train'): 165 | s2r_rgb = self.netG_s2r(imageListSyn)[-1] 166 | reconstruct_syn = self.netG_r2s(s2r_rgb)[-1] 167 | 168 | r2s_rgb = self.netG_r2s(imageListReal)[-1] 169 | reconstruct_real = self.netG_s2r(r2s_rgb)[-1] 170 | 171 | ############# update generator 172 | set_requires_grad([self.netD_r, self.netD_s], False) 173 | 174 | netG_loss = 0. 175 | self.netG_optimizer.zero_grad() 176 | netG_loss, G_idt_loss, G_GAN_loss, G_rec_loss = self.compute_G_loss(imageListReal, imageListSyn, 177 | r2s_rgb, s2r_rgb, reconstruct_real, reconstruct_syn) 178 | 179 | if self.use_apex: 180 | with amp.scale_loss(netG_loss, self.netG_optimizer, loss_id=0) as netG_loss_scaled: 181 | netG_loss_scaled.backward() 182 | else: 183 | netG_loss.backward() 184 | 185 | self.netG_optimizer.step() 186 | 187 | ############# update discriminator 188 | set_requires_grad([self.netD_r, self.netD_s], True) 189 | 190 | self.netD_optimizer.zero_grad() 191 | r2s_rgb_pool = self.generated_syn_pool.query(r2s_rgb) 192 | netD_s_loss, netD_s_syn_acc, netD_s_real_acc = self.compute_D_loss(imageListSyn, r2s_rgb.detach(), self.netD_s) 193 | s2r_rgb_pool = self.generated_real_pool.query(s2r_rgb) 194 | netD_r_loss, netD_r_syn_acc, netD_r_real_acc = self.compute_D_loss(imageListReal, s2r_rgb.detach(), self.netD_r) 195 | 196 | netD_loss = netD_s_loss + netD_r_loss 197 | 198 | if self.use_apex: 199 | with amp.scale_loss(netD_loss, self.netD_optimizer, loss_id=1) as netD_loss_scaled: 200 | netD_loss_scaled.backward() 201 | else: 202 | netD_loss.backward() 203 | self.netD_optimizer.step() 204 | 205 | iterCount += 1 206 | 207 | if self.use_tensorboardX: 208 | self.train_display_freq = len(self.dataloaders_xLabels_joint) # feel free to adjust the display frequency 209 | nrow = imageListReal.size()[0] 210 | if tensorboardX_iter_count % self.train_display_freq == 0: 211 | s2r_rgb_concat = torch.cat((imageListSyn, s2r_rgb, imageListReal, reconstruct_syn), dim=0) 212 | self.write_2_tensorboardX(self.train_SummaryWriter, s2r_rgb_concat, name='RGB: syn, s2r, real, reconstruct syn', mode='image', 213 | count=tensorboardX_iter_count, nrow=nrow) 214 | 215 | r2s_rgb_concat = torch.cat((imageListReal, r2s_rgb, imageListSyn, reconstruct_real), dim=0) 216 | self.write_2_tensorboardX(self.train_SummaryWriter, r2s_rgb_concat, name='RGB: real, r2s, synthetic, reconstruct real', mode='image', 217 | count=tensorboardX_iter_count, nrow=nrow) 218 | 219 | loss_val_list = [netD_loss, netG_loss] 220 | loss_name_list = ['netD_loss', 'netG_loss'] 221 | self.write_2_tensorboardX(self.train_SummaryWriter, loss_val_list, name=loss_name_list, mode='scalar', count=tensorboardX_iter_count) 222 | 223 | tensorboardX_iter_count += 1 224 | 225 | if iterCount % 20 == 0: 226 | loss_summary = '\t{}/{} netD: {:.7f}, netG: {:.7f}'.format(iterCount, len(self.dataloaders_xLabels_joint), netD_loss, netG_loss) 227 | G_loss_summary = '\t\tG loss summary: netG: {:.7f}, idt_loss: {:.7f}, GAN_loss: {:.7f}, rec_loss: {:.7f}'.format( 228 | netG_loss, G_idt_loss, G_GAN_loss, G_rec_loss) 229 | 230 | print(loss_summary) 231 | print(G_loss_summary) 232 | 233 | fn = open(self.train_log,'a') 234 | fn.write(loss_summary + '\n') 235 | fn.write(G_loss_summary + '\n') 236 | fn.close() 237 | 238 | if (epoch+1) % self.save_steps == 0: 239 | self.save_models(['netG_r2s'], mode=epoch+1, save_list=['styleTranslator']) 240 | 241 | # take step in optimizer 242 | for scheduler in self.scheduler_list: 243 | scheduler.step() 244 | for optim in self.optim_name: 245 | lr = getattr(self, optim).param_groups[0]['lr'] 246 | lr_update = 'Epoch {}/{} finished: {} learning rate = {:.7f}'.format(epoch+1, self.total_epoch_num, optim, lr) 247 | print(lr_update) 248 | 249 | fn = open(self.train_log,'a') 250 | fn.write(lr_update + '\n') 251 | fn.close() 252 | 253 | time_elapsed = time.time() - since 254 | print('\nTraining complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 255 | 256 | fn = open(self.train_log,'a') 257 | fn.write('\nTraining complete in {:.0f}m {:.0f}s\n'.format(time_elapsed // 60, time_elapsed % 60)) 258 | fn.close() 259 | 260 | def evaluate(self, mode): 261 | pass -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import * 2 | -------------------------------------------------------------------------------- /utils/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class ImagePool(): 6 | """This class implements an image buffer that stores previously generated images. 7 | 8 | This buffer enables us to update discriminators using a history of generated images 9 | rather than the ones produced by the latest generators. 10 | """ 11 | 12 | def __init__(self, pool_size): 13 | """Initialize the ImagePool class 14 | 15 | Parameters: 16 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 17 | """ 18 | self.pool_size = pool_size 19 | if self.pool_size > 0: # create an empty pool 20 | self.num_imgs = 0 21 | self.images = [] 22 | 23 | def query(self, images): 24 | """Return an image from the pool. 25 | 26 | Parameters: 27 | images: the latest generated images from the generator 28 | 29 | Returns images from the buffer. 30 | 31 | By 50/100, the buffer will return input images. 32 | By 50/100, the buffer will return images previously stored in the buffer, 33 | and insert the current images to the buffer. 34 | """ 35 | if self.pool_size == 0: # if the buffer size is 0, do nothing 36 | return images 37 | return_images = [] 38 | for image in images: 39 | image = torch.unsqueeze(image.data, 0) 40 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer 41 | self.num_imgs = self.num_imgs + 1 42 | self.images.append(image) 43 | return_images.append(image) 44 | else: 45 | p = random.uniform(0, 1) 46 | if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer 47 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 48 | tmp = self.images[random_id].clone() 49 | self.images[random_id] = image 50 | return_images.append(tmp) 51 | else: # by another 50% chance, the buffer will return the current image 52 | return_images.append(image) 53 | return_images = torch.cat(return_images, 0) # collect all the images and return 54 | return return_images -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | 5 | 6 | 7 | def log10(x): 8 | """Convert a new tensor with the base-10 logarithm of the elements of x. """ 9 | return torch.log(x) / math.log(10) 10 | 11 | class Result(object): 12 | def __init__(self, mask_min, mask_max): 13 | self.irmse, self.imae = 0, 0 14 | self.mse, self.rmse, self.mae = 0, 0, 0 15 | self.absrel, self.lg10 = 0, 0 16 | self.delta1, self.delta2, self.delta3 = 0, 0, 0 17 | self.data_time, self.gpu_time = 0, 0 18 | self.mask_min = mask_min 19 | self.mask_max = mask_max 20 | 21 | def set_to_worst(self): 22 | self.irmse, self.imae = np.inf, np.inf 23 | self.mse, self.rmse, self.mae = np.inf, np.inf, np.inf 24 | self.absrel, self.lg10 = np.inf, np.inf 25 | self.delta1, self.delta2, self.delta3 = 0, 0, 0 26 | self.data_time, self.gpu_time = 0, 0 27 | 28 | def update(self, irmse, imae, mse, rmse, mae, absrel, lg10, delta1, delta2, delta3, gpu_time, data_time): 29 | self.irmse, self.imae = irmse, imae 30 | self.mse, self.rmse, self.mae = mse, rmse, mae 31 | self.absrel, self.lg10 = absrel, lg10 32 | self.delta1, self.delta2, self.delta3 = delta1, delta2, delta3 33 | self.data_time, self.gpu_time = data_time, gpu_time 34 | 35 | def evaluate(self, output, target): 36 | 37 | # not quite sure whether this is useful 38 | # target[target < self.mask_min] = self.mask_min 39 | # target[target > self.mask_max] = self.mask_max 40 | 41 | valid_mask = np.logical_and(target > self.mask_min, target < self.mask_max) 42 | output = output[valid_mask] 43 | target = target[valid_mask] 44 | 45 | abs_diff = (output - target).abs() 46 | diff = (output - target) 47 | 48 | self.mse = float((torch.pow(abs_diff, 2)).mean()) 49 | self.rmse = math.sqrt(self.mse) 50 | self.rmselog = math.sqrt(float(((torch.log(target) - torch.log(output)) ** 2).mean())) 51 | 52 | self.mae = float(abs_diff.mean()) 53 | self.lg10 = float((log10(output) - log10(target)).abs().mean()) 54 | self.absrel = float((abs_diff / target).mean()) 55 | self.sqrel = float(((diff ** 2) / target).mean()) 56 | 57 | maxRatio = torch.max(output / target, target / output) 58 | self.delta1 = float((maxRatio < 1.25).float().mean()) 59 | self.delta2 = float((maxRatio < 1.25 ** 2).float().mean()) 60 | self.delta3 = float((maxRatio < 1.25 ** 3).float().mean()) 61 | self.data_time = 0 62 | self.gpu_time = 0 63 | 64 | inv_output = 1 / output 65 | inv_target = 1 / target 66 | abs_inv_diff = (inv_output - inv_target).abs() 67 | self.irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean()) 68 | self.imae = float(abs_inv_diff.mean()) 69 | 70 | 71 | class Result_withIdx(object): 72 | def __init__(self, mask_min, mask_max): 73 | self.irmse, self.imae = 0, 0 74 | self.mse, self.rmse, self.mae = 0, 0, 0 75 | self.absrel, self.lg10 = 0, 0 76 | self.delta1, self.delta2, self.delta3 = 0, 0, 0 77 | self.data_time, self.gpu_time = 0, 0 78 | self.mask_min = mask_min 79 | self.mask_max = mask_max 80 | 81 | def set_to_worst(self): 82 | self.irmse, self.imae = np.inf, np.inf 83 | self.mse, self.rmse, self.mae = np.inf, np.inf, np.inf 84 | self.absrel, self.lg10 = np.inf, np.inf 85 | self.delta1, self.delta2, self.delta3 = 0, 0, 0 86 | self.data_time, self.gpu_time = 0, 0 87 | 88 | def update(self, irmse, imae, mse, rmse, mae, absrel, lg10, delta1, delta2, delta3, gpu_time, data_time): 89 | self.irmse, self.imae = irmse, imae 90 | self.mse, self.rmse, self.mae = mse, rmse, mae 91 | self.absrel, self.lg10 = absrel, lg10 92 | self.delta1, self.delta2, self.delta3 = delta1, delta2, delta3 93 | self.data_time, self.gpu_time = data_time, gpu_time 94 | 95 | def evaluate(self, output, target, idx_tensor): 96 | # idx_tensor should have the same size as output and target 97 | 98 | valid_mask = np.logical_and(target > self.mask_min, target < self.mask_max) 99 | # print(valid_mask.shape, type(valid_mask)) 100 | # print(valid_mask) 101 | # print(valid_mask.shape, idx_tensor.shape) 102 | final_mask = valid_mask & idx_tensor 103 | # print(final_mask.shape) 104 | output = output[final_mask] 105 | target = target[final_mask] 106 | 107 | abs_diff = (output - target).abs() 108 | diff = (output - target) 109 | 110 | self.mse = float((torch.pow(abs_diff, 2)).mean()) 111 | self.rmse = math.sqrt(self.mse) 112 | self.rmselog = math.sqrt(float(((torch.log(target) - torch.log(output)) ** 2).mean())) 113 | 114 | self.mae = float(abs_diff.mean()) 115 | self.lg10 = float((log10(output) - log10(target)).abs().mean()) 116 | self.absrel = float((abs_diff / target).mean()) 117 | self.sqrel = float(((diff ** 2) / target).mean()) 118 | 119 | maxRatio = torch.max(output / target, target / output) 120 | self.delta1 = float((maxRatio < 1.25).float().mean()) 121 | self.delta2 = float((maxRatio < 1.25 ** 2).float().mean()) 122 | self.delta3 = float((maxRatio < 1.25 ** 3).float().mean()) 123 | self.data_time = 0 124 | self.gpu_time = 0 125 | 126 | inv_output = 1 / output 127 | inv_target = 1 / target 128 | abs_inv_diff = (inv_output - inv_target).abs() 129 | self.irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean()) 130 | self.imae = float(abs_inv_diff.mean()) 131 | 132 | 133 | def miou(pred, target, n_classes=12): 134 | ious = [] 135 | pred = pred.view(-1) 136 | target = target.view(-1) 137 | 138 | # Ignore IoU for background class ("0") 139 | for cls in range(0, n_classes): # This goes from 1:n_classes-1 -> class "0" is ignored 140 | pred_inds = pred == cls 141 | target_inds = target == cls 142 | intersection = (pred_inds[target_inds]).long().sum().data.cpu()[0] # Cast to long to prevent overflows 143 | union = pred_inds.long().sum().data.cpu()[0] + target_inds.long().sum().data.cpu()[0] - intersection 144 | if union == 0: ious.append(float('nan')) # If there is no ground truth, do not include in evaluation 145 | else:ious.append(float(intersection) / float(max(union, 1))) 146 | return np.array(ious) 147 | 148 | 149 | def im2col_sliding_broadcasting(A, BSZ, stepsize=1): 150 | # Parameters 151 | M,N = A.shape[0],A.shape[1] 152 | col_extent = N - BSZ[1] + 1 153 | row_extent = M - BSZ[0] + 1 154 | 155 | # Get Starting block indices 156 | start_idx = np.arange(BSZ[0])[:,None]*N + np.arange(BSZ[1]) 157 | 158 | # Get offsetted indices across the height and width of input array 159 | offset_idx = np.arange(row_extent)[:,None]*N + np.arange(col_extent) 160 | 161 | # Get all actual indices & index into input array for final output 162 | return np.take (A,start_idx.ravel()[:,None] + offset_idx.ravel()[::stepsize]) 163 | 164 | 165 | def rgb2ycbcr(im): 166 | cbcr = np.empty_like(im) 167 | r = im[:,:,0] 168 | g = im[:,:,1] 169 | b = im[:,:,2] 170 | # Y 171 | cbcr[:,:,0] = .299 * r + .587 * g + .114 * b 172 | # Cb 173 | cbcr[:,:,1] = 128 - .169 * r - .331 * g + .5 * b 174 | # Cr 175 | cbcr[:,:,2] = 128 + .5 * r - .419 * g - .081 * b 176 | return cbcr # np.uint8(cbcr) 177 | 178 | def ycbcr2rgb(im): 179 | rgb = np.empty_like(im) 180 | y = im[:,:,0] 181 | cb = im[:,:,1] - 128 182 | cr = im[:,:,2] - 128 183 | # R 184 | rgb[:,:,0] = y + 1.402 * cr 185 | # G 186 | rgb[:,:,1] = y - .34414 * cb - .71414 * cr 187 | # B 188 | rgb[:,:,2] = y + 1.772 * cb 189 | return rgb # np.uint8(rgb) 190 | 191 | 192 | def img_greyscale(img): 193 | return 0.299 * img[:,:,0] + 0.587 * img[:,:,1] + 0.114 * img[:,:,2] 194 | --------------------------------------------------------------------------------