├── .gitignore ├── Monodepth.ipynb ├── README.md ├── data_loader.py ├── kitti_archives_to_download.txt ├── loss.py ├── main_monodepth_pytorch.py ├── models_resnet.py ├── readme_images └── demo.gif ├── transforms.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | __pycache__/ 3 | .idea 4 | data 5 | -------------------------------------------------------------------------------- /Monodepth.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import torch\n", 11 | "import numpy as np\n", 12 | "import skimage.transform\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "from easydict import EasyDict as edict\n", 15 | "\n", 16 | "from main_monodepth_pytorch import Model\n", 17 | "%reload_ext autoreload\n", 18 | "%autoreload 2" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## Train" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "Check if CUDA is available" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "torch.cuda.is_available()" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "torch.cuda.device_count()" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "torch.cuda.empty_cache()" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "dict_parameters = edict({'data_dir':'data/kitti/train/',\n", 69 | " 'val_data_dir':'data/kitti/val/',\n", 70 | " 'model_path':'data/models/monodepth_resnet18_001.pth',\n", 71 | " 'output_directory':'data/output/',\n", 72 | " 'input_height':256,\n", 73 | " 'input_width':512,\n", 74 | " 'model':'resnet18_md',\n", 75 | " 'pretrained':True,\n", 76 | " 'mode':'train',\n", 77 | " 'epochs':200,\n", 78 | " 'learning_rate':1e-4,\n", 79 | " 'batch_size': 8,\n", 80 | " 'adjust_lr':True,\n", 81 | " 'device':'cuda:0',\n", 82 | " 'do_augmentation':True,\n", 83 | " 'augment_parameters':[0.8, 1.2, 0.5, 2.0, 0.8, 1.2],\n", 84 | " 'print_images':False,\n", 85 | " 'print_weights':False,\n", 86 | " 'input_channels': 3,\n", 87 | " 'num_workers': 8,\n", 88 | " 'use_multiple_gpu': False})" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "model = Model(dict_parameters)\n", 98 | "#model.load('data/models/monodepth_resnet18_001_last.pth')" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": { 105 | "scrolled": true 106 | }, 107 | "outputs": [], 108 | "source": [ 109 | "model.train()" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "## Test the model" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "dict_parameters_test = edict({'data_dir':'data/test',\n", 126 | " 'model_path':'data/models/monodepth_resnet18_001_cpt.pth',\n", 127 | " 'output_directory':'data/output/',\n", 128 | " 'input_height':256,\n", 129 | " 'input_width':512,\n", 130 | " 'model':'resnet18_md',\n", 131 | " 'pretrained':False,\n", 132 | " 'mode':'test',\n", 133 | " 'device':'cuda:0',\n", 134 | " 'input_channels':3,\n", 135 | " 'num_workers':4,\n", 136 | " 'use_multiple_gpu':False})\n", 137 | "model_test = Model(dict_parameters_test)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "model_test.test()" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "disp = np.load('data/output/disparities_pp.npy') # Or disparities.npy for output without post-processing\n", 156 | "disp.shape" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "disp_to_img = skimage.transform.resize(disp[0].squeeze(), [375, 1242], mode='constant')\n", 166 | "plt.imshow(disp_to_img, cmap='plasma')" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "metadata": {}, 172 | "source": [ 173 | "Save a color image" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "plt.imsave(os.path.join(dict_parameters_test.output_directory,\n", 183 | " dict_parameters_test.model_path.split('/')[-1][:-4]+'_test_output.png'), disp_to_img, cmap='plasma')" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "metadata": {}, 189 | "source": [ 190 | "Save all test images" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "for i in range(disp.shape[0]):\n", 200 | " disp_to_img = skimage.transform.resize(disp[i].squeeze(), [375, 1242], mode='constant')\n", 201 | " plt.imsave(os.path.join(dict_parameters_test.output_directory,\n", 202 | " 'pred_'+str(i)+'.png'), disp_to_img, cmap='plasma')" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": {}, 208 | "source": [ 209 | "Save a grayscale image" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "plt.imsave(os.path.join(dict_parameters_test.output_directory,\n", 219 | " dict_parameters_test.model_path.split('/')[-1][:-4]+'_gray.png'), disp_to_img, cmap='gray')" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [] 228 | } 229 | ], 230 | "metadata": { 231 | "kernelspec": { 232 | "display_name": "Python 3", 233 | "language": "python", 234 | "name": "python3" 235 | }, 236 | "language_info": { 237 | "codemirror_mode": { 238 | "name": "ipython", 239 | "version": 3 240 | }, 241 | "file_extension": ".py", 242 | "mimetype": "text/x-python", 243 | "name": "python", 244 | "nbconvert_exporter": "python", 245 | "pygments_lexer": "ipython3", 246 | "version": "3.6.5" 247 | } 248 | }, 249 | "nbformat": 4, 250 | "nbformat_minor": 2 251 | } 252 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MonoDepth 2 | ![demo.gif animation](readme_images/demo.gif) 3 | 4 | This repo is inspired by an amazing work of [Clément Godard](http://www0.cs.ucl.ac.uk/staff/C.Godard/), [Oisin Mac Aodha](http://vision.caltech.edu/~macaodha/) and [Gabriel J. Brostow](http://www0.cs.ucl.ac.uk/staff/g.brostow/) for Unsupervised Monocular Depth Estimation. 5 | Original code and paper could be found via the following links: 6 | 1. [Original repo](https://github.com/mrharicot/monodepth) 7 | 2. [Original paper](https://arxiv.org/abs/1609.03677) 8 | 9 | ## MonoDepth-PyTorch 10 | This repository contains code and additional parts for the PyTorch port of the MonoDepth Deep Learning algorithm. For more information about original work, please visit [author's website](http://visual.cs.ucl.ac.uk/pubs/monoDepth/) 11 | 12 | ## Purpose 13 | 14 | Purpose of this repository is to make a more lightweight model for depth estimation with better accuracy. 15 | In our version of MonoDepth, we used ResNet50 as an encoder. It was slightly changed (with one more lateral shrinkage) as well as in the original repo. 16 | 17 | Also, we add ResNet18 version and used batch normalization in both cases for training stability. 18 | Moreover, we made flexible feature extractor with any version of original Resnet from torchvision models zoo 19 | with an option to use pretrained models. 20 | 21 | ## Dataset 22 | ### KITTI 23 | 24 | This algorithm requires stereo-pair images for training and single images for testing. 25 | [KITTI](http://www.cvlibs.net/datasets/kitti/raw_data.php) dataset was used for training. 26 | It contains 38237 training samples. 27 | Raw dataset (about 175 GB) can be downloaded by running: 28 | ```shell 29 | wget -i kitti_archives_to_download.txt -P ~/my/output/folder/ 30 | ``` 31 | kitti_archives_to_download.txt may be found in this repo. 32 | 33 | ## Dataloader 34 | Dataloader assumes the following structure of the folder with train examples (**'data_dir'** argument contains path to that folder): 35 | The folder contains subfolders with following folders "image_02/data" for left images and "image_03/data" for right images. 36 | Such structure is default for KITTI dataset 37 | 38 | Example data folder structure (path to the "kitti" directory should be passed as **'data_dir'** in this example): 39 | ``` 40 | data 41 | ├── kitti 42 | │   ├── 2011_09_26_drive_0001_sync 43 | │   │   ├── image_02 44 | │   │   │   ├─ data 45 | │   │   │   │   ├── 0000000000.png 46 | │   │   │   │   └── ... 47 | │   │   ├── image_03 48 | │   │   │ ├── data 49 | │   │   │   │   ├── 0000000000.png 50 | │   │   │   │   └── ... 51 | │   ├── ... 52 | ├── models 53 | ├── output 54 | ├── test 55 | │   ├── left 56 | │   │   ├── test_1.jpg 57 | │   │   └── ... 58 | ``` 59 | 60 | ## Training 61 | Example of training can be find in [Monodepth](Monodepth.ipynb) notebook. 62 | 63 | Model class from main_monodepth_pytorch.py should be initialized with following params (as easydict) for training: 64 | - `data_dir`: path to the dataset folder 65 | - `val_data_dir`: path to the validation dataset folder 66 | - `model_path`: path to save the trained model 67 | - `output_directory`: where save dispairities for tested images 68 | - `input_height` 69 | - `input_width` 70 | - `model`: model for encoder (resnet18_md or resnet50_md or any torchvision version of Resnet (resnet18, resnet34 etc.) 71 | - `pretrained`: if use a torchvision model it's possible to download weights for pretrained model 72 | - `mode`: train or test 73 | - `epochs`: number of epochs, 74 | - `learning_rate` 75 | - `batch_size` 76 | - `adjust_lr`: apply learning rate decay or not 77 | - `tensor_type`:'torch.cuda.FloatTensor' or 'torch.FloatTensor' 78 | - `do_augmentation`:do data augmentation or not 79 | - `augment_parameters`:lowest and highest values for gamma, lightness and color respectively 80 | - `print_images` 81 | - `print_weights` 82 | - `input_channels` Number of channels in input tensor (3 for RGB images) 83 | - `num_workers` Number of workers to use in dataloader 84 | 85 | Optionally after initialization, we can load a pretrained model via `model.load`. 86 | 87 | After that calling train() on Model class object starts the training process. 88 | 89 | Also, it can be started via calling main_monodepth_pytorch.py through the terminal and feeding parameters as argparse arguments. 90 | 91 | ## Train results and pretrained model 92 | 93 | Results presented on the gif image were obtained using the model with a **resnet18** as an encoder, which can be downloaded from [here](https://my.pcloud.com/publink/show?code=XZb5r97ZD7HDDlc237BMjoCbWJVYMm0FLKcy). 94 | 95 | For training the following parameters were used: 96 | ``` 97 | `model`: 'resnet18_md' 98 | `epochs`: 200, 99 | `learning_rate`: 1e-4, 100 | `batch_size`: 8, 101 | `adjust_lr`: True, 102 | `do_augmentation`: True 103 | ``` 104 | The provided model was trained on the whole dataset, except subsets, listed below, which were used for a hold-out validation. 105 | 106 | ``` 107 | 2011_09_26_drive_0002_sync 2011_09_29_drive_0071_sync 108 | 2011_09_26_drive_0014_sync 2011_09_30_drive_0033_sync 109 | 2011_09_26_drive_0020_sync 2011_10_03_drive_0042_sync 110 | 2011_09_26_drive_0079_sync 111 | ``` 112 | 113 | The demo gif image is a visualization of the predictions on `2011_09_26_drive_0014_sync` subset. 114 | 115 | See [Monodepth](Monodepth.ipynb) notebook for the details on the training. 116 | 117 | ## Testing 118 | Example of testing can also be find in [Monodepth](Monodepth.ipynb) notebook. 119 | 120 | Model class from main_monodepth_pytorch.py should be initialized with following params (as easydict) for testing: 121 | - `data_dir`: path to the dataset folder 122 | - `model_path`: path to save the trained model 123 | - `pretrained`: 124 | - `output_directory`: where save dispairities for tested images 125 | - `input_height` 126 | - `input_width` 127 | - `model`: model for encoder (resnet18 or resnet50) 128 | - `mode`: train or test 129 | - `input_channels` Number of channels in input tensor (3 for RGB images) 130 | - `num_workers` Number of workers to use in dataloader 131 | 132 | After that calling test() on Model class object starts testing process. 133 | 134 | Also it can be started via calling [main_monodepth_pytorch.py](main_monodepth_pytorch.py) through the terminal and feeding parameters as argparse arguments. 135 | 136 | ## Requirements 137 | This code was tested with PyTorch 0.4.1, CUDA 9.1 and Ubuntu 16.04. Other required modules: 138 | 139 | ``` 140 | torchvision 141 | numpy 142 | matplotlib 143 | easydict 144 | ``` 145 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class KittiLoader(Dataset): 8 | def __init__(self, root_dir, mode, transform=None): 9 | left_dir = os.path.join(root_dir, 'image_02/data/') 10 | self.left_paths = sorted([os.path.join(left_dir, fname) for fname\ 11 | in os.listdir(left_dir)]) 12 | if mode == 'train': 13 | right_dir = os.path.join(root_dir, 'image_03/data/') 14 | self.right_paths = sorted([os.path.join(right_dir, fname) for fname\ 15 | in os.listdir(right_dir)]) 16 | assert len(self.right_paths) == len(self.left_paths) 17 | self.transform = transform 18 | self.mode = mode 19 | 20 | 21 | def __len__(self): 22 | return len(self.left_paths) 23 | 24 | def __getitem__(self, idx): 25 | left_image = Image.open(self.left_paths[idx]) 26 | if self.mode == 'train': 27 | right_image = Image.open(self.right_paths[idx]) 28 | sample = {'left_image': left_image, 'right_image': right_image} 29 | 30 | if self.transform: 31 | sample = self.transform(sample) 32 | return sample 33 | else: 34 | return sample 35 | else: 36 | if self.transform: 37 | left_image = self.transform(left_image) 38 | return left_image 39 | -------------------------------------------------------------------------------- /kitti_archives_to_download.txt: -------------------------------------------------------------------------------- 1 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_calib.zip 2 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0001/2011_09_26_drive_0001_sync.zip 3 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0002/2011_09_26_drive_0002_sync.zip 4 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0005/2011_09_26_drive_0005_sync.zip 5 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0009/2011_09_26_drive_0009_sync.zip 6 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0011/2011_09_26_drive_0011_sync.zip 7 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0013/2011_09_26_drive_0013_sync.zip 8 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0014/2011_09_26_drive_0014_sync.zip 9 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0015/2011_09_26_drive_0015_sync.zip 10 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0017/2011_09_26_drive_0017_sync.zip 11 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0018/2011_09_26_drive_0018_sync.zip 12 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0019/2011_09_26_drive_0019_sync.zip 13 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0020/2011_09_26_drive_0020_sync.zip 14 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0022/2011_09_26_drive_0022_sync.zip 15 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0023/2011_09_26_drive_0023_sync.zip 16 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0027/2011_09_26_drive_0027_sync.zip 17 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0028/2011_09_26_drive_0028_sync.zip 18 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0029/2011_09_26_drive_0029_sync.zip 19 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0032/2011_09_26_drive_0032_sync.zip 20 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0035/2011_09_26_drive_0035_sync.zip 21 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0036/2011_09_26_drive_0036_sync.zip 22 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0039/2011_09_26_drive_0039_sync.zip 23 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0046/2011_09_26_drive_0046_sync.zip 24 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0048/2011_09_26_drive_0048_sync.zip 25 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0051/2011_09_26_drive_0051_sync.zip 26 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0052/2011_09_26_drive_0052_sync.zip 27 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0056/2011_09_26_drive_0056_sync.zip 28 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0057/2011_09_26_drive_0057_sync.zip 29 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0059/2011_09_26_drive_0059_sync.zip 30 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0060/2011_09_26_drive_0060_sync.zip 31 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0061/2011_09_26_drive_0061_sync.zip 32 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0064/2011_09_26_drive_0064_sync.zip 33 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0070/2011_09_26_drive_0070_sync.zip 34 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0079/2011_09_26_drive_0079_sync.zip 35 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0084/2011_09_26_drive_0084_sync.zip 36 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0086/2011_09_26_drive_0086_sync.zip 37 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0087/2011_09_26_drive_0087_sync.zip 38 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0091/2011_09_26_drive_0091_sync.zip 39 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0093/2011_09_26_drive_0093_sync.zip 40 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0095/2011_09_26_drive_0095_sync.zip 41 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0096/2011_09_26_drive_0096_sync.zip 42 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0101/2011_09_26_drive_0101_sync.zip 43 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0104/2011_09_26_drive_0104_sync.zip 44 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0106/2011_09_26_drive_0106_sync.zip 45 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0113/2011_09_26_drive_0113_sync.zip 46 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_drive_0117/2011_09_26_drive_0117_sync.zip 47 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_28_calib.zip 48 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_28_drive_0001/2011_09_28_drive_0001_sync.zip 49 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_28_drive_0002/2011_09_28_drive_0002_sync.zip 50 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_29_calib.zip 51 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_29_drive_0004/2011_09_29_drive_0004_sync.zip 52 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_29_drive_0026/2011_09_29_drive_0026_sync.zip 53 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_29_drive_0071/2011_09_29_drive_0071_sync.zip 54 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_30_calib.zip 55 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_30_drive_0016/2011_09_30_drive_0016_sync.zip 56 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_30_drive_0018/2011_09_30_drive_0018_sync.zip 57 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_30_drive_0020/2011_09_30_drive_0020_sync.zip 58 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_30_drive_0027/2011_09_30_drive_0027_sync.zip 59 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_30_drive_0028/2011_09_30_drive_0028_sync.zip 60 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_30_drive_0033/2011_09_30_drive_0033_sync.zip 61 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_30_drive_0034/2011_09_30_drive_0034_sync.zip 62 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_10_03_calib.zip 63 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_10_03_drive_0027/2011_10_03_drive_0027_sync.zip 64 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_10_03_drive_0034/2011_10_03_drive_0034_sync.zip 65 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_10_03_drive_0042/2011_10_03_drive_0042_sync.zip 66 | https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_10_03_drive_0047/2011_10_03_drive_0047_sync.zip 67 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MonodepthLoss(nn.modules.Module): 7 | def __init__(self, n=4, SSIM_w=0.85, disp_gradient_w=1.0, lr_w=1.0): 8 | super(MonodepthLoss, self).__init__() 9 | self.SSIM_w = SSIM_w 10 | self.disp_gradient_w = disp_gradient_w 11 | self.lr_w = lr_w 12 | self.n = n 13 | 14 | def scale_pyramid(self, img, num_scales): 15 | scaled_imgs = [img] 16 | s = img.size() 17 | h = s[2] 18 | w = s[3] 19 | for i in range(num_scales - 1): 20 | ratio = 2 ** (i + 1) 21 | nh = h // ratio 22 | nw = w // ratio 23 | scaled_imgs.append(nn.functional.interpolate(img, 24 | size=[nh, nw], mode='bilinear', 25 | align_corners=True)) 26 | return scaled_imgs 27 | 28 | def gradient_x(self, img): 29 | # Pad input to keep output size consistent 30 | img = F.pad(img, (0, 1, 0, 0), mode="replicate") 31 | gx = img[:, :, :, :-1] - img[:, :, :, 1:] # NCHW 32 | return gx 33 | 34 | def gradient_y(self, img): 35 | # Pad input to keep output size consistent 36 | img = F.pad(img, (0, 0, 0, 1), mode="replicate") 37 | gy = img[:, :, :-1, :] - img[:, :, 1:, :] # NCHW 38 | return gy 39 | 40 | def apply_disparity(self, img, disp): 41 | batch_size, _, height, width = img.size() 42 | 43 | # Original coordinates of pixels 44 | x_base = torch.linspace(0, 1, width).repeat(batch_size, 45 | height, 1).type_as(img) 46 | y_base = torch.linspace(0, 1, height).repeat(batch_size, 47 | width, 1).transpose(1, 2).type_as(img) 48 | 49 | # Apply shift in X direction 50 | x_shifts = disp[:, 0, :, :] # Disparity is passed in NCHW format with 1 channel 51 | flow_field = torch.stack((x_base + x_shifts, y_base), dim=3) 52 | # In grid_sample coordinates are assumed to be between -1 and 1 53 | output = F.grid_sample(img, 2*flow_field - 1, mode='bilinear', 54 | padding_mode='zeros') 55 | 56 | return output 57 | 58 | def generate_image_left(self, img, disp): 59 | return self.apply_disparity(img, -disp) 60 | 61 | def generate_image_right(self, img, disp): 62 | return self.apply_disparity(img, disp) 63 | 64 | def SSIM(self, x, y): 65 | C1 = 0.01 ** 2 66 | C2 = 0.03 ** 2 67 | 68 | mu_x = nn.AvgPool2d(3, 1)(x) 69 | mu_y = nn.AvgPool2d(3, 1)(y) 70 | mu_x_mu_y = mu_x * mu_y 71 | mu_x_sq = mu_x.pow(2) 72 | mu_y_sq = mu_y.pow(2) 73 | 74 | sigma_x = nn.AvgPool2d(3, 1)(x * x) - mu_x_sq 75 | sigma_y = nn.AvgPool2d(3, 1)(y * y) - mu_y_sq 76 | sigma_xy = nn.AvgPool2d(3, 1)(x * y) - mu_x_mu_y 77 | 78 | SSIM_n = (2 * mu_x_mu_y + C1) * (2 * sigma_xy + C2) 79 | SSIM_d = (mu_x_sq + mu_y_sq + C1) * (sigma_x + sigma_y + C2) 80 | SSIM = SSIM_n / SSIM_d 81 | 82 | return torch.clamp((1 - SSIM) / 2, 0, 1) 83 | 84 | def disp_smoothness(self, disp, pyramid): 85 | disp_gradients_x = [self.gradient_x(d) for d in disp] 86 | disp_gradients_y = [self.gradient_y(d) for d in disp] 87 | 88 | image_gradients_x = [self.gradient_x(img) for img in pyramid] 89 | image_gradients_y = [self.gradient_y(img) for img in pyramid] 90 | 91 | weights_x = [torch.exp(-torch.mean(torch.abs(g), 1, 92 | keepdim=True)) for g in image_gradients_x] 93 | weights_y = [torch.exp(-torch.mean(torch.abs(g), 1, 94 | keepdim=True)) for g in image_gradients_y] 95 | 96 | smoothness_x = [disp_gradients_x[i] * weights_x[i] 97 | for i in range(self.n)] 98 | smoothness_y = [disp_gradients_y[i] * weights_y[i] 99 | for i in range(self.n)] 100 | 101 | return [torch.abs(smoothness_x[i]) + torch.abs(smoothness_y[i]) 102 | for i in range(self.n)] 103 | 104 | def forward(self, input, target): 105 | """ 106 | Args: 107 | input [disp1, disp2, disp3, disp4] 108 | target [left, right] 109 | 110 | Return: 111 | (float): The loss 112 | """ 113 | left, right = target 114 | left_pyramid = self.scale_pyramid(left, self.n) 115 | right_pyramid = self.scale_pyramid(right, self.n) 116 | 117 | # Prepare disparities 118 | disp_left_est = [d[:, 0, :, :].unsqueeze(1) for d in input] 119 | disp_right_est = [d[:, 1, :, :].unsqueeze(1) for d in input] 120 | 121 | self.disp_left_est = disp_left_est 122 | self.disp_right_est = disp_right_est 123 | # Generate images 124 | left_est = [self.generate_image_left(right_pyramid[i], 125 | disp_left_est[i]) for i in range(self.n)] 126 | right_est = [self.generate_image_right(left_pyramid[i], 127 | disp_right_est[i]) for i in range(self.n)] 128 | self.left_est = left_est 129 | self.right_est = right_est 130 | 131 | # L-R Consistency 132 | right_left_disp = [self.generate_image_left(disp_right_est[i], 133 | disp_left_est[i]) for i in range(self.n)] 134 | left_right_disp = [self.generate_image_right(disp_left_est[i], 135 | disp_right_est[i]) for i in range(self.n)] 136 | 137 | # Disparities smoothness 138 | disp_left_smoothness = self.disp_smoothness(disp_left_est, 139 | left_pyramid) 140 | disp_right_smoothness = self.disp_smoothness(disp_right_est, 141 | right_pyramid) 142 | 143 | # L1 144 | l1_left = [torch.mean(torch.abs(left_est[i] - left_pyramid[i])) 145 | for i in range(self.n)] 146 | l1_right = [torch.mean(torch.abs(right_est[i] 147 | - right_pyramid[i])) for i in range(self.n)] 148 | 149 | # SSIM 150 | ssim_left = [torch.mean(self.SSIM(left_est[i], 151 | left_pyramid[i])) for i in range(self.n)] 152 | ssim_right = [torch.mean(self.SSIM(right_est[i], 153 | right_pyramid[i])) for i in range(self.n)] 154 | 155 | image_loss_left = [self.SSIM_w * ssim_left[i] 156 | + (1 - self.SSIM_w) * l1_left[i] 157 | for i in range(self.n)] 158 | image_loss_right = [self.SSIM_w * ssim_right[i] 159 | + (1 - self.SSIM_w) * l1_right[i] 160 | for i in range(self.n)] 161 | image_loss = sum(image_loss_left + image_loss_right) 162 | 163 | # L-R Consistency 164 | lr_left_loss = [torch.mean(torch.abs(right_left_disp[i] 165 | - disp_left_est[i])) for i in range(self.n)] 166 | lr_right_loss = [torch.mean(torch.abs(left_right_disp[i] 167 | - disp_right_est[i])) for i in range(self.n)] 168 | lr_loss = sum(lr_left_loss + lr_right_loss) 169 | 170 | # Disparities smoothness 171 | disp_left_loss = [torch.mean(torch.abs( 172 | disp_left_smoothness[i])) / 2 ** i 173 | for i in range(self.n)] 174 | disp_right_loss = [torch.mean(torch.abs( 175 | disp_right_smoothness[i])) / 2 ** i 176 | for i in range(self.n)] 177 | disp_gradient_loss = sum(disp_left_loss + disp_right_loss) 178 | 179 | loss = image_loss + self.disp_gradient_w * disp_gradient_loss\ 180 | + self.lr_w * lr_loss 181 | self.image_loss = image_loss 182 | self.disp_gradient_loss = disp_gradient_loss 183 | self.lr_loss = lr_loss 184 | return loss 185 | -------------------------------------------------------------------------------- /main_monodepth_pytorch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import torch 4 | import numpy as np 5 | import torch.optim as optim 6 | 7 | # custom modules 8 | 9 | from loss import MonodepthLoss 10 | from utils import get_model, to_device, prepare_dataloader 11 | 12 | # plot params 13 | 14 | import matplotlib.pyplot as plt 15 | import matplotlib as mpl 16 | mpl.rcParams['figure.figsize'] = (15, 10) 17 | 18 | 19 | def return_arguments(): 20 | parser = argparse.ArgumentParser(description='PyTorch Monodepth') 21 | 22 | parser.add_argument('data_dir', 23 | help='path to the dataset folder. \ 24 | It should contain subfolders with following structure:\ 25 | "image_02/data" for left images and \ 26 | "image_03/data" for right images' 27 | ) 28 | parser.add_argument('val_data_dir', 29 | help='path to the validation dataset folder. \ 30 | It should contain subfolders with following structure:\ 31 | "image_02/data" for left images and \ 32 | "image_03/data" for right images' 33 | ) 34 | parser.add_argument('model_path', help='path to the trained model') 35 | parser.add_argument('output_directory', 36 | help='where save dispairities\ 37 | for tested images' 38 | ) 39 | parser.add_argument('--input_height', type=int, help='input height', 40 | default=256) 41 | parser.add_argument('--input_width', type=int, help='input width', 42 | default=512) 43 | parser.add_argument('--model', default='resnet18_md', 44 | help='encoder architecture: ' + 45 | 'resnet18_md or resnet50_md ' + '(default: resnet18)' 46 | + 'or torchvision version of any resnet model' 47 | ) 48 | parser.add_argument('--pretrained', default=False, 49 | help='Use weights of pretrained model' 50 | ) 51 | parser.add_argument('--mode', default='train', 52 | help='mode: train or test (default: train)') 53 | parser.add_argument('--epochs', default=50, 54 | help='number of total epochs to run') 55 | parser.add_argument('--learning_rate', default=1e-4, 56 | help='initial learning rate (default: 1e-4)') 57 | parser.add_argument('--batch_size', default=256, 58 | help='mini-batch size (default: 256)') 59 | parser.add_argument('--adjust_lr', default=True, 60 | help='apply learning rate decay or not\ 61 | (default: True)' 62 | ) 63 | parser.add_argument('--device', 64 | default='cuda:0', 65 | help='choose cpu or cuda:0 device"' 66 | ) 67 | parser.add_argument('--do_augmentation', default=True, 68 | help='do augmentation of images or not') 69 | parser.add_argument('--augment_parameters', default=[ 70 | 0.8, 71 | 1.2, 72 | 0.5, 73 | 2.0, 74 | 0.8, 75 | 1.2, 76 | ], 77 | help='lowest and highest values for gamma,\ 78 | brightness and color respectively' 79 | ) 80 | parser.add_argument('--print_images', default=False, 81 | help='print disparity and image\ 82 | generated from disparity on every iteration' 83 | ) 84 | parser.add_argument('--print_weights', default=False, 85 | help='print weights of every layer') 86 | parser.add_argument('--input_channels', default=3, 87 | help='Number of channels in input tensor') 88 | parser.add_argument('--num_workers', default=4, 89 | help='Number of workers in dataloader') 90 | parser.add_argument('--use_multiple_gpu', default=False) 91 | args = parser.parse_args() 92 | return args 93 | 94 | 95 | def adjust_learning_rate(optimizer, epoch, learning_rate): 96 | """Sets the learning rate to the initial LR\ 97 | decayed by 2 every 10 epochs after 30 epoches""" 98 | 99 | if epoch >= 30 and epoch < 40: 100 | lr = learning_rate / 2 101 | elif epoch >= 40: 102 | lr = learning_rate / 4 103 | else: 104 | lr = learning_rate 105 | for param_group in optimizer.param_groups: 106 | param_group['lr'] = lr 107 | 108 | 109 | def post_process_disparity(disp): 110 | (_, h, w) = disp.shape 111 | l_disp = disp[0, :, :] 112 | r_disp = np.fliplr(disp[1, :, :]) 113 | m_disp = 0.5 * (l_disp + r_disp) 114 | (l, _) = np.meshgrid(np.linspace(0, 1, w), np.linspace(0, 1, h)) 115 | l_mask = 1.0 - np.clip(20 * (l - 0.05), 0, 1) 116 | r_mask = np.fliplr(l_mask) 117 | return r_mask * l_disp + l_mask * r_disp + (1.0 - l_mask - r_mask) * m_disp 118 | 119 | 120 | class Model: 121 | 122 | def __init__(self, args): 123 | self.args = args 124 | 125 | # Set up model 126 | self.device = args.device 127 | self.model = get_model(args.model, input_channels=args.input_channels, pretrained=args.pretrained) 128 | self.model = self.model.to(self.device) 129 | if args.use_multiple_gpu: 130 | self.model = torch.nn.DataParallel(self.model) 131 | 132 | if args.mode == 'train': 133 | self.loss_function = MonodepthLoss( 134 | n=4, 135 | SSIM_w=0.85, 136 | disp_gradient_w=0.1, lr_w=1).to(self.device) 137 | self.optimizer = optim.Adam(self.model.parameters(), 138 | lr=args.learning_rate) 139 | self.val_n_img, self.val_loader = prepare_dataloader(args.val_data_dir, args.mode, 140 | args.augment_parameters, 141 | False, args.batch_size, 142 | (args.input_height, args.input_width), 143 | args.num_workers) 144 | else: 145 | self.model.load_state_dict(torch.load(args.model_path)) 146 | args.augment_parameters = None 147 | args.do_augmentation = False 148 | args.batch_size = 1 149 | 150 | # Load data 151 | self.output_directory = args.output_directory 152 | self.input_height = args.input_height 153 | self.input_width = args.input_width 154 | 155 | self.n_img, self.loader = prepare_dataloader(args.data_dir, args.mode, args.augment_parameters, 156 | args.do_augmentation, args.batch_size, 157 | (args.input_height, args.input_width), 158 | args.num_workers) 159 | 160 | 161 | if 'cuda' in self.device: 162 | torch.cuda.synchronize() 163 | 164 | 165 | def train(self): 166 | losses = [] 167 | val_losses = [] 168 | best_loss = float('Inf') 169 | best_val_loss = float('Inf') 170 | 171 | running_val_loss = 0.0 172 | self.model.eval() 173 | for data in self.val_loader: 174 | data = to_device(data, self.device) 175 | left = data['left_image'] 176 | right = data['right_image'] 177 | disps = self.model(left) 178 | loss = self.loss_function(disps, [left, right]) 179 | val_losses.append(loss.item()) 180 | running_val_loss += loss.item() 181 | 182 | running_val_loss /= self.val_n_img / self.args.batch_size 183 | print('Val_loss:', running_val_loss) 184 | 185 | for epoch in range(self.args.epochs): 186 | if self.args.adjust_lr: 187 | adjust_learning_rate(self.optimizer, epoch, 188 | self.args.learning_rate) 189 | c_time = time.time() 190 | running_loss = 0.0 191 | self.model.train() 192 | for data in self.loader: 193 | # Load data 194 | data = to_device(data, self.device) 195 | left = data['left_image'] 196 | right = data['right_image'] 197 | 198 | # One optimization iteration 199 | self.optimizer.zero_grad() 200 | disps = self.model(left) 201 | loss = self.loss_function(disps, [left, right]) 202 | loss.backward() 203 | self.optimizer.step() 204 | losses.append(loss.item()) 205 | 206 | # Print statistics 207 | if self.args.print_weights: 208 | j = 1 209 | for (name, parameter) in self.model.named_parameters(): 210 | if name.split(sep='.')[-1] == 'weight': 211 | plt.subplot(5, 9, j) 212 | plt.hist(parameter.data.view(-1)) 213 | plt.xlim([-1, 1]) 214 | plt.title(name.split(sep='.')[0]) 215 | j += 1 216 | plt.show() 217 | 218 | if self.args.print_images: 219 | print('disp_left_est[0]') 220 | plt.imshow(np.squeeze( 221 | np.transpose(self.loss_function.disp_left_est[0][0, 222 | :, :, :].cpu().detach().numpy(), 223 | (1, 2, 0)))) 224 | plt.show() 225 | print('left_est[0]') 226 | plt.imshow(np.transpose(self.loss_function\ 227 | .left_est[0][0, :, :, :].cpu().detach().numpy(), 228 | (1, 2, 0))) 229 | plt.show() 230 | print('disp_right_est[0]') 231 | plt.imshow(np.squeeze( 232 | np.transpose(self.loss_function.disp_right_est[0][0, 233 | :, :, :].cpu().detach().numpy(), 234 | (1, 2, 0)))) 235 | plt.show() 236 | print('right_est[0]') 237 | plt.imshow(np.transpose(self.loss_function.right_est[0][0, 238 | :, :, :].cpu().detach().numpy(), (1, 2, 239 | 0))) 240 | plt.show() 241 | running_loss += loss.item() 242 | 243 | running_val_loss = 0.0 244 | self.model.eval() 245 | for data in self.val_loader: 246 | data = to_device(data, self.device) 247 | left = data['left_image'] 248 | right = data['right_image'] 249 | disps = self.model(left) 250 | loss = self.loss_function(disps, [left, right]) 251 | val_losses.append(loss.item()) 252 | running_val_loss += loss.item() 253 | 254 | # Estimate loss per image 255 | running_loss /= self.n_img / self.args.batch_size 256 | running_val_loss /= self.val_n_img / self.args.batch_size 257 | print ( 258 | 'Epoch:', 259 | epoch + 1, 260 | 'train_loss:', 261 | running_loss, 262 | 'val_loss:', 263 | running_val_loss, 264 | 'time:', 265 | round(time.time() - c_time, 3), 266 | 's', 267 | ) 268 | self.save(self.args.model_path[:-4] + '_last.pth') 269 | if running_val_loss < best_val_loss: 270 | self.save(self.args.model_path[:-4] + '_cpt.pth') 271 | best_val_loss = running_val_loss 272 | print('Model_saved') 273 | 274 | print ('Finished Training. Best loss:', best_loss) 275 | self.save(self.args.model_path) 276 | 277 | def save(self, path): 278 | torch.save(self.model.state_dict(), path) 279 | 280 | def load(self, path): 281 | self.model.load_state_dict(torch.load(path)) 282 | 283 | def test(self): 284 | self.model.eval() 285 | disparities = np.zeros((self.n_img, 286 | self.input_height, self.input_width), 287 | dtype=np.float32) 288 | disparities_pp = np.zeros((self.n_img, 289 | self.input_height, self.input_width), 290 | dtype=np.float32) 291 | with torch.no_grad(): 292 | for (i, data) in enumerate(self.loader): 293 | # Get the inputs 294 | data = to_device(data, self.device) 295 | left = data.squeeze() 296 | # Do a forward pass 297 | disps = self.model(left) 298 | disp = disps[0][:, 0, :, :].unsqueeze(1) 299 | disparities[i] = disp[0].squeeze().cpu().numpy() 300 | disparities_pp[i] = \ 301 | post_process_disparity(disps[0][:, 0, :, :]\ 302 | .cpu().numpy()) 303 | 304 | np.save(self.output_directory + '/disparities.npy', disparities) 305 | np.save(self.output_directory + '/disparities_pp.npy', 306 | disparities_pp) 307 | print('Finished Testing') 308 | 309 | 310 | def main(args): 311 | args = return_arguments() 312 | if args.mode == 'train': 313 | model = Model(args) 314 | model.train() 315 | elif args.mode == 'test': 316 | model_test = Model(args) 317 | model_test.test() 318 | 319 | 320 | if __name__ == '__main__': 321 | main() 322 | 323 | -------------------------------------------------------------------------------- /models_resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import importlib 7 | 8 | 9 | class conv(nn.Module): 10 | def __init__(self, num_in_layers, num_out_layers, kernel_size, stride): 11 | super(conv, self).__init__() 12 | self.kernel_size = kernel_size 13 | self.conv_base = nn.Conv2d(num_in_layers, num_out_layers, kernel_size=kernel_size, stride=stride) 14 | self.normalize = nn.BatchNorm2d(num_out_layers) 15 | 16 | def forward(self, x): 17 | p = int(np.floor((self.kernel_size-1)/2)) 18 | p2d = (p, p, p, p) 19 | x = self.conv_base(F.pad(x, p2d)) 20 | x = self.normalize(x) 21 | return F.elu(x, inplace=True) 22 | 23 | 24 | class convblock(nn.Module): 25 | def __init__(self, num_in_layers, num_out_layers, kernel_size): 26 | super(convblock, self).__init__() 27 | self.conv1 = conv(num_in_layers, num_out_layers, kernel_size, 1) 28 | self.conv2 = conv(num_out_layers, num_out_layers, kernel_size, 2) 29 | 30 | def forward(self, x): 31 | x = self.conv1(x) 32 | return self.conv2(x) 33 | 34 | 35 | class maxpool(nn.Module): 36 | def __init__(self, kernel_size): 37 | super(maxpool, self).__init__() 38 | self.kernel_size = kernel_size 39 | 40 | def forward(self, x): 41 | p = int(np.floor((self.kernel_size-1) / 2)) 42 | p2d = (p, p, p, p) 43 | return F.max_pool2d(F.pad(x, p2d), self.kernel_size, stride=2) 44 | 45 | 46 | class resconv(nn.Module): 47 | def __init__(self, num_in_layers, num_out_layers, stride): 48 | super(resconv, self).__init__() 49 | self.num_out_layers = num_out_layers 50 | self.stride = stride 51 | self.conv1 = conv(num_in_layers, num_out_layers, 1, 1) 52 | self.conv2 = conv(num_out_layers, num_out_layers, 3, stride) 53 | self.conv3 = nn.Conv2d(num_out_layers, 4*num_out_layers, kernel_size=1, stride=1) 54 | self.conv4 = nn.Conv2d(num_in_layers, 4*num_out_layers, kernel_size=1, stride=stride) 55 | self.normalize = nn.BatchNorm2d(4*num_out_layers) 56 | 57 | def forward(self, x): 58 | # do_proj = x.size()[1] != self.num_out_layers or self.stride == 2 59 | do_proj = True 60 | shortcut = [] 61 | x_out = self.conv1(x) 62 | x_out = self.conv2(x_out) 63 | x_out = self.conv3(x_out) 64 | if do_proj: 65 | shortcut = self.conv4(x) 66 | else: 67 | shortcut = x 68 | return F.elu(self.normalize(x_out + shortcut), inplace=True) 69 | 70 | 71 | class resconv_basic(nn.Module): 72 | # for resnet18 73 | def __init__(self, num_in_layers, num_out_layers, stride): 74 | super(resconv_basic, self).__init__() 75 | self.num_out_layers = num_out_layers 76 | self.stride = stride 77 | self.conv1 = conv(num_in_layers, num_out_layers, 3, stride) 78 | self.conv2 = conv(num_out_layers, num_out_layers, 3, 1) 79 | self.conv3 = nn.Conv2d(num_in_layers, num_out_layers, kernel_size=1, stride=stride) 80 | self.normalize = nn.BatchNorm2d(num_out_layers) 81 | 82 | def forward(self, x): 83 | # do_proj = x.size()[1] != self.num_out_layers or self.stride == 2 84 | do_proj = True 85 | shortcut = [] 86 | x_out = self.conv1(x) 87 | x_out = self.conv2(x_out) 88 | if do_proj: 89 | shortcut = self.conv3(x) 90 | else: 91 | shortcut = x 92 | return F.elu(self.normalize(x_out + shortcut), inplace=True) 93 | 94 | 95 | def resblock(num_in_layers, num_out_layers, num_blocks, stride): 96 | layers = [] 97 | layers.append(resconv(num_in_layers, num_out_layers, stride)) 98 | for i in range(1, num_blocks - 1): 99 | layers.append(resconv(4 * num_out_layers, num_out_layers, 1)) 100 | layers.append(resconv(4 * num_out_layers, num_out_layers, 1)) 101 | return nn.Sequential(*layers) 102 | 103 | 104 | def resblock_basic(num_in_layers, num_out_layers, num_blocks, stride): 105 | layers = [] 106 | layers.append(resconv_basic(num_in_layers, num_out_layers, stride)) 107 | for i in range(1, num_blocks): 108 | layers.append(resconv_basic(num_out_layers, num_out_layers, 1)) 109 | return nn.Sequential(*layers) 110 | 111 | 112 | class upconv(nn.Module): 113 | def __init__(self, num_in_layers, num_out_layers, kernel_size, scale): 114 | super(upconv, self).__init__() 115 | self.scale = scale 116 | self.conv1 = conv(num_in_layers, num_out_layers, kernel_size, 1) 117 | 118 | def forward(self, x): 119 | x = nn.functional.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=True) 120 | return self.conv1(x) 121 | 122 | 123 | class get_disp(nn.Module): 124 | def __init__(self, num_in_layers): 125 | super(get_disp, self).__init__() 126 | self.conv1 = nn.Conv2d(num_in_layers, 2, kernel_size=3, stride=1) 127 | self.normalize = nn.BatchNorm2d(2) 128 | self.sigmoid = torch.nn.Sigmoid() 129 | 130 | def forward(self, x): 131 | p = 1 132 | p2d = (p, p, p, p) 133 | x = self.conv1(F.pad(x, p2d)) 134 | x = self.normalize(x) 135 | return 0.3 * self.sigmoid(x) 136 | 137 | 138 | class Resnet50_md(nn.Module): 139 | def __init__(self, num_in_layers): 140 | super(Resnet50_md, self).__init__() 141 | # encoder 142 | self.conv1 = conv(num_in_layers, 64, 7, 2) # H/2 - 64D 143 | self.pool1 = maxpool(3) # H/4 - 64D 144 | self.conv2 = resblock(64, 64, 3, 2) # H/8 - 256D 145 | self.conv3 = resblock(256, 128, 4, 2) # H/16 - 512D 146 | self.conv4 = resblock(512, 256, 6, 2) # H/32 - 1024D 147 | self.conv5 = resblock(1024, 512, 3, 2) # H/64 - 2048D 148 | 149 | # decoder 150 | self.upconv6 = upconv(2048, 512, 3, 2) 151 | self.iconv6 = conv(1024 + 512, 512, 3, 1) 152 | 153 | self.upconv5 = upconv(512, 256, 3, 2) 154 | self.iconv5 = conv(512+256, 256, 3, 1) 155 | 156 | self.upconv4 = upconv(256, 128, 3, 2) 157 | self.iconv4 = conv(256+128, 128, 3, 1) 158 | self.disp4_layer = get_disp(128) 159 | 160 | self.upconv3 = upconv(128, 64, 3, 2) 161 | self.iconv3 = conv(64+64+2, 64, 3, 1) 162 | self.disp3_layer = get_disp(64) 163 | 164 | self.upconv2 = upconv(64, 32, 3, 2) 165 | self.iconv2 = conv(32+64+2, 32, 3, 1) 166 | self.disp2_layer = get_disp(32) 167 | 168 | self.upconv1 = upconv(32, 16, 3, 2) 169 | self.iconv1 = conv(16+2, 16, 3, 1) 170 | self.disp1_layer = get_disp(16) 171 | 172 | for m in self.modules(): 173 | if isinstance(m, nn.Conv2d): 174 | nn.init.xavier_uniform_(m.weight) 175 | 176 | def forward(self, x): 177 | # encoder 178 | x1 = self.conv1(x) 179 | x_pool1 = self.pool1(x1) 180 | x2 = self.conv2(x_pool1) 181 | x3 = self.conv3(x2) 182 | x4 = self.conv4(x3) 183 | x5 = self.conv5(x4) 184 | 185 | # skips 186 | skip1 = x1 187 | skip2 = x_pool1 188 | skip3 = x2 189 | skip4 = x3 190 | skip5 = x4 191 | 192 | # decoder 193 | upconv6 = self.upconv6(x5) 194 | concat6 = torch.cat((upconv6, skip5), 1) 195 | iconv6 = self.iconv6(concat6) 196 | 197 | upconv5 = self.upconv5(iconv6) 198 | concat5 = torch.cat((upconv5, skip4), 1) 199 | iconv5 = self.iconv5(concat5) 200 | 201 | upconv4 = self.upconv4(iconv5) 202 | concat4 = torch.cat((upconv4, skip3), 1) 203 | iconv4 = self.iconv4(concat4) 204 | self.disp4 = self.disp4_layer(iconv4) 205 | self.udisp4 = nn.functional.interpolate(self.disp4, scale_factor=2, mode='bilinear', align_corners=True) 206 | 207 | upconv3 = self.upconv3(iconv4) 208 | concat3 = torch.cat((upconv3, skip2, self.udisp4), 1) 209 | iconv3 = self.iconv3(concat3) 210 | self.disp3 = self.disp3_layer(iconv3) 211 | self.udisp3 = nn.functional.interpolate(self.disp3, scale_factor=2, mode='bilinear', align_corners=True) 212 | 213 | upconv2 = self.upconv2(iconv3) 214 | concat2 = torch.cat((upconv2, skip1, self.udisp3), 1) 215 | iconv2 = self.iconv2(concat2) 216 | self.disp2 = self.disp2_layer(iconv2) 217 | self.udisp2 = nn.functional.interpolate(self.disp2, scale_factor=2, mode='bilinear', align_corners=True) 218 | 219 | upconv1 = self.upconv1(iconv2) 220 | concat1 = torch.cat((upconv1, self.udisp2), 1) 221 | iconv1 = self.iconv1(concat1) 222 | self.disp1 = self.disp1_layer(iconv1) 223 | return self.disp1, self.disp2, self.disp3, self.disp4 224 | 225 | 226 | class Resnet18_md(nn.Module): 227 | def __init__(self, num_in_layers): 228 | super(Resnet18_md, self).__init__() 229 | # encoder 230 | self.conv1 = conv(num_in_layers, 64, 7, 2) # H/2 - 64D 231 | self.pool1 = maxpool(3) # H/4 - 64D 232 | self.conv2 = resblock_basic(64, 64, 2, 2) # H/8 - 64D 233 | self.conv3 = resblock_basic(64, 128, 2, 2) # H/16 - 128D 234 | self.conv4 = resblock_basic(128, 256, 2, 2) # H/32 - 256D 235 | self.conv5 = resblock_basic(256, 512, 2, 2) # H/64 - 512D 236 | 237 | # decoder 238 | self.upconv6 = upconv(512, 512, 3, 2) 239 | self.iconv6 = conv(256+512, 512, 3, 1) 240 | 241 | self.upconv5 = upconv(512, 256, 3, 2) 242 | self.iconv5 = conv(128+256, 256, 3, 1) 243 | 244 | self.upconv4 = upconv(256, 128, 3, 2) 245 | self.iconv4 = conv(64+128, 128, 3, 1) 246 | self.disp4_layer = get_disp(128) 247 | 248 | self.upconv3 = upconv(128, 64, 3, 2) 249 | self.iconv3 = conv(64+64 + 2, 64, 3, 1) 250 | self.disp3_layer = get_disp(64) 251 | 252 | self.upconv2 = upconv(64, 32, 3, 2) 253 | self.iconv2 = conv(64+32 + 2, 32, 3, 1) 254 | self.disp2_layer = get_disp(32) 255 | 256 | self.upconv1 = upconv(32, 16, 3, 2) 257 | self.iconv1 = conv(16+2, 16, 3, 1) 258 | self.disp1_layer = get_disp(16) 259 | 260 | for m in self.modules(): 261 | if isinstance(m, nn.Conv2d): 262 | nn.init.xavier_uniform_(m.weight) 263 | 264 | def forward(self, x): 265 | # encoder 266 | x1 = self.conv1(x) 267 | x_pool1 = self.pool1(x1) 268 | x2 = self.conv2(x_pool1) 269 | x3 = self.conv3(x2) 270 | x4 = self.conv4(x3) 271 | x5 = self.conv5(x4) 272 | 273 | # skips 274 | skip1 = x1 275 | skip2 = x_pool1 276 | skip3 = x2 277 | skip4 = x3 278 | skip5 = x4 279 | 280 | # decoder 281 | upconv6 = self.upconv6(x5) 282 | concat6 = torch.cat((upconv6, skip5), 1) 283 | iconv6 = self.iconv6(concat6) 284 | 285 | upconv5 = self.upconv5(iconv6) 286 | concat5 = torch.cat((upconv5, skip4), 1) 287 | iconv5 = self.iconv5(concat5) 288 | 289 | upconv4 = self.upconv4(iconv5) 290 | concat4 = torch.cat((upconv4, skip3), 1) 291 | iconv4 = self.iconv4(concat4) 292 | self.disp4 = self.disp4_layer(iconv4) 293 | self.udisp4 = nn.functional.interpolate(self.disp4, scale_factor=2, mode='bilinear', align_corners=True) 294 | 295 | upconv3 = self.upconv3(iconv4) 296 | concat3 = torch.cat((upconv3, skip2, self.udisp4), 1) 297 | iconv3 = self.iconv3(concat3) 298 | self.disp3 = self.disp3_layer(iconv3) 299 | self.udisp3 = nn.functional.interpolate(self.disp3, scale_factor=2, mode='bilinear', align_corners=True) 300 | 301 | upconv2 = self.upconv2(iconv3) 302 | concat2 = torch.cat((upconv2, skip1, self.udisp3), 1) 303 | iconv2 = self.iconv2(concat2) 304 | self.disp2 = self.disp2_layer(iconv2) 305 | self.udisp2 = nn.functional.interpolate(self.disp2, scale_factor=2, mode='bilinear', align_corners=True) 306 | 307 | upconv1 = self.upconv1(iconv2) 308 | concat1 = torch.cat((upconv1, self.udisp2), 1) 309 | iconv1 = self.iconv1(concat1) 310 | self.disp1 = self.disp1_layer(iconv1) 311 | return self.disp1, self.disp2, self.disp3, self.disp4 312 | 313 | 314 | def class_for_name(module_name, class_name): 315 | # load the module, will raise ImportError if module cannot be loaded 316 | m = importlib.import_module(module_name) 317 | # get the class, will raise AttributeError if class cannot be found 318 | return getattr(m, class_name) 319 | 320 | 321 | class ResnetModel(nn.Module): 322 | def __init__(self, num_in_layers, encoder='resnet18', pretrained=False): 323 | super(ResnetModel, self).__init__() 324 | assert encoder in ['resnet18', 'resnet34', 'resnet50',\ 325 | 'resnet101', 'resnet152'],\ 326 | "Incorrect encoder type" 327 | if encoder in ['resnet18', 'resnet34']: 328 | filters = [64, 128, 256, 512] 329 | else: 330 | filters = [256, 512, 1024, 2048] 331 | resnet = class_for_name("torchvision.models", encoder)\ 332 | (pretrained=pretrained) 333 | if num_in_layers != 3: # Number of input channels 334 | self.firstconv = nn.Conv2d(num_in_layers, 64, 335 | kernel_size=(7, 7), stride=(2, 2), 336 | padding=(3, 3), bias=False) 337 | else: 338 | self.firstconv = resnet.conv1 # H/2 339 | self.firstbn = resnet.bn1 340 | self.firstrelu = resnet.relu 341 | self.firstmaxpool = resnet.maxpool # H/4 342 | 343 | # encoder 344 | self.encoder1 = resnet.layer1 # H/4 345 | self.encoder2 = resnet.layer2 # H/8 346 | self.encoder3 = resnet.layer3 # H/16 347 | self.encoder4 = resnet.layer4 # H/32 348 | 349 | # decoder 350 | self.upconv6 = upconv(filters[3], 512, 3, 2) 351 | self.iconv6 = conv(filters[2] + 512, 512, 3, 1) 352 | 353 | self.upconv5 = upconv(512, 256, 3, 2) 354 | self.iconv5 = conv(filters[1] + 256, 256, 3, 1) 355 | 356 | self.upconv4 = upconv(256, 128, 3, 2) 357 | self.iconv4 = conv(filters[0] + 128, 128, 3, 1) 358 | self.disp4_layer = get_disp(128) 359 | 360 | self.upconv3 = upconv(128, 64, 3, 1) # 361 | self.iconv3 = conv(64 + 64 + 2, 64, 3, 1) 362 | self.disp3_layer = get_disp(64) 363 | 364 | self.upconv2 = upconv(64, 32, 3, 2) 365 | self.iconv2 = conv(64 + 32 + 2, 32, 3, 1) 366 | self.disp2_layer = get_disp(32) 367 | 368 | self.upconv1 = upconv(32, 16, 3, 2) 369 | self.iconv1 = conv(16 + 2, 16, 3, 1) 370 | self.disp1_layer = get_disp(16) 371 | 372 | for m in self.modules(): 373 | if isinstance(m, nn.Conv2d): 374 | nn.init.xavier_uniform_(m.weight) 375 | 376 | def forward(self, x): 377 | # encoder 378 | x_first_conv = self.firstconv(x) 379 | x = self.firstbn(x_first_conv) 380 | x = self.firstrelu(x) 381 | x_pool1 = self.firstmaxpool(x) 382 | x1 = self.encoder1(x_pool1) 383 | x2 = self.encoder2(x1) 384 | x3 = self.encoder3(x2) 385 | x4 = self.encoder4(x3) 386 | # skips 387 | skip1 = x_first_conv 388 | skip2 = x_pool1 389 | skip3 = x1 390 | skip4 = x2 391 | skip5 = x3 392 | 393 | # decoder 394 | upconv6 = self.upconv6(x4) 395 | concat6 = torch.cat((upconv6, skip5), 1) 396 | iconv6 = self.iconv6(concat6) 397 | 398 | upconv5 = self.upconv5(iconv6) 399 | concat5 = torch.cat((upconv5, skip4), 1) 400 | iconv5 = self.iconv5(concat5) 401 | 402 | upconv4 = self.upconv4(iconv5) 403 | concat4 = torch.cat((upconv4, skip3), 1) 404 | iconv4 = self.iconv4(concat4) 405 | self.disp4 = self.disp4_layer(iconv4) 406 | self.udisp4 = nn.functional.interpolate(self.disp4, scale_factor=1, mode='bilinear', align_corners=True) 407 | self.disp4 = nn.functional.interpolate(self.disp4, scale_factor=0.5, mode='bilinear', align_corners=True) 408 | 409 | upconv3 = self.upconv3(iconv4) 410 | concat3 = torch.cat((upconv3, skip2, self.udisp4), 1) 411 | iconv3 = self.iconv3(concat3) 412 | self.disp3 = self.disp3_layer(iconv3) 413 | self.udisp3 = nn.functional.interpolate(self.disp3, scale_factor=2, mode='bilinear', align_corners=True) 414 | 415 | upconv2 = self.upconv2(iconv3) 416 | concat2 = torch.cat((upconv2, skip1, self.udisp3), 1) 417 | iconv2 = self.iconv2(concat2) 418 | self.disp2 = self.disp2_layer(iconv2) 419 | self.udisp2 = nn.functional.interpolate(self.disp2, scale_factor=2, mode='bilinear', align_corners=True) 420 | 421 | upconv1 = self.upconv1(iconv2) 422 | concat1 = torch.cat((upconv1, self.udisp2), 1) 423 | iconv1 = self.iconv1(concat1) 424 | self.disp1 = self.disp1_layer(iconv1) 425 | return self.disp1, self.disp2, self.disp3, self.disp4 426 | -------------------------------------------------------------------------------- /readme_images/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OniroAI/MonoDepth-PyTorch/0b7d60bd1dab0e8b6a7a1bab9c0eb68ebda51c5c/readme_images/demo.gif -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import numpy as np 4 | 5 | 6 | 7 | def image_transforms(mode='train', augment_parameters=[0.8, 1.2, 0.5, 2.0, 0.8, 1.2], 8 | do_augmentation=True, transformations=None, size=(256, 512)): 9 | if mode == 'train': 10 | data_transform = transforms.Compose([ 11 | ResizeImage(train=True, size=size), 12 | RandomFlip(do_augmentation), 13 | ToTensor(train=True), 14 | AugmentImagePair(augment_parameters, do_augmentation) 15 | ]) 16 | return data_transform 17 | elif mode == 'test': 18 | data_transform = transforms.Compose([ 19 | ResizeImage(train=False, size=size), 20 | ToTensor(train=False), 21 | DoTest(), 22 | ]) 23 | return data_transform 24 | elif mode == 'custom': 25 | data_transform = transforms.Compose(transformations) 26 | return data_transform 27 | else: 28 | print('Wrong mode') 29 | 30 | 31 | class ResizeImage(object): 32 | def __init__(self, train=True, size=(256, 512)): 33 | self.train = train 34 | self.transform = transforms.Resize(size) 35 | 36 | def __call__(self, sample): 37 | if self.train: 38 | left_image = sample['left_image'] 39 | right_image = sample['right_image'] 40 | new_right_image = self.transform(right_image) 41 | new_left_image = self.transform(left_image) 42 | sample = {'left_image': new_left_image, 'right_image': new_right_image} 43 | else: 44 | left_image = sample 45 | new_left_image = self.transform(left_image) 46 | sample = new_left_image 47 | return sample 48 | 49 | 50 | class DoTest(object): 51 | def __call__(self, sample): 52 | new_sample = torch.stack((sample, torch.flip(sample, [2]))) 53 | return new_sample 54 | 55 | 56 | class ToTensor(object): 57 | def __init__(self, train): 58 | self.train = train 59 | self.transform = transforms.ToTensor() 60 | 61 | def __call__(self, sample): 62 | if self.train: 63 | left_image = sample['left_image'] 64 | right_image = sample['right_image'] 65 | new_right_image = self.transform(right_image) 66 | new_left_image = self.transform(left_image) 67 | sample = {'left_image': new_left_image, 68 | 'right_image': new_right_image} 69 | else: 70 | left_image = sample 71 | sample = self.transform(left_image) 72 | return sample 73 | 74 | 75 | class RandomFlip(object): 76 | def __init__(self, do_augmentation): 77 | self.transform = transforms.RandomHorizontalFlip(p=1) 78 | self.do_augmentation = do_augmentation 79 | 80 | def __call__(self, sample): 81 | left_image = sample['left_image'] 82 | right_image = sample['right_image'] 83 | k = np.random.uniform(0, 1, 1) 84 | if self.do_augmentation: 85 | if k > 0.5: 86 | fliped_left = self.transform(right_image) 87 | fliped_right = self.transform(left_image) 88 | sample = {'left_image': fliped_left, 'right_image': fliped_right} 89 | else: 90 | sample = {'left_image': left_image, 'right_image': right_image} 91 | return sample 92 | 93 | 94 | class AugmentImagePair(object): 95 | def __init__(self, augment_parameters, do_augmentation): 96 | self.do_augmentation = do_augmentation 97 | self.gamma_low = augment_parameters[0] # 0.8 98 | self.gamma_high = augment_parameters[1] # 1.2 99 | self.brightness_low = augment_parameters[2] # 0.5 100 | self.brightness_high = augment_parameters[3] # 2.0 101 | self.color_low = augment_parameters[4] # 0.8 102 | self.color_high = augment_parameters[5] # 1.2 103 | 104 | def __call__(self, sample): 105 | left_image = sample['left_image'] 106 | right_image = sample['right_image'] 107 | p = np.random.uniform(0, 1, 1) 108 | if self.do_augmentation: 109 | if p > 0.5: 110 | # randomly shift gamma 111 | random_gamma = np.random.uniform(self.gamma_low, self.gamma_high) 112 | left_image_aug = left_image ** random_gamma 113 | right_image_aug = right_image ** random_gamma 114 | 115 | # randomly shift brightness 116 | random_brightness = np.random.uniform(self.brightness_low, self.brightness_high) 117 | left_image_aug = left_image_aug * random_brightness 118 | right_image_aug = right_image_aug * random_brightness 119 | 120 | # randomly shift color 121 | random_colors = np.random.uniform(self.color_low, self.color_high, 3) 122 | for i in range(3): 123 | left_image_aug[i, :, :] *= random_colors[i] 124 | right_image_aug[i, :, :] *= random_colors[i] 125 | 126 | # saturate 127 | left_image_aug = torch.clamp(left_image_aug, 0, 1) 128 | right_image_aug = torch.clamp(right_image_aug, 0, 1) 129 | 130 | sample = {'left_image': left_image_aug, 'right_image': right_image_aug} 131 | 132 | else: 133 | sample = {'left_image': left_image, 'right_image': right_image} 134 | return sample 135 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import collections 3 | import os 4 | from torch.utils.data import DataLoader, ConcatDataset 5 | 6 | 7 | from models_resnet import Resnet18_md, Resnet50_md, ResnetModel 8 | from data_loader import KittiLoader 9 | from transforms import image_transforms 10 | 11 | def to_device(input, device): 12 | if torch.is_tensor(input): 13 | return input.to(device=device) 14 | elif isinstance(input, str): 15 | return input 16 | elif isinstance(input, collections.Mapping): 17 | return {k: to_device(sample, device=device) for k, sample in input.items()} 18 | elif isinstance(input, collections.Sequence): 19 | return [to_device(sample, device=device) for sample in input] 20 | else: 21 | raise TypeError(f"Input must contain tensor, dict or list, found {type(input)}") 22 | 23 | 24 | def get_model(model, input_channels=3, pretrained=False): 25 | if model == 'resnet50_md': 26 | out_model = Resnet50_md(input_channels) 27 | elif model == 'resnet18_md': 28 | out_model = Resnet18_md(input_channels) 29 | else: 30 | out_model = ResnetModel(input_channels, encoder=model, pretrained=pretrained) 31 | return out_model 32 | 33 | 34 | def prepare_dataloader(data_directory, mode, augment_parameters, 35 | do_augmentation, batch_size, size, num_workers): 36 | data_dirs = os.listdir(data_directory) 37 | data_transform = image_transforms( 38 | mode=mode, 39 | augment_parameters=augment_parameters, 40 | do_augmentation=do_augmentation, 41 | size = size) 42 | datasets = [KittiLoader(os.path.join(data_directory, 43 | data_dir), mode, transform=data_transform) 44 | for data_dir in data_dirs] 45 | dataset = ConcatDataset(datasets) 46 | n_img = len(dataset) 47 | print('Use a dataset with', n_img, 'images') 48 | if mode == 'train': 49 | loader = DataLoader(dataset, batch_size=batch_size, 50 | shuffle=True, num_workers=num_workers, 51 | pin_memory=True) 52 | else: 53 | loader = DataLoader(dataset, batch_size=batch_size, 54 | shuffle=False, num_workers=num_workers, 55 | pin_memory=True) 56 | return n_img, loader 57 | --------------------------------------------------------------------------------