├── .gitignore ├── README.md ├── datasets ├── __init__.py ├── kitti_dataset.py └── mono_dataset.py ├── evaluate_depth.py ├── img ├── overview.png ├── robustness.png ├── speed.png └── teaser_m.gif ├── kitti_utils.py ├── layers.py ├── license ├── lite-mono-pretrain-code ├── README.md ├── datasets.py ├── engine.py ├── main.py ├── models │ └── litemono.py ├── optim_factory.py ├── requirements.txt ├── sampler.py └── utils.py ├── networks ├── __init__.py ├── depth_decoder.py ├── depth_encoder.py ├── pose_decoder.py └── resnet_encoder.py ├── options.py ├── splits └── eigen │ └── test_files.txt ├── test_simple.py ├── train.py ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .idea/ 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Lite-Mono 4 | **A Lightweight CNN and Transformer Architecture for Self-Supervised Monocular Depth Estimation** 5 | [[paper link]](https://arxiv.org/abs/2211.13202) 6 | 7 | Ning Zhang*, Francesco Nex, George Vosselman, Norman Kerle 8 | 9 | 10 | License: MIT 11 | 12 | 13 | teaser 14 | 15 | (Lite-Mono-8m 1024x320) 16 | 17 |
18 | 19 | 20 | ## Table of Contents 21 | - [Overview](#overview) 22 | - [Results](#results) 23 | - [KITTI](#kitti) 24 | - [Speed Evaluation](#speed-evaluation) 25 | - [Robustness](#robustness) 26 | - [Data Preparation](#data-preparation) 27 | - [Single Image Test](#single-image-test) 28 | - [Preparing Trained Model](#preparing-trained-model) 29 | - [Start Testing](#start-testing) 30 | - [Evaluation](#evaluation) 31 | - [Training](#training) 32 | - [Dependency Installation](#dependency-installation) 33 | - [Preparing Pre-trained Weights](#preparing-pre-trained-weights) 34 | - [Start Training](#start-training) 35 | - [Tensorboard Visualization](#tensorboard-visualization) 36 | - [Make Your Own Pre-training Weights On ImageNet](#make-your-own-pre-training-weights-on-imagenet) 37 | - [Citation](#citation) 38 | 39 | 40 | ## Overview 41 | overview 42 | 43 | 44 | ## Results 45 | ### KITTI 46 | You can download the trained models using the links below. 47 | 48 | | --model | Params | ImageNet Pretrained | Input size | Abs Rel | Sq Rel | RMSE | RMSE log | delta < 1.25 | delta < 1.25^2 | delta < 1.25^3 | 49 | |:---------------:|:------:|:-------------------:|:----------:|:---------:|:---------:|:---------:|:---------:|:------------:|:--------------:|:--------------:| 50 | | [**lite-mono**](https://surfdrive.surf.nl/files/index.php/s/CUjiK221EFLyXDY) | 3.1M | [yes](https://surfdrive.surf.nl/files/index.php/s/InMMGd5ZP2fXuia) | 640x192 | 0.107 | 0.765 | 4.561 | 0.183 | 0.886 | 0.963 | 0.983 | 51 | | [lite-mono-small](https://surfdrive.surf.nl/files/index.php/s/8cuZNH1CkNtQwxQ) | 2.5M | [yes](https://surfdrive.surf.nl/files/index.php/s/DYbWV4bsWImfJKu) | 640x192 | 0.110 | 0.802 | 4.671 | 0.186 | 0.879 | 0.961 | 0.982 | 52 | | [lite-mono-tiny](https://surfdrive.surf.nl/files/index.php/s/TFDlF3wYQy0Nhmg) | 2.2M | yes | 640x192 | 0.110 | 0.837 | 4.710 | 0.187 | 0.880 | 0.960 | 0.982 | 53 | | [**lite-mono-8m**](https://surfdrive.surf.nl/files/index.php/s/UlkVBi1p99NFWWI) | 8.7M | [yes](https://surfdrive.surf.nl/files/index.php/s/oil2ME6ymoLGDlL) | 640x192 | 0.101 | 0.729 | 4.454 | 0.178 | 0.897 | 0.965 | 0.983 | 54 | | [**lite-mono**](https://surfdrive.surf.nl/files/index.php/s/IK3VtPj6b5FkVnl) | 3.1M | yes | 1024x320 | 0.102 | 0.746 | 4.444 | 0.179 | 0.896 | 0.965 | 0.983 | 55 | | [lite-mono-small](https://surfdrive.surf.nl/files/index.php/s/w8mvJMkB1dP15pu) | 2.5M | yes | 1024x320 | 0.103 | 0.757 | 4.449 | 0.180 | 0.894 | 0.964 | 0.983 | 56 | | [lite-mono-tiny](https://surfdrive.surf.nl/files/index.php/s/myxcplTciOkgu5w) | 2.2M | yes | 1024x320 | 0.104 | 0.764 | 4.487 | 0.180 | 0.892 | 0.964 | 0.983 | 57 | | [**lite-mono-8m**](https://surfdrive.surf.nl/files/index.php/s/mgonNFAvoEJmMas) | 8.7M | yes | 1024x320 | 0.097 | 0.710 | 4.309 | 0.174 | 0.905 | 0.967 | 0.984 | 58 | 59 | 60 | ### Speed Evaluation 61 | speed evaluation 62 | 63 | 64 | ### Robustness 65 | robustness 66 | 67 | The [RoboDepth Challenge Team](https://github.com/ldkong1205/RoboDepth) is evaluating the robustness of different depth estimation algorithms. Lite-Mono has achieved the best robustness to date. 68 | 69 | ## Data Preparation 70 | Please refer to [Monodepth2](https://github.com/nianticlabs/monodepth2) to prepare your KITTI data. 71 | 72 | ## Single Image Test 73 | #### preparing trained model 74 | From this [table](#kitti) you can download trained models (depth encoder and depth decoder). 75 | 76 | Click on the links in the '--model' column to download a trained model. 77 | 78 | #### start testing 79 | python test_simple.py --load_weights_folder path/to/your/weights/folder --image_path path/to/your/test/image 80 | 81 | ## Evaluation 82 | python evaluate_depth.py --load_weights_folder path/to/your/weights/folder --data_path path/to/kitti_data/ --model lite-mono 83 | 84 | 85 | ## Training 86 | #### dependency installation 87 | pip install 'git+https://github.com/saadnaeem-dev/pytorch-linear-warmup-cosine-annealing-warm-restarts-weight-decay' 88 | 89 | #### preparing pre-trained weights 90 | From this [table](#kitti) you can also download weights of backbone (depth encoder) pre-trained on ImageNet. 91 | 92 | Click 'yes' on a row to download specific pre-trained weights. The weights are agnostic to image resolutions. 93 | 94 | #### start training 95 | python train.py --data_path path/to/your/data --model_name mytrain --num_epochs 30 --batch_size 12 --mypretrain path/to/your/pretrained/weights --lr 0.0001 5e-6 31 0.0001 1e-5 31 96 | 97 | #### tensorboard visualization 98 | tensorboard --log_dir ./tmp/mytrain 99 | 100 | ## Make Your Own Pre-training Weights On ImageNet 101 | Since a lot of people are interested in training their own backbone on ImageNet, I also upload my pre-training [scripts](lite-mono-pretrain-code) to this repo. 102 | 103 | 104 | ## Citation 105 | 106 | @InProceedings{Zhang_2023_CVPR, 107 | author = {Zhang, Ning and Nex, Francesco and Vosselman, George and Kerle, Norman}, 108 | title = {Lite-Mono: A Lightweight CNN and Transformer Architecture for Self-Supervised Monocular Depth Estimation}, 109 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 110 | month = {June}, 111 | year = {2023}, 112 | pages = {18537-18546} 113 | } -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .kitti_dataset import KITTIRAWDataset, KITTIOdomDataset, KITTIDepthDataset 2 | -------------------------------------------------------------------------------- /datasets/kitti_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import skimage.transform 5 | import numpy as np 6 | import PIL.Image as pil 7 | 8 | from kitti_utils import generate_depth_map 9 | from .mono_dataset import MonoDataset 10 | 11 | 12 | class KITTIDataset(MonoDataset): 13 | """Superclass for different types of KITTI dataset loaders 14 | """ 15 | def __init__(self, *args, **kwargs): 16 | super(KITTIDataset, self).__init__(*args, **kwargs) 17 | 18 | self.K = np.array([[0.58, 0, 0.5, 0], 19 | [0, 1.92, 0.5, 0], 20 | [0, 0, 1, 0], 21 | [0, 0, 0, 1]], dtype=np.float32) 22 | 23 | self.full_res_shape = (1242, 375) 24 | self.side_map = {"2": 2, "3": 3, "l": 2, "r": 3} 25 | 26 | def check_depth(self): 27 | line = self.filenames[0].split() 28 | scene_name = line[0] 29 | frame_index = int(line[1]) 30 | 31 | velo_filename = os.path.join( 32 | self.data_path, 33 | scene_name, 34 | "velodyne_points/data/{:010d}.bin".format(int(frame_index))) 35 | 36 | return os.path.isfile(velo_filename) 37 | 38 | def get_color(self, folder, frame_index, side, do_flip): 39 | color = self.loader(self.get_image_path(folder, frame_index, side)) 40 | 41 | if do_flip: 42 | color = color.transpose(pil.FLIP_LEFT_RIGHT) 43 | 44 | return color 45 | 46 | 47 | class KITTIRAWDataset(KITTIDataset): 48 | """KITTI dataset which loads the original velodyne depth maps for ground truth 49 | """ 50 | def __init__(self, *args, **kwargs): 51 | super(KITTIRAWDataset, self).__init__(*args, **kwargs) 52 | 53 | def get_image_path(self, folder, frame_index, side): 54 | f_str = "{:010d}{}".format(frame_index, self.img_ext) 55 | image_path = os.path.join( 56 | self.data_path, folder, "image_0{}/data".format(self.side_map[side]), f_str) 57 | return image_path 58 | 59 | def get_depth(self, folder, frame_index, side, do_flip): 60 | calib_path = os.path.join(self.data_path, folder.split("/")[0]) 61 | 62 | velo_filename = os.path.join( 63 | self.data_path, 64 | folder, 65 | "velodyne_points/data/{:010d}.bin".format(int(frame_index))) 66 | 67 | depth_gt = generate_depth_map(calib_path, velo_filename, self.side_map[side]) 68 | depth_gt = skimage.transform.resize( 69 | depth_gt, self.full_res_shape[::-1], order=0, preserve_range=True, mode='constant') 70 | 71 | if do_flip: 72 | depth_gt = np.fliplr(depth_gt) 73 | 74 | return depth_gt 75 | 76 | 77 | class KITTIOdomDataset(KITTIDataset): 78 | """KITTI dataset for odometry training and testing 79 | """ 80 | def __init__(self, *args, **kwargs): 81 | super(KITTIOdomDataset, self).__init__(*args, **kwargs) 82 | 83 | def get_image_path(self, folder, frame_index, side): 84 | f_str = "{:06d}{}".format(frame_index, self.img_ext) 85 | image_path = os.path.join( 86 | self.data_path, 87 | "sequences/{:02d}".format(int(folder)), 88 | "image_{}".format(self.side_map[side]), 89 | f_str) 90 | return image_path 91 | 92 | 93 | class KITTIDepthDataset(KITTIDataset): 94 | """KITTI dataset which uses the updated ground truth depth maps 95 | """ 96 | def __init__(self, *args, **kwargs): 97 | super(KITTIDepthDataset, self).__init__(*args, **kwargs) 98 | 99 | def get_image_path(self, folder, frame_index, side): 100 | f_str = "{:010d}{}".format(frame_index, self.img_ext) 101 | image_path = os.path.join( 102 | self.data_path, 103 | folder, 104 | "image_0{}/data".format(self.side_map[side]), 105 | f_str) 106 | return image_path 107 | 108 | def get_depth(self, folder, frame_index, side, do_flip): 109 | f_str = "{:010d}.png".format(frame_index) 110 | depth_path = os.path.join( 111 | self.data_path, 112 | folder, 113 | "proj_depth/groundtruth/image_0{}".format(self.side_map[side]), 114 | f_str) 115 | 116 | depth_gt = pil.open(depth_path) 117 | depth_gt = depth_gt.resize(self.full_res_shape, pil.NEAREST) 118 | depth_gt = np.array(depth_gt).astype(np.float32) / 256 119 | 120 | if do_flip: 121 | depth_gt = np.fliplr(depth_gt) 122 | 123 | return depth_gt 124 | -------------------------------------------------------------------------------- /datasets/mono_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import random 5 | import numpy as np 6 | import copy 7 | from PIL import Image # using pillow-simd for increased speed 8 | 9 | import torch 10 | import torch.utils.data as data 11 | from torchvision import transforms 12 | 13 | 14 | def pil_loader(path): 15 | with open(path, 'rb') as f: 16 | with Image.open(f) as img: 17 | return img.convert('RGB') 18 | 19 | 20 | class MonoDataset(data.Dataset): 21 | """Superclass for monocular dataloaders 22 | 23 | Args: 24 | data_path 25 | filenames 26 | height 27 | width 28 | frame_idxs 29 | num_scales 30 | is_train 31 | img_ext 32 | """ 33 | def __init__(self, 34 | data_path, 35 | filenames, 36 | height, 37 | width, 38 | frame_idxs, 39 | num_scales, 40 | is_train=False, 41 | img_ext='.jpg'): 42 | super(MonoDataset, self).__init__() 43 | 44 | self.data_path = data_path 45 | self.filenames = filenames 46 | self.height = height 47 | self.width = width 48 | self.num_scales = num_scales 49 | self.interp = Image.ANTIALIAS 50 | 51 | self.frame_idxs = frame_idxs 52 | 53 | self.is_train = is_train 54 | self.img_ext = img_ext 55 | 56 | self.loader = pil_loader 57 | self.to_tensor = transforms.ToTensor() 58 | 59 | # We need to specify augmentations differently in newer versions of torchvision. 60 | # We first try the newer tuple version; if this fails we fall back to scalars 61 | try: 62 | self.brightness = (0.8, 1.2) 63 | self.contrast = (0.8, 1.2) 64 | self.saturation = (0.8, 1.2) 65 | self.hue = (-0.1, 0.1) 66 | transforms.ColorJitter.get_params( 67 | self.brightness, self.contrast, self.saturation, self.hue) 68 | except TypeError: 69 | self.brightness = 0.2 70 | self.contrast = 0.2 71 | self.saturation = 0.2 72 | self.hue = 0.1 73 | 74 | self.resize = {} 75 | for i in range(self.num_scales): 76 | s = 2 ** i 77 | self.resize[i] = transforms.Resize((self.height // s, self.width // s), 78 | interpolation=self.interp) 79 | 80 | self.load_depth = self.check_depth() 81 | 82 | def preprocess(self, inputs, color_aug): 83 | """Resize colour images to the required scales and augment if required 84 | 85 | We create the color_aug object in advance and apply the same augmentation to all 86 | images in this item. This ensures that all images input to the pose network receive the 87 | same augmentation. 88 | """ 89 | for k in list(inputs): 90 | frame = inputs[k] 91 | if "color" in k: 92 | n, im, i = k 93 | for i in range(self.num_scales): 94 | inputs[(n, im, i)] = self.resize[i](inputs[(n, im, i - 1)]) 95 | 96 | for k in list(inputs): 97 | f = inputs[k] 98 | if "color" in k: 99 | n, im, i = k 100 | inputs[(n, im, i)] = self.to_tensor(f) 101 | inputs[(n + "_aug", im, i)] = self.to_tensor(color_aug(f)) 102 | 103 | def __len__(self): 104 | return len(self.filenames) 105 | 106 | def __getitem__(self, index): 107 | """Returns a single training item from the dataset as a dictionary. 108 | 109 | Values correspond to torch tensors. 110 | Keys in the dictionary are either strings or tuples: 111 | 112 | ("color", , ) for raw colour images, 113 | ("color_aug", , ) for augmented colour images, 114 | ("K", scale) or ("inv_K", scale) for camera intrinsics, 115 | "stereo_T" for camera extrinsics, and 116 | "depth_gt" for ground truth depth maps. 117 | 118 | is either: 119 | an integer (e.g. 0, -1, or 1) representing the temporal step relative to 'index', 120 | or 121 | "s" for the opposite image in the stereo pair. 122 | 123 | is an integer representing the scale of the image relative to the fullsize image: 124 | -1 images at native resolution as loaded from disk 125 | 0 images resized to (self.width, self.height ) 126 | 1 images resized to (self.width // 2, self.height // 2) 127 | 2 images resized to (self.width // 4, self.height // 4) 128 | 3 images resized to (self.width // 8, self.height // 8) 129 | """ 130 | inputs = {} 131 | 132 | do_color_aug = self.is_train and random.random() > 0.5 133 | do_flip = self.is_train and random.random() > 0.5 134 | 135 | line = self.filenames[index].split() 136 | folder = line[0] 137 | 138 | if len(line) == 3: 139 | frame_index = int(line[1]) 140 | else: 141 | frame_index = 0 142 | 143 | if len(line) == 3: 144 | side = line[2] 145 | else: 146 | side = None 147 | 148 | for i in self.frame_idxs: 149 | if i == "s": 150 | other_side = {"r": "l", "l": "r"}[side] 151 | inputs[("color", i, -1)] = self.get_color(folder, frame_index, other_side, do_flip) 152 | else: 153 | inputs[("color", i, -1)] = self.get_color(folder, frame_index + i, side, do_flip) 154 | 155 | # adjusting intrinsics to match each scale in the pyramid 156 | for scale in range(self.num_scales): 157 | K = self.K.copy() 158 | 159 | K[0, :] *= self.width // (2 ** scale) 160 | K[1, :] *= self.height // (2 ** scale) 161 | 162 | inv_K = np.linalg.pinv(K) 163 | 164 | inputs[("K", scale)] = torch.from_numpy(K) 165 | inputs[("inv_K", scale)] = torch.from_numpy(inv_K) 166 | 167 | if do_color_aug: 168 | # color_aug = transforms.ColorJitter.get_params( 169 | # self.brightness, self.contrast, self.saturation, self.hue) 170 | color_aug = transforms.ColorJitter( 171 | self.brightness, self.contrast, self.saturation, self.hue) 172 | else: 173 | color_aug = (lambda x: x) 174 | 175 | self.preprocess(inputs, color_aug) 176 | 177 | for i in self.frame_idxs: 178 | del inputs[("color", i, -1)] 179 | del inputs[("color_aug", i, -1)] 180 | 181 | if self.load_depth: 182 | depth_gt = self.get_depth(folder, frame_index, side, do_flip) 183 | inputs["depth_gt"] = np.expand_dims(depth_gt, 0) 184 | inputs["depth_gt"] = torch.from_numpy(inputs["depth_gt"].astype(np.float32)) 185 | 186 | if "s" in self.frame_idxs: 187 | stereo_T = np.eye(4, dtype=np.float32) 188 | baseline_sign = -1 if do_flip else 1 189 | side_sign = -1 if side == "l" else 1 190 | stereo_T[0, 3] = side_sign * baseline_sign * 0.1 191 | 192 | inputs["stereo_T"] = torch.from_numpy(stereo_T) 193 | 194 | return inputs 195 | 196 | def get_color(self, folder, frame_index, side, do_flip): 197 | raise NotImplementedError 198 | 199 | def check_depth(self): 200 | raise NotImplementedError 201 | 202 | def get_depth(self, folder, frame_index, side, do_flip): 203 | raise NotImplementedError 204 | -------------------------------------------------------------------------------- /evaluate_depth.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import os 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from layers import disp_to_depth 8 | from utils import readlines 9 | from options import LiteMonoOptions 10 | import datasets 11 | import networks 12 | import time 13 | from thop import clever_format 14 | from thop import profile 15 | 16 | 17 | cv2.setNumThreads(0) # This speeds up evaluation 5x on our unix systems (OpenCV 3.3.1) 18 | 19 | splits_dir = os.path.join(os.path.dirname(__file__), "splits") 20 | 21 | 22 | def profile_once(encoder, decoder, x): 23 | x_e = x[0, :, :, :].unsqueeze(0) 24 | x_d = encoder(x_e) 25 | flops_e, params_e = profile(encoder, inputs=(x_e, ), verbose=False) 26 | flops_d, params_d = profile(decoder, inputs=(x_d, ), verbose=False) 27 | 28 | flops, params = clever_format([flops_e + flops_d, params_e + params_d], "%.3f") 29 | flops_e, params_e = clever_format([flops_e, params_e], "%.3f") 30 | flops_d, params_d = clever_format([flops_d, params_d], "%.3f") 31 | 32 | return flops, params, flops_e, params_e, flops_d, params_d 33 | 34 | 35 | def time_sync(): 36 | # PyTorch-accurate time 37 | if torch.cuda.is_available(): 38 | torch.cuda.synchronize() 39 | return time.time() 40 | 41 | 42 | def compute_errors(gt, pred): 43 | """Computation of error metrics between predicted and ground truth depths 44 | """ 45 | thresh = np.maximum((gt / pred), (pred / gt)) 46 | a1 = (thresh < 1.25 ).mean() 47 | a2 = (thresh < 1.25 ** 2).mean() 48 | a3 = (thresh < 1.25 ** 3).mean() 49 | 50 | rmse = (gt - pred) ** 2 51 | rmse = np.sqrt(rmse.mean()) 52 | 53 | rmse_log = (np.log(gt) - np.log(pred)) ** 2 54 | rmse_log = np.sqrt(rmse_log.mean()) 55 | 56 | abs_rel = np.mean(np.abs(gt - pred) / gt) 57 | 58 | sq_rel = np.mean(((gt - pred) ** 2) / gt) 59 | 60 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 61 | 62 | 63 | def batch_post_process_disparity(l_disp, r_disp): 64 | """Apply the disparity post-processing method as introduced in Monodepthv1 65 | """ 66 | _, h, w = l_disp.shape 67 | m_disp = 0.5 * (l_disp + r_disp) 68 | l, _ = np.meshgrid(np.linspace(0, 1, w), np.linspace(0, 1, h)) 69 | l_mask = (1.0 - np.clip(20 * (l - 0.05), 0, 1))[None, ...] 70 | r_mask = l_mask[:, :, ::-1] 71 | return r_mask * l_disp + l_mask * r_disp + (1.0 - l_mask - r_mask) * m_disp 72 | 73 | 74 | def evaluate(opt): 75 | """Evaluates a pretrained model using a specified test set 76 | """ 77 | MIN_DEPTH = 1e-3 78 | MAX_DEPTH = 80 79 | 80 | if opt.ext_disp_to_eval is None: 81 | 82 | opt.load_weights_folder = os.path.expanduser(opt.load_weights_folder) 83 | 84 | assert os.path.isdir(opt.load_weights_folder), \ 85 | "Cannot find a folder at {}".format(opt.load_weights_folder) 86 | 87 | print("-> Loading weights from {}".format(opt.load_weights_folder)) 88 | 89 | filenames = readlines(os.path.join(splits_dir, opt.eval_split, "test_files.txt")) 90 | encoder_path = os.path.join(opt.load_weights_folder, "encoder.pth") 91 | decoder_path = os.path.join(opt.load_weights_folder, "depth.pth") 92 | 93 | encoder_dict = torch.load(encoder_path) 94 | decoder_dict = torch.load(decoder_path) 95 | 96 | dataset = datasets.KITTIRAWDataset(opt.data_path, filenames, 97 | encoder_dict['height'], encoder_dict['width'], 98 | [0], 4, is_train=False) 99 | dataloader = DataLoader(dataset, 16, shuffle=False, num_workers=opt.num_workers, 100 | pin_memory=True, drop_last=False) 101 | 102 | encoder = networks.LiteMono(model=opt.model, 103 | height=encoder_dict['height'], 104 | width=encoder_dict['width']) 105 | depth_decoder = networks.DepthDecoder(encoder.num_ch_enc, scales=range(3)) 106 | model_dict = encoder.state_dict() 107 | depth_model_dict = depth_decoder.state_dict() 108 | encoder.load_state_dict({k: v for k, v in encoder_dict.items() if k in model_dict}) 109 | depth_decoder.load_state_dict({k: v for k, v in decoder_dict.items() if k in depth_model_dict}) 110 | 111 | encoder.cuda() 112 | encoder.eval() 113 | depth_decoder.cuda() 114 | depth_decoder.eval() 115 | 116 | pred_disps = [] 117 | 118 | print("-> Computing predictions with size {}x{}".format( 119 | encoder_dict['width'], encoder_dict['height'])) 120 | 121 | with torch.no_grad(): 122 | for data in dataloader: 123 | input_color = data[("color", 0, 0)].cuda() 124 | 125 | if opt.post_process: 126 | # Post-processed results require each image to have two forward passes 127 | input_color = torch.cat((input_color, torch.flip(input_color, [3])), 0) 128 | 129 | flops, params, flops_e, params_e, flops_d, params_d = profile_once(encoder, depth_decoder, input_color) 130 | t1 = time_sync() 131 | output = depth_decoder(encoder(input_color)) 132 | t2 = time_sync() 133 | 134 | pred_disp, _ = disp_to_depth(output[("disp", 0)], opt.min_depth, opt.max_depth) 135 | pred_disp = pred_disp.cpu()[:, 0].numpy() 136 | 137 | if opt.post_process: 138 | N = pred_disp.shape[0] // 2 139 | pred_disp = batch_post_process_disparity(pred_disp[:N], pred_disp[N:, :, ::-1]) 140 | 141 | pred_disps.append(pred_disp) 142 | 143 | pred_disps = np.concatenate(pred_disps) 144 | 145 | else: 146 | # Load predictions from file 147 | print("-> Loading predictions from {}".format(opt.ext_disp_to_eval)) 148 | pred_disps = np.load(opt.ext_disp_to_eval) 149 | 150 | if opt.eval_eigen_to_benchmark: 151 | eigen_to_benchmark_ids = np.load( 152 | os.path.join(splits_dir, "benchmark", "eigen_to_benchmark_ids.npy")) 153 | 154 | pred_disps = pred_disps[eigen_to_benchmark_ids] 155 | 156 | if opt.save_pred_disps: 157 | output_path = os.path.join( 158 | opt.load_weights_folder, "disps_{}_split.npy".format(opt.eval_split)) 159 | print("-> Saving predicted disparities to ", output_path) 160 | np.save(output_path, pred_disps) 161 | 162 | if opt.no_eval: 163 | print("-> Evaluation disabled. Done.") 164 | quit() 165 | 166 | gt_path = os.path.join(splits_dir, opt.eval_split, "gt_depths.npz") 167 | gt_depths = np.load(gt_path, fix_imports=True, encoding='latin1', allow_pickle=True)["data"] 168 | 169 | print("-> Evaluating") 170 | print(" Mono evaluation - using median scaling") 171 | 172 | errors = [] 173 | ratios = [] 174 | 175 | for i in range(pred_disps.shape[0]): 176 | gt_depth = gt_depths[i] 177 | gt_height, gt_width = gt_depth.shape[:2] 178 | 179 | pred_disp = pred_disps[i] 180 | pred_disp = cv2.resize(pred_disp, (gt_width, gt_height)) 181 | pred_depth = 1 / pred_disp 182 | 183 | if opt.eval_split == "eigen": 184 | mask = np.logical_and(gt_depth > MIN_DEPTH, gt_depth < MAX_DEPTH) 185 | 186 | crop = np.array([0.40810811 * gt_height, 0.99189189 * gt_height, 187 | 0.03594771 * gt_width, 0.96405229 * gt_width]).astype(np.int32) 188 | crop_mask = np.zeros(mask.shape) 189 | crop_mask[crop[0]:crop[1], crop[2]:crop[3]] = 1 190 | mask = np.logical_and(mask, crop_mask) 191 | 192 | else: 193 | mask = gt_depth > 0 194 | 195 | pred_depth = pred_depth[mask] 196 | gt_depth = gt_depth[mask] 197 | 198 | pred_depth *= opt.pred_depth_scale_factor 199 | if not opt.disable_median_scaling: 200 | ratio = np.median(gt_depth) / np.median(pred_depth) 201 | ratios.append(ratio) 202 | pred_depth *= ratio 203 | 204 | pred_depth[pred_depth < MIN_DEPTH] = MIN_DEPTH 205 | pred_depth[pred_depth > MAX_DEPTH] = MAX_DEPTH 206 | errors.append(compute_errors(gt_depth, pred_depth)) 207 | 208 | if not opt.disable_median_scaling: 209 | ratios = np.array(ratios) 210 | med = np.median(ratios) 211 | print(" Scaling ratios | med: {:0.3f} | std: {:0.3f}".format(med, np.std(ratios / med))) 212 | 213 | mean_errors = np.array(errors).mean(0) 214 | 215 | print("\n " + ("{:>8} | " * 7).format("abs_rel", "sq_rel", "rmse", "rmse_log", "a1", "a2", "a3")) 216 | print(("&{: 8.3f} " * 7).format(*mean_errors.tolist()) + "\\\\") 217 | print("\n " + ("flops: {0}, params: {1}, flops_e: {2}, params_e:{3}, flops_d:{4}, params_d:{5}").format(flops, params, flops_e, params_e, flops_d, params_d)) 218 | print("\n-> Done!") 219 | 220 | 221 | if __name__ == "__main__": 222 | options = LiteMonoOptions() 223 | evaluate(options.parse()) 224 | -------------------------------------------------------------------------------- /img/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/noahzn/Lite-Mono/4874b35df8ed4da16159ce8be8c697028b72bf76/img/overview.png -------------------------------------------------------------------------------- /img/robustness.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/noahzn/Lite-Mono/4874b35df8ed4da16159ce8be8c697028b72bf76/img/robustness.png -------------------------------------------------------------------------------- /img/speed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/noahzn/Lite-Mono/4874b35df8ed4da16159ce8be8c697028b72bf76/img/speed.png -------------------------------------------------------------------------------- /img/teaser_m.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/noahzn/Lite-Mono/4874b35df8ed4da16159ce8be8c697028b72bf76/img/teaser_m.gif -------------------------------------------------------------------------------- /kitti_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import os 3 | import numpy as np 4 | from collections import Counter 5 | 6 | 7 | def load_velodyne_points(filename): 8 | """Load 3D point cloud from KITTI file format 9 | (adapted from https://github.com/hunse/kitti) 10 | """ 11 | points = np.fromfile(filename, dtype=np.float32).reshape(-1, 4) 12 | points[:, 3] = 1.0 # homogeneous 13 | return points 14 | 15 | 16 | def read_calib_file(path): 17 | """Read KITTI calibration file 18 | (from https://github.com/hunse/kitti) 19 | """ 20 | float_chars = set("0123456789.e+- ") 21 | data = {} 22 | with open(path, 'r') as f: 23 | for line in f.readlines(): 24 | key, value = line.split(':', 1) 25 | value = value.strip() 26 | data[key] = value 27 | if float_chars.issuperset(value): 28 | # try to cast to float array 29 | try: 30 | data[key] = np.array(list(map(float, value.split(' ')))) 31 | except ValueError: 32 | # casting error: data[key] already eq. value, so pass 33 | pass 34 | 35 | return data 36 | 37 | 38 | def sub2ind(matrixSize, rowSub, colSub): 39 | """Convert row, col matrix subscripts to linear indices 40 | """ 41 | m, n = matrixSize 42 | return rowSub * (n-1) + colSub - 1 43 | 44 | 45 | def generate_depth_map(calib_dir, velo_filename, cam=2, vel_depth=False): 46 | """Generate a depth map from velodyne data 47 | """ 48 | # load calibration files 49 | cam2cam = read_calib_file(os.path.join(calib_dir, 'calib_cam_to_cam.txt')) 50 | velo2cam = read_calib_file(os.path.join(calib_dir, 'calib_velo_to_cam.txt')) 51 | velo2cam = np.hstack((velo2cam['R'].reshape(3, 3), velo2cam['T'][..., np.newaxis])) 52 | velo2cam = np.vstack((velo2cam, np.array([0, 0, 0, 1.0]))) 53 | 54 | # get image shape 55 | im_shape = cam2cam["S_rect_02"][::-1].astype(np.int32) 56 | 57 | # compute projection matrix velodyne->image plane 58 | R_cam2rect = np.eye(4) 59 | R_cam2rect[:3, :3] = cam2cam['R_rect_00'].reshape(3, 3) 60 | P_rect = cam2cam['P_rect_0'+str(cam)].reshape(3, 4) 61 | P_velo2im = np.dot(np.dot(P_rect, R_cam2rect), velo2cam) 62 | 63 | # load velodyne points and remove all behind image plane (approximation) 64 | # each row of the velodyne data is forward, left, up, reflectance 65 | velo = load_velodyne_points(velo_filename) 66 | velo = velo[velo[:, 0] >= 0, :] 67 | 68 | # project the points to the camera 69 | velo_pts_im = np.dot(P_velo2im, velo.T).T 70 | velo_pts_im[:, :2] = velo_pts_im[:, :2] / velo_pts_im[:, 2][..., np.newaxis] 71 | 72 | if vel_depth: 73 | velo_pts_im[:, 2] = velo[:, 0] 74 | 75 | # check if in bounds 76 | # use minus 1 to get the exact same value as KITTI matlab code 77 | velo_pts_im[:, 0] = np.round(velo_pts_im[:, 0]) - 1 78 | velo_pts_im[:, 1] = np.round(velo_pts_im[:, 1]) - 1 79 | val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0) 80 | val_inds = val_inds & (velo_pts_im[:, 0] < im_shape[1]) & (velo_pts_im[:, 1] < im_shape[0]) 81 | velo_pts_im = velo_pts_im[val_inds, :] 82 | 83 | # project to image 84 | depth = np.zeros((im_shape[:2])) 85 | depth[velo_pts_im[:, 1].astype(np.int), velo_pts_im[:, 0].astype(np.int)] = velo_pts_im[:, 2] 86 | 87 | # find the duplicate points and choose the closest depth 88 | inds = sub2ind(depth.shape, velo_pts_im[:, 1], velo_pts_im[:, 0]) 89 | dupe_inds = [item for item, count in Counter(inds).items() if count > 1] 90 | for dd in dupe_inds: 91 | pts = np.where(inds == dd)[0] 92 | x_loc = int(velo_pts_im[pts[0], 0]) 93 | y_loc = int(velo_pts_im[pts[0], 1]) 94 | depth[y_loc, x_loc] = velo_pts_im[pts, 2].min() 95 | depth[depth < 0] = 0 96 | 97 | return depth 98 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import math 9 | 10 | 11 | def disp_to_depth(disp, min_depth, max_depth): 12 | """Convert network's sigmoid output into depth prediction 13 | The formula for this conversion is given in the 'additional considerations' 14 | section of the paper. 15 | """ 16 | min_disp = 1 / max_depth 17 | max_disp = 1 / min_depth 18 | scaled_disp = min_disp + (max_disp - min_disp) * disp 19 | depth = 1 / scaled_disp 20 | return scaled_disp, depth 21 | 22 | 23 | def transformation_from_parameters(axisangle, translation, invert=False): 24 | """Convert the network's (axisangle, translation) output into a 4x4 matrix 25 | """ 26 | R = rot_from_axisangle(axisangle) 27 | t = translation.clone() 28 | 29 | if invert: 30 | R = R.transpose(1, 2) 31 | t *= -1 32 | 33 | T = get_translation_matrix(t) 34 | 35 | if invert: 36 | M = torch.matmul(R, T) 37 | else: 38 | M = torch.matmul(T, R) 39 | 40 | return M 41 | 42 | 43 | def get_translation_matrix(translation_vector): 44 | """Convert a translation vector into a 4x4 transformation matrix 45 | """ 46 | T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device) 47 | 48 | t = translation_vector.contiguous().view(-1, 3, 1) 49 | 50 | T[:, 0, 0] = 1 51 | T[:, 1, 1] = 1 52 | T[:, 2, 2] = 1 53 | T[:, 3, 3] = 1 54 | T[:, :3, 3, None] = t 55 | 56 | return T 57 | 58 | 59 | def rot_from_axisangle(vec): 60 | """Convert an axisangle rotation into a 4x4 transformation matrix 61 | (adapted from https://github.com/Wallacoloo/printipi) 62 | Input 'vec' has to be Bx1x3 63 | """ 64 | angle = torch.norm(vec, 2, 2, True) 65 | axis = vec / (angle + 1e-7) 66 | 67 | ca = torch.cos(angle) 68 | sa = torch.sin(angle) 69 | C = 1 - ca 70 | 71 | x = axis[..., 0].unsqueeze(1) 72 | y = axis[..., 1].unsqueeze(1) 73 | z = axis[..., 2].unsqueeze(1) 74 | 75 | xs = x * sa 76 | ys = y * sa 77 | zs = z * sa 78 | xC = x * C 79 | yC = y * C 80 | zC = z * C 81 | xyC = x * yC 82 | yzC = y * zC 83 | zxC = z * xC 84 | 85 | rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device) 86 | 87 | rot[:, 0, 0] = torch.squeeze(x * xC + ca) 88 | rot[:, 0, 1] = torch.squeeze(xyC - zs) 89 | rot[:, 0, 2] = torch.squeeze(zxC + ys) 90 | rot[:, 1, 0] = torch.squeeze(xyC + zs) 91 | rot[:, 1, 1] = torch.squeeze(y * yC + ca) 92 | rot[:, 1, 2] = torch.squeeze(yzC - xs) 93 | rot[:, 2, 0] = torch.squeeze(zxC - ys) 94 | rot[:, 2, 1] = torch.squeeze(yzC + xs) 95 | rot[:, 2, 2] = torch.squeeze(z * zC + ca) 96 | rot[:, 3, 3] = 1 97 | 98 | return rot 99 | 100 | 101 | class ConvBlock(nn.Module): 102 | """Layer to perform a convolution followed by ELU 103 | """ 104 | def __init__(self, in_channels, out_channels): 105 | super(ConvBlock, self).__init__() 106 | 107 | self.conv = Conv3x3(in_channels, out_channels) 108 | self.nonlin = nn.ELU(inplace=True) 109 | 110 | def forward(self, x): 111 | out = self.conv(x) 112 | out = self.nonlin(out) 113 | return out 114 | 115 | 116 | class ConvBlockDepth(nn.Module): 117 | """Layer to perform a convolution followed by ELU 118 | """ 119 | def __init__(self, in_channels, out_channels): 120 | super(ConvBlockDepth, self).__init__() 121 | 122 | self.conv = DepthConv3x3(in_channels, out_channels) 123 | self.nonlin = nn.GELU() 124 | 125 | def forward(self, x): 126 | out = self.conv(x) 127 | out = self.nonlin(out) 128 | return out 129 | 130 | 131 | class DepthConv3x3(nn.Module): 132 | """Layer to pad and convolve input 133 | """ 134 | def __init__(self, in_channels, out_channels, use_refl=True): 135 | super(DepthConv3x3, self).__init__() 136 | 137 | if use_refl: 138 | self.pad = nn.ReflectionPad2d(1) 139 | else: 140 | self.pad = nn.ZeroPad2d(1) 141 | # self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3) 142 | self.conv = nn.Conv2d(int(in_channels), int(out_channels), kernel_size=3, groups=int(out_channels), bias=False) 143 | 144 | def forward(self, x): 145 | out = self.pad(x) 146 | out = self.conv(out) 147 | return out 148 | 149 | class Conv3x3(nn.Module): 150 | """Layer to pad and convolve input 151 | """ 152 | def __init__(self, in_channels, out_channels, use_refl=True): 153 | super(Conv3x3, self).__init__() 154 | 155 | if use_refl: 156 | self.pad = nn.ReflectionPad2d(1) 157 | else: 158 | self.pad = nn.ZeroPad2d(1) 159 | self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3) 160 | # self.conv = nn.Conv2d(int(in_channels), int(out_channels), kernel_size=3, padding=3 // 2, groups=int(out_channels), bias=False) 161 | 162 | def forward(self, x): 163 | out = self.pad(x) 164 | out = self.conv(out) 165 | return out 166 | 167 | 168 | class BackprojectDepth(nn.Module): 169 | """Layer to transform a depth image into a point cloud 170 | """ 171 | def __init__(self, batch_size, height, width): 172 | super(BackprojectDepth, self).__init__() 173 | 174 | self.batch_size = batch_size 175 | self.height = height 176 | self.width = width 177 | 178 | meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy') 179 | self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32) 180 | self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords), 181 | requires_grad=False) 182 | 183 | self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width), 184 | requires_grad=False) 185 | 186 | self.pix_coords = torch.unsqueeze(torch.stack( 187 | [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0) 188 | self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1) 189 | self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1), 190 | requires_grad=False) 191 | 192 | def forward(self, depth, inv_K): 193 | cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords) 194 | cam_points = depth.view(self.batch_size, 1, -1) * cam_points 195 | cam_points = torch.cat([cam_points, self.ones], 1) 196 | 197 | return cam_points 198 | 199 | 200 | class Project3D(nn.Module): 201 | """Layer which projects 3D points into a camera with intrinsics K and at position T 202 | """ 203 | def __init__(self, batch_size, height, width, eps=1e-7): 204 | super(Project3D, self).__init__() 205 | 206 | self.batch_size = batch_size 207 | self.height = height 208 | self.width = width 209 | self.eps = eps 210 | 211 | def forward(self, points, K, T): 212 | P = torch.matmul(K, T)[:, :3, :] 213 | 214 | cam_points = torch.matmul(P, points) 215 | 216 | pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps) 217 | pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width) 218 | pix_coords = pix_coords.permute(0, 2, 3, 1) 219 | pix_coords[..., 0] /= self.width - 1 220 | pix_coords[..., 1] /= self.height - 1 221 | pix_coords = (pix_coords - 0.5) * 2 222 | return pix_coords 223 | 224 | 225 | def upsample(x, scale_factor=2, mode="bilinear"): 226 | """Upsample input tensor by a factor of 2 227 | """ 228 | return F.interpolate(x, scale_factor=scale_factor, mode=mode) 229 | 230 | 231 | def get_smooth_loss(disp, img): 232 | """Computes the smoothness loss for a disparity image 233 | The color image is used for edge-aware smoothness 234 | """ 235 | grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:]) 236 | grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :]) 237 | 238 | grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True) 239 | grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True) 240 | 241 | grad_disp_x *= torch.exp(-grad_img_x) 242 | grad_disp_y *= torch.exp(-grad_img_y) 243 | 244 | return grad_disp_x.mean() + grad_disp_y.mean() 245 | 246 | 247 | class SSIM(nn.Module): 248 | """Layer to compute the SSIM loss between a pair of images 249 | """ 250 | def __init__(self): 251 | super(SSIM, self).__init__() 252 | self.mu_x_pool = nn.AvgPool2d(3, 1) 253 | self.mu_y_pool = nn.AvgPool2d(3, 1) 254 | self.sig_x_pool = nn.AvgPool2d(3, 1) 255 | self.sig_y_pool = nn.AvgPool2d(3, 1) 256 | self.sig_xy_pool = nn.AvgPool2d(3, 1) 257 | 258 | self.refl = nn.ReflectionPad2d(1) 259 | 260 | self.C1 = 0.01 ** 2 261 | self.C2 = 0.03 ** 2 262 | 263 | def forward(self, x, y): 264 | x = self.refl(x) 265 | y = self.refl(y) 266 | 267 | mu_x = self.mu_x_pool(x) 268 | mu_y = self.mu_y_pool(y) 269 | 270 | sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2 271 | sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2 272 | sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y 273 | 274 | SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2) 275 | SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2) 276 | 277 | return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) 278 | 279 | 280 | def compute_depth_errors(gt, pred): 281 | """Computation of error metrics between predicted and ground truth depths 282 | """ 283 | thresh = torch.max((gt / pred), (pred / gt)) 284 | a1 = (thresh < 1.25 ).float().mean() 285 | a2 = (thresh < 1.25 ** 2).float().mean() 286 | a3 = (thresh < 1.25 ** 3).float().mean() 287 | 288 | rmse = (gt - pred) ** 2 289 | rmse = torch.sqrt(rmse.mean()) 290 | 291 | rmse_log = (torch.log(gt) - torch.log(pred)) ** 2 292 | rmse_log = torch.sqrt(rmse_log.mean()) 293 | 294 | abs_rel = torch.mean(torch.abs(gt - pred) / gt) 295 | 296 | sq_rel = torch.mean((gt - pred) ** 2 / gt) 297 | 298 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 299 | 300 | 301 | -------------------------------------------------------------------------------- /license: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Noah 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /lite-mono-pretrain-code/README.md: -------------------------------------------------------------------------------- 1 | I have no time to clean up the code, but I have kept all the necessary lines in ‘litemono.py’ to run the pretraining. 2 | 3 | To be short, the last layer of the encoder of Lite-Mono should have `1000` channels for the classification task. Please add these lines to your current model file. The `main.py` file will create the model from this file. 4 | 5 | 6 | To train on a single machine with 2 GPUs, using the following command. 7 | 8 | python -m torch.distributed.launch --nproc_per_node=2 main.py --data_path data/imagenet/ 9 | 10 | 11 | Plese check the code if you want to change parameters such as epochs, learning rates, etc. -------------------------------------------------------------------------------- /lite-mono-pretrain-code/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torchvision import datasets, transforms 3 | from PIL import Image 4 | 5 | from timm.data.constants import \ 6 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 7 | from timm.data import create_transform 8 | from sampler import MultiScaleImageFolder 9 | 10 | 11 | # from typing import Any, Callable, cast, Dict, List, Optional, Tuple 12 | # from typing import Union 13 | # 14 | # from PIL import Image 15 | # # IMAGENET_DEFAULT_MEAN = (0.445, ) 16 | # # IMAGENET_DEFAULT_STD = (0.269, ) 17 | # IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") 18 | # 19 | # 20 | # def pil_loader(path: str) -> Image.Image: 21 | # # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 22 | # with open(path, "rb") as f: 23 | # img = Image.open(f) 24 | # return img.convert("L") 25 | # 26 | # 27 | # # TODO: specify the return type 28 | # def accimage_loader(path: str) -> Any: 29 | # import accimage 30 | # try: 31 | # return accimage.Image(path) 32 | # except OSError: 33 | # # Potentially a decoding problem, fall back to PIL.Image 34 | # return pil_loader(path) 35 | # 36 | # 37 | # def default_loader(path: str) -> Any: 38 | # from torchvision import get_image_backend 39 | # 40 | # if get_image_backend() == "accimage": 41 | # return accimage_loader(path) 42 | # else: 43 | # 44 | # return pil_loader(path) 45 | # 46 | # 47 | # class MyImageFolder(datasets.DatasetFolder): 48 | # def __init__( 49 | # self, 50 | # root: str, 51 | # transform: Optional[Callable] = None, 52 | # target_transform: Optional[Callable] = None, 53 | # loader: Callable[[str], Any] = default_loader, 54 | # is_valid_file: Optional[Callable[[str], bool]] = None, 55 | # ): 56 | # super().__init__( 57 | # root, 58 | # loader, 59 | # IMG_EXTENSIONS if is_valid_file is None else None, 60 | # transform=transform, 61 | # target_transform=target_transform, 62 | # is_valid_file=is_valid_file, 63 | # ) 64 | # self.imgs = self.samples 65 | 66 | 67 | def build_dataset(is_train, args): 68 | transform = build_transform(is_train, args) 69 | 70 | print("Transform = ") 71 | if isinstance(transform, tuple): 72 | for trans in transform: 73 | print(" - - - - - - - - - - ") 74 | for t in trans.transforms: 75 | print(t) 76 | else: 77 | for t in transform.transforms: 78 | print(t) 79 | print("---------------------------") 80 | 81 | if args.data_set == 'CIFAR': 82 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True) 83 | nb_classes = 100 84 | elif args.data_set == 'IMNET': 85 | print("reading from datapath", args.data_path) 86 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 87 | if is_train and args.multi_scale_sampler: 88 | dataset = MultiScaleImageFolder(root, args) 89 | else: 90 | dataset = datasets.ImageFolder(root, transform=transform) 91 | nb_classes = 1000 92 | elif args.data_set == "image_folder": 93 | root = args.data_path if is_train else args.eval_data_path 94 | dataset = datasets.ImageFolder(root, transform=transform) 95 | nb_classes = args.nb_classes 96 | assert len(dataset.class_to_idx) == nb_classes 97 | else: 98 | raise NotImplementedError() 99 | print("Number of the class = %d" % nb_classes) 100 | 101 | return dataset, nb_classes 102 | 103 | 104 | def build_transform(is_train, args): 105 | resize_im = args.input_size > 32 106 | imagenet_default_mean_and_std = args.imagenet_default_mean_and_std 107 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN 108 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD 109 | 110 | if is_train: 111 | # This should always dispatch to transforms_imagenet_train 112 | transform = create_transform( 113 | input_size=args.input_size, 114 | is_training=True, 115 | color_jitter=args.color_jitter if args.color_jitter > 0 else None, 116 | auto_augment=args.aa, 117 | interpolation=args.train_interpolation, 118 | re_prob=args.reprob, 119 | re_mode=args.remode, 120 | re_count=args.recount, 121 | mean=mean, 122 | std=std, 123 | ) 124 | if args.three_aug: # --aa should not be "" to use this as it actually overrides the auto-augment 125 | print(f"Using 3-Augments instead of Rand Augment") 126 | cur_augs = transform.transforms 127 | three_aug = transforms.RandomChoice([transforms.Grayscale(num_output_channels=3), 128 | transforms.RandomSolarize(threshold=192.0), 129 | transforms.GaussianBlur(kernel_size=(5, 9))]) 130 | final_transforms = cur_augs[0:2] + [three_aug] + cur_augs[2:] 131 | transform = transforms.Compose(final_transforms) 132 | if not resize_im: 133 | transform.transforms[0] = transforms.RandomCrop( 134 | args.input_size, padding=4) 135 | return transform 136 | 137 | t = [] 138 | if resize_im: 139 | # Warping (no cropping) when evaluated at 384 or larger 140 | if args.input_size >= 384: 141 | t.append( 142 | transforms.Resize((args.input_size, args.input_size), 143 | interpolation=transforms.InterpolationMode.BICUBIC), 144 | ) 145 | print(f"Warping {args.input_size} size input images...") 146 | else: 147 | if args.crop_pct is None: 148 | args.crop_pct = 224 / 256 149 | size = int(args.input_size / args.crop_pct) 150 | t.append( 151 | # To maintain same ratio w.r.t. 224 images 152 | transforms.Resize(size, interpolation=Image.BICUBIC), 153 | ) 154 | t.append(transforms.CenterCrop(args.input_size)) 155 | 156 | t.append(transforms.ToTensor()) 157 | t.append(transforms.Normalize(mean, std)) 158 | return transforms.Compose(t) 159 | -------------------------------------------------------------------------------- /lite-mono-pretrain-code/engine.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Iterable, Optional 3 | import torch 4 | from timm.data import Mixup 5 | from timm.utils import accuracy, ModelEma 6 | 7 | import utils 8 | 9 | 10 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 11 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 12 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 13 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, 14 | wandb_logger=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None, 15 | num_training_steps_per_epoch=None, update_freq=None, use_amp=False): 16 | model.train(True) 17 | metric_logger = utils.MetricLogger(delimiter=" ") 18 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 19 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 20 | header = 'Epoch: [{}]'.format(epoch) 21 | print_freq = 10 22 | 23 | optimizer.zero_grad() 24 | 25 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 26 | step = data_iter_step // update_freq 27 | if step >= num_training_steps_per_epoch: 28 | continue 29 | it = start_steps + step # Global training iteration 30 | # Update LR & WD for the first acc 31 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0: 32 | for i, param_group in enumerate(optimizer.param_groups): 33 | if lr_schedule_values is not None: 34 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 35 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 36 | param_group["weight_decay"] = wd_schedule_values[it] 37 | 38 | samples = samples.to(device, non_blocking=True) 39 | targets = targets.to(device, non_blocking=True) 40 | 41 | if mixup_fn is not None: 42 | samples, targets = mixup_fn(samples, targets) 43 | 44 | if use_amp: 45 | with torch.cuda.amp.autocast(): 46 | output = model(samples) 47 | loss = criterion(output, targets) 48 | else: # Full precision 49 | output = model(samples) 50 | loss = criterion(output, targets) 51 | 52 | loss_value = loss.item() 53 | 54 | if not math.isfinite(loss_value): # This could trigger if using AMP 55 | print("Loss is {}, stopping training".format(loss_value)) 56 | assert math.isfinite(loss_value) 57 | 58 | if use_amp: 59 | # This attribute is added by timm on one optimizer (adahessian) 60 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 61 | loss /= update_freq 62 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 63 | parameters=model.parameters(), create_graph=is_second_order, 64 | update_grad=(data_iter_step + 1) % update_freq == 0) 65 | 66 | # for name, param in model.named_parameters(): 67 | # if param.grad is None: 68 | # print(name) 69 | if (data_iter_step + 1) % update_freq == 0: 70 | optimizer.zero_grad() 71 | if model_ema is not None: 72 | model_ema.update(model) 73 | else: # Full precision 74 | loss /= update_freq 75 | loss.backward() 76 | if (data_iter_step + 1) % update_freq == 0: 77 | optimizer.step() 78 | optimizer.zero_grad() 79 | if model_ema is not None: 80 | model_ema.update(model) 81 | 82 | torch.cuda.synchronize() 83 | 84 | if mixup_fn is None: 85 | class_acc = (output.max(-1)[-1] == targets).float().mean() 86 | else: 87 | class_acc = None 88 | metric_logger.update(loss=loss_value) 89 | metric_logger.update(class_acc=class_acc) 90 | min_lr = 10. 91 | max_lr = 0. 92 | for group in optimizer.param_groups: 93 | min_lr = min(min_lr, group["lr"]) 94 | max_lr = max(max_lr, group["lr"]) 95 | 96 | metric_logger.update(lr=max_lr) 97 | metric_logger.update(min_lr=min_lr) 98 | weight_decay_value = None 99 | for group in optimizer.param_groups: 100 | if group["weight_decay"] > 0: 101 | weight_decay_value = group["weight_decay"] 102 | metric_logger.update(weight_decay=weight_decay_value) 103 | if use_amp: 104 | metric_logger.update(grad_norm=grad_norm) 105 | 106 | if log_writer is not None: 107 | log_writer.update(loss=loss_value, head="loss") 108 | log_writer.update(class_acc=class_acc, head="loss") 109 | log_writer.update(lr=max_lr, head="opt") 110 | log_writer.update(min_lr=min_lr, head="opt") 111 | log_writer.update(weight_decay=weight_decay_value, head="opt") 112 | if use_amp: 113 | log_writer.update(grad_norm=grad_norm, head="opt") 114 | log_writer.set_step() 115 | 116 | if wandb_logger: 117 | wandb_logger._wandb.log({ 118 | 'Rank-0 Batch Wise/train_loss': loss_value, 119 | 'Rank-0 Batch Wise/train_max_lr': max_lr, 120 | 'Rank-0 Batch Wise/train_min_lr': min_lr 121 | }, commit=False) 122 | if class_acc: 123 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_class_acc': class_acc}, commit=False) 124 | if use_amp: 125 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_grad_norm': grad_norm}, commit=False) 126 | wandb_logger._wandb.log({'Rank-0 Batch Wise/global_train_step': it}) 127 | 128 | # Gather the stats from all processes 129 | metric_logger.synchronize_between_processes() 130 | print("Averaged stats:", metric_logger) 131 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 132 | 133 | 134 | @torch.no_grad() 135 | def evaluate(data_loader, model, device, use_amp=False): 136 | criterion = torch.nn.CrossEntropyLoss() 137 | 138 | metric_logger = utils.MetricLogger(delimiter=" ") 139 | header = 'Test:' 140 | 141 | # Switch to evaluation mode 142 | model.eval() 143 | for batch in metric_logger.log_every(data_loader, 10, header): 144 | images = batch[0] 145 | target = batch[-1] 146 | 147 | images = images.to(device, non_blocking=True) 148 | target = target.to(device, non_blocking=True) 149 | 150 | # Compute output 151 | if use_amp: 152 | with torch.cuda.amp.autocast(): 153 | output = model(images) 154 | loss = criterion(output, target) 155 | else: 156 | output = model(images) 157 | loss = criterion(output, target) 158 | 159 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 160 | 161 | batch_size = images.shape[0] 162 | metric_logger.update(loss=loss.item()) 163 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 164 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 165 | # Gather the stats from all processes 166 | metric_logger.synchronize_between_processes() 167 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 168 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 169 | 170 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 171 | -------------------------------------------------------------------------------- /lite-mono-pretrain-code/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import numpy as np 4 | import time 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import json 8 | import os 9 | 10 | from pathlib import Path 11 | 12 | from timm.data.mixup import Mixup 13 | from timm.models import create_model 14 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 15 | from timm.utils import ModelEma 16 | from optim_factory import create_optimizer 17 | 18 | from datasets import build_dataset 19 | from engine import train_one_epoch, evaluate 20 | 21 | from utils import NativeScalerWithGradNormCount as NativeScaler 22 | import utils 23 | import models.model 24 | 25 | from sampler import MultiScaleSamplerDDP 26 | from fvcore.nn import FlopCountAnalysis 27 | 28 | 29 | def str2bool(v): 30 | """ 31 | Converts string to bool type; enables command line 32 | arguments in the format of '--arg1 true --arg2 false' 33 | """ 34 | if isinstance(v, bool): 35 | return v 36 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 37 | return True 38 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 39 | return False 40 | else: 41 | raise argparse.ArgumentTypeError('Boolean value expected.') 42 | 43 | 44 | def get_args_parser(): 45 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script for image classification', add_help=False) 46 | parser.add_argument('--batch_size', default=512, type=int, 47 | help='Per GPU batch size') 48 | parser.add_argument('--epochs', default=100, type=int) 49 | parser.add_argument('--update_freq', default=2, type=int, 50 | help='gradient accumulation steps') 51 | 52 | # Model parameters 53 | parser.add_argument('--model', default='depth_encoder', type=str, metavar='MODEL', 54 | help='Name of model to train') 55 | parser.add_argument('--drop_path', type=float, default=0.0, metavar='PCT', 56 | help='Drop path rate (default: 0.0)') 57 | parser.add_argument('--input_size', default=256, type=int, 58 | help='image input size') 59 | parser.add_argument('--layer_scale_init_value', default=1e-5, type=float, 60 | help="Layer scale initial values") 61 | 62 | # EMA related parameters 63 | parser.add_argument('--model_ema', type=str2bool, default=False) 64 | parser.add_argument('--model_ema_decay', type=float, default=0.9995, help='') # TODO: MobileViT is using 0.9995 65 | parser.add_argument('--model_ema_force_cpu', type=str2bool, default=False, help='') 66 | parser.add_argument('--model_ema_eval', type=str2bool, default=False, help='Using ema to eval during training.') 67 | 68 | # Optimization parameters 69 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', help='Optimizer (default: "adamw"') 70 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', 71 | help='Optimizer Epsilon (default: 1e-8)') 72 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA', 73 | help='Optimizer Betas (default: None, use opt default)') 74 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 75 | help='Clip gradient norm (default: None, no clipping)') 76 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 77 | help='SGD momentum (default: 0.9)') 78 | parser.add_argument('--weight_decay', type=float, default=0.05, 79 | help='weight decay (default: 0.05)') 80 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the 81 | weight decay. We use a cosine schedule for WD and using a larger decay by 82 | the end of training improves performance for ViTs.""") 83 | 84 | parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', 85 | help='learning rate (default: 6e-3), with total batch size 4096') 86 | parser.add_argument('--layer_decay', type=float, default=1.0) 87 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 88 | help='lower lr bound for cyclic schedulers that hit 0 (1e-6)') 89 | parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', 90 | help='epochs to warmup LR, if scheduler supports') 91 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', 92 | help='num of steps to warmup LR, will overload warmup_epochs if set > 0') 93 | parser.add_argument('--warmup_start_lr', type=float, default=0, metavar='LR', 94 | help='Starting LR for warmup (default 0)') 95 | 96 | # Augmentation parameters 97 | parser.add_argument('--color_jitter', type=float, default=0.0, metavar='PCT', 98 | help='Color jitter factor (default: 0.4)') 99 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 100 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 101 | parser.add_argument('--smoothing', type=float, default=0.0, 102 | help='Label smoothing (default: 0.1)') 103 | parser.add_argument('--train_interpolation', type=str, default='bicubic', 104 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 105 | 106 | # Evaluation parameters 107 | parser.add_argument('--crop_pct', type=float, default=None) 108 | 109 | # * Random Erase params 110 | parser.add_argument('--reprob', type=float, default=0.0, metavar='PCT', 111 | help='Random erase prob (default: 0.0)') 112 | parser.add_argument('--remode', type=str, default='pixel', 113 | help='Random erase mode (default: "pixel")') 114 | parser.add_argument('--recount', type=int, default=1, 115 | help='Random erase count (default: 1)') 116 | parser.add_argument('--resplit', type=str2bool, default=False, 117 | help='Do not random erase first (clean) augmentation split') 118 | 119 | # Mixup params 120 | parser.add_argument('--mixup', type=float, default=0.0, 121 | help='mixup alpha, mixup enabled if > 0.') 122 | parser.add_argument('--cutmix', type=float, default=0.0, 123 | help='cutmix alpha, cutmix enabled if > 0.') 124 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 125 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 126 | parser.add_argument('--mixup_prob', type=float, default=0.0, 127 | help='Probability of performing mixup or cutmix when either/both is enabled') 128 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 129 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 130 | parser.add_argument('--mixup_mode', type=str, default='batch', 131 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 132 | 133 | # Dataset parameters 134 | parser.add_argument('--data_path', default='data/ILSVRC2012/imagenet/', type=str, 135 | help='dataset path (path to full imagenet)') 136 | parser.add_argument('--eval_data_path', default=None, type=str, 137 | help='dataset path for evaluation') 138 | parser.add_argument('--nb_classes', default=1000, type=int, 139 | help='number of the classification types') 140 | parser.add_argument('--imagenet_default_mean_and_std', type=str2bool, default=True) 141 | parser.add_argument('--data_set', default='IMNET', choices=['IMNET', 'image_folder'], 142 | type=str, help='ImageNet dataset path') 143 | parser.add_argument('--output_dir', default='./save', 144 | help='path where to save, empty for no saving') 145 | parser.add_argument('--log_dir', default=None, 146 | help='path where to tensorboard log') 147 | parser.add_argument('--device', default='cuda', 148 | help='device to use for training / testing') 149 | parser.add_argument('--seed', default=0, type=int) 150 | 151 | parser.add_argument('--resume', default='', 152 | help='resume from checkpoint') 153 | parser.add_argument('--auto_resume', type=str2bool, default=False) 154 | parser.add_argument('--save_ckpt', type=str2bool, default=True) 155 | parser.add_argument('--save_ckpt_freq', default=1, type=int) 156 | parser.add_argument('--save_ckpt_num', default=3, type=int) 157 | 158 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 159 | help='start epoch') 160 | parser.add_argument('--eval', type=str2bool, default=False, 161 | help='Perform evaluation only') 162 | parser.add_argument('--dist_eval', type=str2bool, default=True, 163 | help='Enabling distributed evaluation') 164 | parser.add_argument('--disable_eval', type=str2bool, default=False, 165 | help='Disabling evaluation during training') 166 | parser.add_argument('--num_workers', default=10, type=int) 167 | parser.add_argument('--pin_mem', type=str2bool, default=True, 168 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 169 | 170 | # Distributed training parameters 171 | parser.add_argument('--world_size', default=1, type=int, 172 | help='number of distributed processes') 173 | parser.add_argument('--local_rank', default=-1, type=int) 174 | parser.add_argument('--dist_on_itp', type=str2bool, default=False) 175 | parser.add_argument('--dist_url', default='env://', 176 | help='url used to set up distributed training') 177 | 178 | parser.add_argument('--use_amp', type=str2bool, default=True, 179 | help="Use PyTorch's AMP (Automatic Mixed Precision) or not") 180 | 181 | # Weights and Biases arguments 182 | parser.add_argument('--enable_wandb', type=str2bool, default=False, 183 | help="enable logging to Weights and Biases") 184 | parser.add_argument('--project', default='litemono', type=str, 185 | help="The name of the W&B project where you're sending the new run.") 186 | parser.add_argument('--wandb_ckpt', type=str2bool, default=False, 187 | help="Save model checkpoints as W&B Artifacts.") 188 | parser.add_argument("--multi_scale_sampler", action="store_true", help="Either to use multi-scale sampler or not.") 189 | parser.add_argument('--min_crop_size_w', default=160, type=int) 190 | parser.add_argument('--max_crop_size_w', default=320, type=int) 191 | parser.add_argument('--min_crop_size_h', default=160, type=int) 192 | parser.add_argument('--max_crop_size_h', default=320, type=int) 193 | parser.add_argument("--find_unused_params", action="store_true", 194 | help="Set this flag to enable unused parameters finding in DistributedDataParallel()") 195 | parser.add_argument("--three_aug", action="store_true", 196 | help="Either to use three augments proposed by DeiT-III") 197 | parser.add_argument('--classifier_dropout', default=0.0, type=float) 198 | parser.add_argument('--usi_eval', type=str2bool, default=False, 199 | help="Enable it when testing USI model.") 200 | 201 | return parser 202 | 203 | 204 | def main(args): 205 | utils.init_distributed_mode(args) 206 | print(args) 207 | device = torch.device(args.device) 208 | 209 | # Eval/USI_eval configurations 210 | if args.eval: 211 | if args.usi_eval: 212 | args.crop_pct = 0.95 213 | model_state_dict_name = 'state_dict' 214 | else: 215 | model_state_dict_name = 'model_ema' 216 | else: 217 | model_state_dict_name = 'model' 218 | 219 | # Fix the seed for reproducibility 220 | seed = args.seed + utils.get_rank() 221 | torch.manual_seed(seed) 222 | np.random.seed(seed) 223 | cudnn.benchmark = True 224 | 225 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 226 | if args.disable_eval: 227 | args.dist_eval = False 228 | dataset_val = None 229 | else: 230 | dataset_val, _ = build_dataset(is_train=False, args=args) 231 | 232 | num_tasks = utils.get_world_size() 233 | global_rank = utils.get_rank() 234 | if args.multi_scale_sampler: 235 | sampler_train = MultiScaleSamplerDDP(base_im_w=args.input_size, base_im_h=args.input_size, 236 | base_batch_size=args.batch_size, n_data_samples=len(dataset_train), 237 | is_training=True, distributed=args.distributed, 238 | min_crop_size_w=args.min_crop_size_w, max_crop_size_w=args.max_crop_size_w, 239 | min_crop_size_h=args.min_crop_size_h, max_crop_size_h=args.max_crop_size_h) 240 | else: 241 | sampler_train = torch.utils.data.DistributedSampler( 242 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=args.seed, 243 | ) 244 | print("Sampler_train = %s" % str(sampler_train)) 245 | if args.dist_eval: 246 | if len(dataset_val) % num_tasks != 0: 247 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 248 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 249 | 'equal num of samples per-process.') 250 | sampler_val = torch.utils.data.DistributedSampler( 251 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 252 | else: 253 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 254 | 255 | if global_rank == 0 and args.log_dir is not None: 256 | os.makedirs(args.log_dir, exist_ok=True) 257 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir) 258 | else: 259 | log_writer = None 260 | 261 | if global_rank == 0 and args.enable_wandb: 262 | wandb_logger = utils.WandbLogger(args) 263 | else: 264 | wandb_logger = None 265 | 266 | if args.multi_scale_sampler: 267 | data_loader_train = torch.utils.data.DataLoader( 268 | dataset_train, batch_sampler=sampler_train, 269 | batch_size=1, 270 | num_workers=args.num_workers, 271 | pin_memory=args.pin_mem, 272 | ) 273 | else: 274 | data_loader_train = torch.utils.data.DataLoader( 275 | dataset_train, sampler=sampler_train, 276 | batch_size=args.batch_size, 277 | num_workers=args.num_workers, 278 | pin_memory=args.pin_mem, 279 | drop_last=True, 280 | ) 281 | 282 | if dataset_val is not None: 283 | data_loader_val = torch.utils.data.DataLoader( 284 | dataset_val, sampler=sampler_val, 285 | batch_size=int(1.5 * args.batch_size), 286 | num_workers=args.num_workers, 287 | pin_memory=args.pin_mem, 288 | drop_last=False 289 | ) 290 | else: 291 | data_loader_val = None 292 | 293 | mixup_fn = None 294 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 295 | if mixup_active: 296 | print("Mixup is activated!") 297 | mixup_fn = Mixup( 298 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 299 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 300 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 301 | 302 | model = create_model( 303 | args.model, 304 | pretrained=False, 305 | num_classes=args.nb_classes, 306 | drop_path_rate=args.drop_path, 307 | layer_scale_init_value=args.layer_scale_init_value, 308 | head_init_scale=1.0, 309 | input_res=args.input_size, 310 | classifier_dropout=args.classifier_dropout, 311 | ) 312 | model.to(device) 313 | 314 | model_ema = None 315 | if args.model_ema: 316 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 317 | model_ema = ModelEma( 318 | model, 319 | decay=args.model_ema_decay, 320 | device='cpu' if args.model_ema_force_cpu else '', 321 | resume='') 322 | print("Using EMA with decay = %.8f" % args.model_ema_decay) 323 | 324 | model_without_ddp = model 325 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 326 | 327 | print("Model = %s" % str(model_without_ddp)) 328 | print('number of params:', n_parameters) 329 | 330 | total_batch_size = args.batch_size * args.update_freq * utils.get_world_size() 331 | num_training_steps_per_epoch = len(dataset_train) // total_batch_size 332 | print("LR = %.8f" % args.lr) 333 | print("Batch size = %d" % total_batch_size) 334 | print("Update frequent = %d" % args.update_freq) 335 | print("Number of training examples = %d" % len(dataset_train)) 336 | print("Number of training training per epoch = %d" % num_training_steps_per_epoch) 337 | 338 | if args.layer_decay < 1.0 or args.layer_decay > 1.0: 339 | # Layer decay not supported 340 | raise NotImplementedError 341 | else: 342 | assigner = None 343 | 344 | if assigner is not None: 345 | print("Assigned values = %s" % str(assigner.values)) 346 | 347 | if args.distributed: 348 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], 349 | find_unused_parameters=args.find_unused_params) 350 | model_without_ddp = model.module 351 | 352 | optimizer = create_optimizer( 353 | args, model_without_ddp, skip_list=None, 354 | get_num_layer=assigner.get_layer_id if assigner is not None else None, 355 | get_layer_scale=assigner.get_scale if assigner is not None else None) 356 | 357 | loss_scaler = NativeScaler() # if args.use_amp is False, this won't be used 358 | 359 | print("Use Cosine LR scheduler") 360 | lr_schedule_values = utils.cosine_scheduler( 361 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch, 362 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps, 363 | start_warmup_value=args.warmup_start_lr 364 | ) 365 | 366 | if args.weight_decay_end is None: 367 | args.weight_decay_end = args.weight_decay 368 | wd_schedule_values = utils.cosine_scheduler( 369 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch) 370 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values))) 371 | 372 | if mixup_fn is not None: 373 | # smoothing is handled with mixup label transform 374 | criterion = SoftTargetCrossEntropy() 375 | elif args.smoothing > 0.: 376 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 377 | else: 378 | criterion = torch.nn.CrossEntropyLoss() 379 | 380 | print("criterion = %s" % str(criterion)) 381 | 382 | utils.auto_load_model( 383 | args=args, model=model, model_without_ddp=model_without_ddp, 384 | optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema, state_dict_name=model_state_dict_name) 385 | 386 | if args.eval: 387 | print(f"Eval only mode") 388 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp) 389 | print(f"Accuracy of the network on {len(dataset_val)} test images: {test_stats['acc1']:.5f}%") 390 | return 391 | 392 | max_accuracy = 0.0 393 | if args.model_ema and args.model_ema_eval: 394 | max_accuracy_ema = 0.0 395 | 396 | def count_parameters(model): 397 | total_trainable_params = 0 398 | for name, parameter in model.named_parameters(): 399 | if not parameter.requires_grad: 400 | continue 401 | params = parameter.numel() 402 | total_trainable_params += params 403 | return total_trainable_params 404 | 405 | total_params = count_parameters(model) 406 | # fvcore to calculate MAdds 407 | input_res = (3, args.input_size, args.input_size) 408 | input = torch.ones(()).new_empty((1, *input_res), dtype=next(model.parameters()).dtype, 409 | device=next(model.parameters()).device) 410 | flops = FlopCountAnalysis(model, input) 411 | model_flops = flops.total() 412 | print(f"Total Trainable Params: {round(total_params * 1e-6, 2)} M") 413 | print(f"MAdds: {round(model_flops * 1e-6, 2)} M") 414 | 415 | print("Start training for %d epochs" % args.epochs) 416 | start_time = time.time() 417 | for epoch in range(args.start_epoch, args.epochs): 418 | if args.multi_scale_sampler: 419 | data_loader_train.batch_sampler.set_epoch(epoch) 420 | elif args.distributed: 421 | data_loader_train.sampler.set_epoch(epoch) 422 | if log_writer is not None: 423 | log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq) 424 | if wandb_logger: 425 | wandb_logger.set_steps() 426 | train_stats = train_one_epoch( 427 | model, criterion, data_loader_train, optimizer, 428 | device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, 429 | log_writer=log_writer, wandb_logger=wandb_logger, start_steps=epoch * num_training_steps_per_epoch, 430 | lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values, 431 | num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq, 432 | use_amp=args.use_amp 433 | ) 434 | if args.output_dir and args.save_ckpt: 435 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs: 436 | utils.save_model( 437 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 438 | loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema) 439 | if data_loader_val is not None: 440 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp) 441 | print(f"Accuracy of the model on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 442 | if max_accuracy < test_stats["acc1"]: 443 | max_accuracy = test_stats["acc1"] 444 | if args.output_dir and args.save_ckpt: 445 | utils.save_model( 446 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 447 | loss_scaler=loss_scaler, epoch="best", model_ema=model_ema) 448 | print(f'Max accuracy: {max_accuracy:.2f}%') 449 | 450 | if log_writer is not None: 451 | log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch) 452 | log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch) 453 | log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch) 454 | 455 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 456 | **{f'test_{k}': v for k, v in test_stats.items()}, 457 | 'epoch': epoch, 458 | 'n_parameters': n_parameters} 459 | 460 | # Repeat testing routines for EMA, if ema eval is turned on 461 | if args.model_ema and args.model_ema_eval: 462 | test_stats_ema = evaluate(data_loader_val, model_ema.ema, device, use_amp=args.use_amp) 463 | print(f"Accuracy of the model EMA on {len(dataset_val)} test images: {test_stats_ema['acc1']:.1f}%") 464 | if max_accuracy_ema < test_stats_ema["acc1"]: 465 | max_accuracy_ema = test_stats_ema["acc1"] 466 | if args.output_dir and args.save_ckpt: 467 | utils.save_model( 468 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 469 | loss_scaler=loss_scaler, epoch="best-ema", model_ema=model_ema) 470 | print(f'Max EMA accuracy: {max_accuracy_ema:.2f}%') 471 | if log_writer is not None: 472 | log_writer.update(test_acc1_ema=test_stats_ema['acc1'], head="perf", step=epoch) 473 | log_stats.update({**{f'test_{k}_ema': v for k, v in test_stats_ema.items()}}) 474 | else: 475 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 476 | 'epoch': epoch, 477 | 'n_parameters': n_parameters} 478 | 479 | if args.output_dir and utils.is_main_process(): 480 | if log_writer is not None: 481 | log_writer.flush() 482 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 483 | f.write(json.dumps(log_stats) + "\n") 484 | 485 | if wandb_logger: 486 | wandb_logger.log_epoch_metrics(log_stats) 487 | 488 | if wandb_logger and args.wandb_ckpt and args.save_ckpt and args.output_dir: 489 | wandb_logger.log_checkpoints() 490 | 491 | total_time = time.time() - start_time 492 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 493 | print('Training time {}'.format(total_time_str)) 494 | 495 | 496 | if __name__ == '__main__': 497 | parser = argparse.ArgumentParser('Lite-Mono pretraining code', parents=[get_args_parser()]) 498 | args = parser.parse_args() 499 | if args.output_dir: 500 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 501 | main(args) 502 | -------------------------------------------------------------------------------- /lite-mono-pretrain-code/models/litemono.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | from torch import nn 6 | import torchvision.models as models 7 | import torch.utils.model_zoo as model_zoo 8 | import torch.nn.functional as F 9 | from timm.models.layers import DropPath 10 | from timm.models.layers import trunc_normal_ 11 | import math 12 | import torch.cuda 13 | 14 | __all__ = ["litemono"] 15 | 16 | 17 | class LiteMono(nn.Module): 18 | 19 | def __init__(self, in_chans=3, num_classes=1000, head_init_scale=1., ** kwargs): 20 | super().__init__() 21 | 22 | 23 | self.norm = nn.LayerNorm(dims[-2], eps=1e-6) # Final norm layer 24 | self.head = nn.Linear(dims[-2], num_classes) 25 | 26 | self.apply(self._init_weights) 27 | self.head_dropout = nn.Dropout() 28 | self.head.weight.data.mul_(head_init_scale) 29 | self.head.bias.data.mul_(head_init_scale) 30 | 31 | 32 | def forward_features(self, x): 33 | # x = (x - 0.45) / 0.225 don't do the normalization because the training script already does that! so please keep this line commented. 34 | 35 | for i in range(1, 3): 36 | ... 37 | 38 | return self.norm(x.mean([-2, -1])) # Global average pooling, (N, C, H, W) -> (N, C) 39 | 40 | def forward(self, x): 41 | x = self.forward_features(x) 42 | x = self.head(x) 43 | return x 44 | 45 | -------------------------------------------------------------------------------- /lite-mono-pretrain-code/optim_factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim as optim 3 | from timm.optim.adafactor import Adafactor 4 | from timm.optim.adahessian import Adahessian 5 | from timm.optim.adamp import AdamP 6 | from timm.optim.lookahead import Lookahead 7 | from timm.optim.nadam import Nadam 8 | from timm.optim.novograd import NovoGrad 9 | from timm.optim.nvnovograd import NvNovoGrad 10 | from timm.optim.radam import RAdam 11 | from timm.optim.rmsprop_tf import RMSpropTF 12 | from timm.optim.sgdp import SGDP 13 | 14 | import json 15 | 16 | try: 17 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 18 | 19 | has_apex = True 20 | except ImportError: 21 | has_apex = False 22 | 23 | 24 | def get_num_layer_for_convnext(var_name): 25 | """ 26 | Divide [3, 3, 27, 3] layers into 12 groups; each group is three 27 | consecutive blocks, including possible neighboring downsample layers; 28 | adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py 29 | """ 30 | num_max_layer = 12 31 | if var_name.startswith("downsample_layers"): 32 | stage_id = int(var_name.split('.')[1]) 33 | if stage_id == 0: 34 | layer_id = 0 35 | elif stage_id == 1 or stage_id == 2: 36 | layer_id = stage_id + 1 37 | elif stage_id == 3: 38 | layer_id = 12 39 | return layer_id 40 | 41 | elif var_name.startswith("stages"): 42 | stage_id = int(var_name.split('.')[1]) 43 | block_id = int(var_name.split('.')[2]) 44 | if stage_id == 0 or stage_id == 1: 45 | layer_id = stage_id + 1 46 | elif stage_id == 2: 47 | layer_id = 3 + block_id // 3 48 | elif stage_id == 3: 49 | layer_id = 12 50 | return layer_id 51 | else: 52 | return num_max_layer + 1 53 | 54 | 55 | class LayerDecayValueAssigner(object): 56 | def __init__(self, values): 57 | self.values = values 58 | 59 | def get_scale(self, layer_id): 60 | return self.values[layer_id] 61 | 62 | def get_layer_id(self, var_name): 63 | return get_num_layer_for_convnext(var_name) 64 | 65 | 66 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): 67 | parameter_group_names = {} 68 | parameter_group_vars = {} 69 | 70 | for name, param in model.named_parameters(): 71 | if not param.requires_grad: 72 | continue # frozen weights 73 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 74 | group_name = "no_decay" 75 | this_weight_decay = 0. 76 | else: 77 | group_name = "decay" 78 | this_weight_decay = weight_decay 79 | if get_num_layer is not None: 80 | layer_id = get_num_layer(name) 81 | group_name = "layer_%d_%s" % (layer_id, group_name) 82 | else: 83 | layer_id = None 84 | 85 | if group_name not in parameter_group_names: 86 | if get_layer_scale is not None: 87 | scale = get_layer_scale(layer_id) 88 | else: 89 | scale = 1. 90 | 91 | parameter_group_names[group_name] = { 92 | "weight_decay": this_weight_decay, 93 | "params": [], 94 | "lr_scale": scale 95 | } 96 | parameter_group_vars[group_name] = { 97 | "weight_decay": this_weight_decay, 98 | "params": [], 99 | "lr_scale": scale 100 | } 101 | 102 | parameter_group_vars[group_name]["params"].append(param) 103 | parameter_group_names[group_name]["params"].append(name) 104 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 105 | return list(parameter_group_vars.values()) 106 | 107 | 108 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): 109 | opt_lower = args.opt.lower() 110 | weight_decay = args.weight_decay 111 | # if weight_decay and filter_bias_and_bn: 112 | if filter_bias_and_bn: 113 | skip = {} 114 | if skip_list is not None: 115 | skip = skip_list 116 | elif hasattr(model, 'no_weight_decay'): 117 | skip = model.no_weight_decay() 118 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) 119 | weight_decay = 0. 120 | else: 121 | parameters = model.parameters() 122 | 123 | if 'fused' in opt_lower: 124 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 125 | 126 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 127 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 128 | opt_args['eps'] = args.opt_eps 129 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 130 | opt_args['betas'] = args.opt_betas 131 | 132 | opt_split = opt_lower.split('_') 133 | opt_lower = opt_split[-1] 134 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 135 | opt_args.pop('eps', None) 136 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 137 | elif opt_lower == 'momentum': 138 | opt_args.pop('eps', None) 139 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 140 | elif opt_lower == 'adam': 141 | optimizer = optim.Adam(parameters, **opt_args) 142 | elif opt_lower == 'adamw': 143 | optimizer = optim.AdamW(parameters, **opt_args) 144 | elif opt_lower == 'nadam': 145 | optimizer = Nadam(parameters, **opt_args) 146 | elif opt_lower == 'radam': 147 | optimizer = RAdam(parameters, **opt_args) 148 | elif opt_lower == 'adamp': 149 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 150 | elif opt_lower == 'sgdp': 151 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 152 | elif opt_lower == 'adadelta': 153 | optimizer = optim.Adadelta(parameters, **opt_args) 154 | elif opt_lower == 'adafactor': 155 | if not args.lr: 156 | opt_args['lr'] = None 157 | optimizer = Adafactor(parameters, **opt_args) 158 | elif opt_lower == 'adahessian': 159 | optimizer = Adahessian(parameters, **opt_args) 160 | elif opt_lower == 'rmsprop': 161 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 162 | elif opt_lower == 'rmsproptf': 163 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 164 | elif opt_lower == 'novograd': 165 | optimizer = NovoGrad(parameters, **opt_args) 166 | elif opt_lower == 'nvnovograd': 167 | optimizer = NvNovoGrad(parameters, **opt_args) 168 | elif opt_lower == 'fusedsgd': 169 | opt_args.pop('eps', None) 170 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 171 | elif opt_lower == 'fusedmomentum': 172 | opt_args.pop('eps', None) 173 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 174 | elif opt_lower == 'fusedadam': 175 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 176 | elif opt_lower == 'fusedadamw': 177 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 178 | elif opt_lower == 'fusedlamb': 179 | optimizer = FusedLAMB(parameters, **opt_args) 180 | elif opt_lower == 'fusednovograd': 181 | opt_args.setdefault('betas', (0.95, 0.98)) 182 | optimizer = FusedNovoGrad(parameters, **opt_args) 183 | else: 184 | assert False and "Invalid optimizer" 185 | 186 | if len(opt_split) > 1: 187 | if opt_split[0] == 'lookahead': 188 | optimizer = Lookahead(optimizer) 189 | 190 | return optimizer 191 | -------------------------------------------------------------------------------- /lite-mono-pretrain-code/requirements.txt: -------------------------------------------------------------------------------- 1 | timm==0.4.12 2 | tensorboardX==2.2 3 | six==1.16.0 4 | fvcore==0.1.5.post20220414 5 | protobuf==3.20.* -------------------------------------------------------------------------------- /lite-mono-pretrain-code/sampler.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | import torch.distributed as dist 3 | import math 4 | import random 5 | import numpy as np 6 | from torchvision.datasets import ImageFolder 7 | from timm.data import create_transform 8 | from timm.data.constants import \ 9 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 10 | from torchvision import transforms 11 | from typing import Tuple 12 | from typing import Optional, Union 13 | 14 | 15 | class MultiScaleSamplerDDP(Sampler): 16 | def __init__(self, base_im_w: int, base_im_h: int, base_batch_size: int, n_data_samples: int, 17 | min_crop_size_w: int = 160, max_crop_size_w: int = 320, 18 | min_crop_size_h: int = 160, max_crop_size_h: int = 320, 19 | n_scales: int = 5, is_training: bool = True, distributed=True) -> None: 20 | # min. and max. spatial dimensions 21 | min_im_w, max_im_w = min_crop_size_w, max_crop_size_w 22 | min_im_h, max_im_h = min_crop_size_h, max_crop_size_h 23 | 24 | # Get the GPU and node related information 25 | if not distributed: 26 | num_replicas = 1 27 | rank = 0 28 | else: 29 | num_replicas = dist.get_world_size() 30 | rank = dist.get_rank() 31 | 32 | # adjust the total samples to avoid batch dropping 33 | num_samples_per_replica = int(math.ceil(n_data_samples * 1.0 / num_replicas)) 34 | total_size = num_samples_per_replica * num_replicas 35 | img_indices = [idx for idx in range(n_data_samples)] 36 | img_indices += img_indices[:(total_size - n_data_samples)] 37 | assert len(img_indices) == total_size 38 | 39 | self.shuffle = True if is_training else False 40 | if is_training: 41 | self.img_batch_pairs = _image_batch_pairs(base_im_w, base_im_h, base_batch_size, num_replicas, n_scales, 32, 42 | min_im_w, max_im_w, min_im_h, max_im_h) 43 | else: 44 | self.img_batch_pairs = [(base_im_h, base_im_w, base_batch_size)] 45 | 46 | self.img_indices = img_indices 47 | self.n_samples_per_replica = num_samples_per_replica 48 | self.epoch = 0 49 | self.rank = rank 50 | self.num_replicas = num_replicas 51 | self.batch_size_gpu0 = base_batch_size 52 | 53 | def __iter__(self): 54 | if self.shuffle: 55 | random.seed(self.epoch) 56 | random.shuffle(self.img_indices) 57 | random.shuffle(self.img_batch_pairs) 58 | indices_rank_i = self.img_indices[self.rank:len(self.img_indices):self.num_replicas] 59 | else: 60 | indices_rank_i = self.img_indices[self.rank:len(self.img_indices):self.num_replicas] 61 | 62 | start_index = 0 63 | while start_index < self.n_samples_per_replica: 64 | curr_h, curr_w, curr_bsz = random.choice(self.img_batch_pairs) 65 | 66 | end_index = min(start_index + curr_bsz, self.n_samples_per_replica) 67 | batch_ids = indices_rank_i[start_index:end_index] 68 | n_batch_samples = len(batch_ids) 69 | if n_batch_samples != curr_bsz: 70 | batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)] 71 | start_index += curr_bsz 72 | 73 | if len(batch_ids) > 0: 74 | batch = [(curr_h, curr_w, b_id) for b_id in batch_ids] 75 | yield batch 76 | 77 | def set_epoch(self, epoch: int) -> None: 78 | self.epoch = epoch 79 | 80 | def __len__(self): 81 | return self.n_samples_per_replica 82 | 83 | 84 | def _image_batch_pairs(crop_size_w: int, 85 | crop_size_h: int, 86 | batch_size_gpu0: int, 87 | n_gpus: int, 88 | max_scales: Optional[float] = 5, 89 | check_scale_div_factor: Optional[int] = 32, 90 | min_crop_size_w: Optional[int] = 160, 91 | max_crop_size_w: Optional[int] = 320, 92 | min_crop_size_h: Optional[int] = 160, 93 | max_crop_size_h: Optional[int] = 320, 94 | *args, **kwargs) -> list: 95 | """ 96 | This function creates batch and image size pairs. For a given batch size and image size, different image sizes 97 | are generated and batch size is adjusted so that GPU memory can be utilized efficiently. 98 | 99 | :param crop_size_w: Base Image width (e.g., 224) 100 | :param crop_size_h: Base Image height (e.g., 224) 101 | :param batch_size_gpu0: Batch size on GPU 0 for base image 102 | :param n_gpus: Number of available GPUs 103 | :param max_scales: Number of scales. How many image sizes that we want to generate between min and max scale factors. 104 | :param check_scale_div_factor: Check if image scales are divisible by this factor. 105 | :param min_crop_size_w: Min. crop size along width 106 | :param max_crop_size_w: Max. crop size along width 107 | :param min_crop_size_h: Min. crop size along height 108 | :param max_crop_size_h: Max. crop size along height 109 | :param args: 110 | :param kwargs: 111 | :return: a sorted list of tuples. Each index is of the form (h, w, batch_size) 112 | """ 113 | 114 | width_dims = list(np.linspace(min_crop_size_w, max_crop_size_w, max_scales)) 115 | if crop_size_w not in width_dims: 116 | width_dims.append(crop_size_w) 117 | 118 | height_dims = list(np.linspace(min_crop_size_h, max_crop_size_h, max_scales)) 119 | if crop_size_h not in height_dims: 120 | height_dims.append(crop_size_h) 121 | 122 | image_scales = set() 123 | 124 | for h, w in zip(height_dims, width_dims): 125 | # ensure that sampled sizes are divisible by check_scale_div_factor 126 | # This is important in some cases where input undergoes a fixed number of down-sampling stages 127 | # for instance, in ImageNet training, CNNs usually have 5 downsampling stages, which downsamples the 128 | # input image of resolution 224x224 to 7x7 size 129 | h = make_divisible(h, check_scale_div_factor) 130 | w = make_divisible(w, check_scale_div_factor) 131 | image_scales.add((h, w)) 132 | 133 | image_scales = list(image_scales) 134 | 135 | img_batch_tuples = set() 136 | n_elements = crop_size_w * crop_size_h * batch_size_gpu0 137 | for (crop_h, crop_y) in image_scales: 138 | # compute the batch size for sampled image resolutions with respect to the base resolution 139 | _bsz = max(batch_size_gpu0, int(round(n_elements/(crop_h * crop_y), 2))) 140 | 141 | _bsz = make_divisible(_bsz, n_gpus) 142 | _bsz = _bsz if _bsz % 2 == 0 else _bsz - 1 # Batch size must be even 143 | img_batch_tuples.add((crop_h, crop_y, _bsz)) 144 | 145 | img_batch_tuples = list(img_batch_tuples) 146 | return sorted(img_batch_tuples) 147 | 148 | 149 | def make_divisible(v: Union[float, int], 150 | divisor: Optional[int] = 8, 151 | min_value: Optional[Union[float, int]] = None) -> Union[float, int]: 152 | """ 153 | This function is taken from the original tf repo. 154 | It ensures that all layers have a channel number that is divisible by 8 155 | It can be seen here: 156 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 157 | :param v: 158 | :param divisor: 159 | :param min_value: 160 | :return: 161 | """ 162 | if min_value is None: 163 | min_value = divisor 164 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 165 | # Make sure that round down does not go down by more than 10%. 166 | if new_v < 0.9 * v: 167 | new_v += divisor 168 | return new_v 169 | 170 | 171 | class MultiScaleImageFolder(ImageFolder): 172 | def __init__(self, root, args) -> None: 173 | self.args = args 174 | ImageFolder.__init__(self, root=root, transform=None, target_transform=None, is_valid_file=None) 175 | 176 | def get_transforms(self, size: int): 177 | imagenet_default_mean_and_std = self.args.imagenet_default_mean_and_std 178 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN 179 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD 180 | resize_im = size > 32 181 | transform = create_transform( 182 | input_size=size, 183 | is_training=True, 184 | color_jitter=self.args.color_jitter, 185 | auto_augment=self.args.aa, 186 | interpolation=self.args.train_interpolation, 187 | re_prob=self.args.reprob, 188 | re_mode=self.args.remode, 189 | re_count=self.args.recount, 190 | mean=mean, 191 | std=std, 192 | ) 193 | if not resize_im: 194 | transform.transforms[0] = transforms.RandomCrop(size, padding=4) 195 | 196 | return transform 197 | 198 | def __getitem__(self, batch_indexes_tup: Tuple): 199 | crop_size_h, crop_size_w, img_index = batch_indexes_tup 200 | transforms = self.get_transforms(size=int(crop_size_w)) 201 | 202 | path, target = self.samples[img_index] 203 | sample = self.loader(path) 204 | if transforms is not None: 205 | sample = transforms(sample) 206 | if self.target_transform is not None: 207 | target = self.target_transform(target) 208 | 209 | return sample, target 210 | 211 | def __len__(self): 212 | return len(self.samples) 213 | -------------------------------------------------------------------------------- /lite-mono-pretrain-code/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | from collections import defaultdict, deque 5 | import datetime 6 | import numpy as np 7 | from timm.utils import get_state_dict 8 | 9 | from pathlib import Path 10 | 11 | import torch 12 | import torch.distributed as dist 13 | from torch._six import inf 14 | 15 | from tensorboardX import SummaryWriter 16 | 17 | import subprocess 18 | 19 | 20 | class SmoothedValue(object): 21 | """Track a series of values and provide access to smoothed values over a 22 | window or the global series average. 23 | """ 24 | 25 | def __init__(self, window_size=20, fmt=None): 26 | if fmt is None: 27 | fmt = "{median:.4f} ({global_avg:.4f})" 28 | self.deque = deque(maxlen=window_size) 29 | self.total = 0.0 30 | self.count = 0 31 | self.fmt = fmt 32 | 33 | def update(self, value, n=1): 34 | self.deque.append(value) 35 | self.count += n 36 | self.total += value * n 37 | 38 | def synchronize_between_processes(self): 39 | """ 40 | Warning: does not synchronize the deque! 41 | """ 42 | if not is_dist_avail_and_initialized(): 43 | return 44 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 45 | dist.barrier() 46 | dist.all_reduce(t) 47 | t = t.tolist() 48 | self.count = int(t[0]) 49 | self.total = t[1] 50 | 51 | @property 52 | def median(self): 53 | d = torch.tensor(list(self.deque)) 54 | return d.median().item() 55 | 56 | @property 57 | def avg(self): 58 | d = torch.tensor(list(self.deque), dtype=torch.float32) 59 | return d.mean().item() 60 | 61 | @property 62 | def global_avg(self): 63 | return self.total / self.count 64 | 65 | @property 66 | def max(self): 67 | return max(self.deque) 68 | 69 | @property 70 | def value(self): 71 | return self.deque[-1] 72 | 73 | def __str__(self): 74 | return self.fmt.format( 75 | median=self.median, 76 | avg=self.avg, 77 | global_avg=self.global_avg, 78 | max=self.max, 79 | value=self.value) 80 | 81 | 82 | class MetricLogger(object): 83 | def __init__(self, delimiter="\t"): 84 | self.meters = defaultdict(SmoothedValue) 85 | self.delimiter = delimiter 86 | 87 | def update(self, **kwargs): 88 | for k, v in kwargs.items(): 89 | if v is None: 90 | continue 91 | if isinstance(v, torch.Tensor): 92 | v = v.item() 93 | assert isinstance(v, (float, int)) 94 | self.meters[k].update(v) 95 | 96 | def __getattr__(self, attr): 97 | if attr in self.meters: 98 | return self.meters[attr] 99 | if attr in self.__dict__: 100 | return self.__dict__[attr] 101 | raise AttributeError("'{}' object has no attribute '{}'".format( 102 | type(self).__name__, attr)) 103 | 104 | def __str__(self): 105 | loss_str = [] 106 | for name, meter in self.meters.items(): 107 | loss_str.append( 108 | "{}: {}".format(name, str(meter)) 109 | ) 110 | return self.delimiter.join(loss_str) 111 | 112 | def synchronize_between_processes(self): 113 | for meter in self.meters.values(): 114 | meter.synchronize_between_processes() 115 | 116 | def add_meter(self, name, meter): 117 | self.meters[name] = meter 118 | 119 | def log_every(self, iterable, print_freq, header=None): 120 | i = 0 121 | if not header: 122 | header = '' 123 | start_time = time.time() 124 | end = time.time() 125 | iter_time = SmoothedValue(fmt='{avg:.4f}') 126 | data_time = SmoothedValue(fmt='{avg:.4f}') 127 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 128 | log_msg = [ 129 | header, 130 | '[{0' + space_fmt + '}/{1}]', 131 | 'eta: {eta}', 132 | '{meters}', 133 | 'time: {time}', 134 | 'data: {data}' 135 | ] 136 | if torch.cuda.is_available(): 137 | log_msg.append('max mem: {memory:.0f}') 138 | log_msg = self.delimiter.join(log_msg) 139 | MB = 1024.0 * 1024.0 140 | for obj in iterable: 141 | data_time.update(time.time() - end) 142 | yield obj 143 | iter_time.update(time.time() - end) 144 | if i % print_freq == 0 or i == len(iterable) - 1: 145 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 146 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 147 | if torch.cuda.is_available(): 148 | print(log_msg.format( 149 | i, len(iterable), eta=eta_string, 150 | meters=str(self), 151 | time=str(iter_time), data=str(data_time), 152 | memory=torch.cuda.max_memory_allocated() / MB)) 153 | else: 154 | print(log_msg.format( 155 | i, len(iterable), eta=eta_string, 156 | meters=str(self), 157 | time=str(iter_time), data=str(data_time))) 158 | i += 1 159 | end = time.time() 160 | total_time = time.time() - start_time 161 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 162 | print('{} Total time: {} ({:.4f} s / it)'.format( 163 | header, total_time_str, total_time / len(iterable))) 164 | 165 | 166 | class TensorboardLogger(object): 167 | def __init__(self, log_dir): 168 | self.writer = SummaryWriter(logdir=log_dir) 169 | self.step = 0 170 | 171 | def set_step(self, step=None): 172 | if step is not None: 173 | self.step = step 174 | else: 175 | self.step += 1 176 | 177 | def update(self, head='scalar', step=None, **kwargs): 178 | for k, v in kwargs.items(): 179 | if v is None: 180 | continue 181 | if isinstance(v, torch.Tensor): 182 | v = v.item() 183 | assert isinstance(v, (float, int)) 184 | self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step) 185 | 186 | def flush(self): 187 | self.writer.flush() 188 | 189 | 190 | class WandbLogger(object): 191 | def __init__(self, args): 192 | self.args = args 193 | 194 | try: 195 | import wandb 196 | self._wandb = wandb 197 | except ImportError: 198 | raise ImportError( 199 | "To use the Weights and Biases Logger please install wandb." 200 | "Run `pip install wandb` to install it." 201 | ) 202 | 203 | # Initialize a W&B run 204 | if self._wandb.run is None: 205 | self._wandb.init( 206 | project=args.project, 207 | config=args 208 | ) 209 | 210 | def log_epoch_metrics(self, metrics, commit=True): 211 | """ 212 | Log train/test metrics onto W&B. 213 | """ 214 | # Log number of model parameters as W&B summary 215 | self._wandb.summary['n_parameters'] = metrics.get('n_parameters', None) 216 | metrics.pop('n_parameters', None) 217 | 218 | # Log current epoch 219 | self._wandb.log({'epoch': metrics.get('epoch')}, commit=False) 220 | metrics.pop('epoch') 221 | 222 | for k, v in metrics.items(): 223 | if 'train' in k: 224 | self._wandb.log({f'Global Train/{k}': v}, commit=False) 225 | elif 'test' in k: 226 | self._wandb.log({f'Global Test/{k}': v}, commit=False) 227 | 228 | self._wandb.log({}) 229 | 230 | def log_checkpoints(self): 231 | output_dir = self.args.output_dir 232 | model_artifact = self._wandb.Artifact( 233 | self._wandb.run.id + "_model", type="model" 234 | ) 235 | 236 | model_artifact.add_dir(output_dir) 237 | self._wandb.log_artifact(model_artifact, aliases=["latest", "best"]) 238 | 239 | def set_steps(self): 240 | # Set global training step 241 | self._wandb.define_metric('Rank-0 Batch Wise/*', step_metric='Rank-0 Batch Wise/global_train_step') 242 | # Set epoch-wise step 243 | self._wandb.define_metric('Global Train/*', step_metric='epoch') 244 | self._wandb.define_metric('Global Test/*', step_metric='epoch') 245 | 246 | 247 | def setup_for_distributed(is_master): 248 | """ 249 | This function disables printing when not in master process 250 | """ 251 | import builtins as __builtin__ 252 | builtin_print = __builtin__.print 253 | 254 | def print(*args, **kwargs): 255 | force = kwargs.pop('force', False) 256 | if is_master or force: 257 | builtin_print(*args, **kwargs) 258 | 259 | __builtin__.print = print 260 | 261 | 262 | def is_dist_avail_and_initialized(): 263 | if not dist.is_available(): 264 | return False 265 | if not dist.is_initialized(): 266 | return False 267 | return True 268 | 269 | 270 | def get_world_size(): 271 | if not is_dist_avail_and_initialized(): 272 | return 1 273 | return dist.get_world_size() 274 | 275 | 276 | def get_rank(): 277 | if not is_dist_avail_and_initialized(): 278 | return 0 279 | return dist.get_rank() 280 | 281 | 282 | def is_main_process(): 283 | return get_rank() == 0 284 | 285 | 286 | def save_on_master(*args, **kwargs): 287 | if is_main_process(): 288 | torch.save(*args, **kwargs) 289 | 290 | 291 | def init_distributed_mode(args): 292 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 293 | args.rank = int(os.environ["RANK"]) 294 | args.world_size = int(os.environ['WORLD_SIZE']) 295 | args.gpu = int(os.environ['LOCAL_RANK']) 296 | args.dist_url = 'env://' 297 | os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) 298 | print('Using distributed mode: 1') 299 | elif 'SLURM_PROCID' in os.environ: 300 | proc_id = int(os.environ['SLURM_PROCID']) 301 | ntasks = int(os.environ['SLURM_NTASKS']) 302 | node_list = os.environ['SLURM_NODELIST'] 303 | num_gpus = torch.cuda.device_count() 304 | addr = subprocess.getoutput( 305 | 'scontrol show hostname {} | head -n1'.format(node_list)) 306 | os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500') 307 | os.environ['MASTER_ADDR'] = addr 308 | os.environ['WORLD_SIZE'] = str(ntasks) 309 | os.environ['RANK'] = str(proc_id) 310 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 311 | os.environ['LOCAL_SIZE'] = str(num_gpus) 312 | args.dist_url = 'env://' 313 | args.world_size = ntasks 314 | args.rank = proc_id 315 | args.gpu = proc_id % num_gpus 316 | print('Using distributed mode: slurm') 317 | print(f"world: {os.environ['WORLD_SIZE']}, rank:{os.environ['RANK']}," 318 | f" local_rank{os.environ['LOCAL_RANK']}, local_size{os.environ['LOCAL_SIZE']}") 319 | else: 320 | print('Not using distributed mode') 321 | args.distributed = False 322 | return 323 | 324 | args.distributed = True 325 | 326 | torch.cuda.set_device(args.gpu) 327 | args.dist_backend = 'nccl' 328 | print('| distributed init (rank {}): {}'.format( 329 | args.rank, args.dist_url), flush=True) 330 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 331 | world_size=args.world_size, rank=args.rank) 332 | torch.distributed.barrier() 333 | setup_for_distributed(args.rank == 0) 334 | 335 | 336 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"): 337 | missing_keys = [] 338 | unexpected_keys = [] 339 | error_msgs = [] 340 | # copy state_dict so _load_from_state_dict can modify it 341 | metadata = getattr(state_dict, '_metadata', None) 342 | state_dict = state_dict.copy() 343 | if metadata is not None: 344 | state_dict._metadata = metadata 345 | 346 | def load(module, prefix=''): 347 | local_metadata = {} if metadata is None else metadata.get( 348 | prefix[:-1], {}) 349 | module._load_from_state_dict( 350 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 351 | for name, child in module._modules.items(): 352 | if child is not None: 353 | load(child, prefix + name + '.') 354 | 355 | load(model, prefix=prefix) 356 | 357 | warn_missing_keys = [] 358 | ignore_missing_keys = [] 359 | for key in missing_keys: 360 | keep_flag = True 361 | for ignore_key in ignore_missing.split('|'): 362 | if ignore_key in key: 363 | keep_flag = False 364 | break 365 | if keep_flag: 366 | warn_missing_keys.append(key) 367 | else: 368 | ignore_missing_keys.append(key) 369 | 370 | missing_keys = warn_missing_keys 371 | 372 | if len(missing_keys) > 0: 373 | print("Weights of {} not initialized from pretrained model: {}".format( 374 | model.__class__.__name__, missing_keys)) 375 | if len(unexpected_keys) > 0: 376 | print("Weights from pretrained model not used in {}: {}".format( 377 | model.__class__.__name__, unexpected_keys)) 378 | if len(ignore_missing_keys) > 0: 379 | print("Ignored weights of {} not initialized from pretrained model: {}".format( 380 | model.__class__.__name__, ignore_missing_keys)) 381 | if len(error_msgs) > 0: 382 | print('\n'.join(error_msgs)) 383 | 384 | 385 | class NativeScalerWithGradNormCount: 386 | state_dict_key = "amp_scaler" 387 | 388 | def __init__(self): 389 | self._scaler = torch.cuda.amp.GradScaler() 390 | 391 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 392 | self._scaler.scale(loss).backward(create_graph=create_graph) 393 | if update_grad: 394 | if clip_grad is not None: 395 | assert parameters is not None 396 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 397 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 398 | else: 399 | self._scaler.unscale_(optimizer) 400 | norm = get_grad_norm_(parameters) 401 | self._scaler.step(optimizer) 402 | self._scaler.update() 403 | else: 404 | norm = None 405 | return norm 406 | 407 | def state_dict(self): 408 | return self._scaler.state_dict() 409 | 410 | def load_state_dict(self, state_dict): 411 | self._scaler.load_state_dict(state_dict) 412 | 413 | 414 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 415 | if isinstance(parameters, torch.Tensor): 416 | parameters = [parameters] 417 | parameters = [p for p in parameters if p.grad is not None] 418 | norm_type = float(norm_type) 419 | if len(parameters) == 0: 420 | return torch.tensor(0.) 421 | device = parameters[0].grad.device 422 | if norm_type == inf: 423 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 424 | else: 425 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 426 | return total_norm 427 | 428 | 429 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, 430 | start_warmup_value=0, warmup_steps=-1): 431 | warmup_schedule = np.array([]) 432 | warmup_iters = warmup_epochs * niter_per_ep 433 | if warmup_steps > 0: 434 | warmup_iters = warmup_steps 435 | print("Set warmup steps = %d" % warmup_iters) 436 | if warmup_epochs > 0: 437 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 438 | 439 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 440 | schedule = np.array( 441 | [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters]) 442 | 443 | schedule = np.concatenate((warmup_schedule, schedule)) 444 | 445 | assert len(schedule) == epochs * niter_per_ep 446 | return schedule 447 | 448 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 449 | output_dir = Path(args.output_dir) 450 | epoch_name = str(epoch) 451 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 452 | for checkpoint_path in checkpoint_paths: 453 | to_save = { 454 | 'model': model_without_ddp.state_dict(), 455 | 'optimizer': optimizer.state_dict(), 456 | 'epoch': epoch, 457 | 'scaler': loss_scaler.state_dict(), 458 | 'args': args, 459 | } 460 | 461 | if model_ema is not None: 462 | to_save['model_ema'] = get_state_dict(model_ema) 463 | 464 | save_on_master(to_save, checkpoint_path) 465 | 466 | if is_main_process() and isinstance(epoch, int): 467 | to_del = epoch - args.save_ckpt_num * args.save_ckpt_freq 468 | old_ckpt = output_dir / ('checkpoint-%s.pth' % to_del) 469 | if os.path.exists(old_ckpt): 470 | os.remove(old_ckpt) 471 | 472 | 473 | def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None, state_dict_name='model'): 474 | output_dir = Path(args.output_dir) 475 | if args.auto_resume and len(args.resume) == 0: 476 | import glob 477 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth')) 478 | latest_ckpt = -1 479 | for ckpt in all_checkpoints: 480 | t = ckpt.split('-')[-1].split('.')[0] 481 | if t.isdigit(): 482 | latest_ckpt = max(int(t), latest_ckpt) 483 | if latest_ckpt >= 0: 484 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt) 485 | print("Auto resume checkpoint: %s" % args.resume) 486 | 487 | if args.resume: 488 | if args.resume.startswith('https'): 489 | checkpoint = torch.hub.load_state_dict_from_url( 490 | args.resume, map_location='cpu', check_hash=True) 491 | else: 492 | checkpoint = torch.load(args.resume, map_location='cpu') 493 | model_without_ddp.load_state_dict(checkpoint[state_dict_name]) 494 | print("Resume checkpoint %s" % args.resume) 495 | if 'optimizer' in checkpoint and 'epoch' in checkpoint: 496 | optimizer.load_state_dict(checkpoint['optimizer']) 497 | if not isinstance(checkpoint['epoch'], str): # does not support resuming with 'best', 'best-ema' 498 | args.start_epoch = checkpoint['epoch'] + 1 499 | else: 500 | assert args.eval, 'Does not support resuming with checkpoint-best' 501 | if hasattr(args, 'model_ema') and args.model_ema: 502 | if 'model_ema' in checkpoint.keys(): 503 | model_ema.ema.load_state_dict(checkpoint['model_ema']) 504 | else: 505 | model_ema.ema.load_state_dict(checkpoint['model']) 506 | if 'scaler' in checkpoint: 507 | loss_scaler.load_state_dict(checkpoint['scaler']) 508 | print("With optim & sched!") 509 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet_encoder import ResnetEncoder 2 | from .pose_decoder import PoseDecoder 3 | from .depth_decoder import DepthDecoder 4 | from .depth_encoder import LiteMono 5 | -------------------------------------------------------------------------------- /networks/depth_decoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | from collections import OrderedDict 3 | from layers import * 4 | from timm.models.layers import trunc_normal_ 5 | 6 | 7 | class DepthDecoder(nn.Module): 8 | def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True): 9 | super().__init__() 10 | 11 | self.num_output_channels = num_output_channels 12 | self.use_skips = use_skips 13 | self.upsample_mode = 'bilinear' 14 | self.scales = scales 15 | 16 | self.num_ch_enc = num_ch_enc 17 | self.num_ch_dec = (self.num_ch_enc / 2).astype('int') 18 | 19 | # decoder 20 | self.convs = OrderedDict() 21 | for i in range(2, -1, -1): 22 | # upconv_0 23 | num_ch_in = self.num_ch_enc[-1] if i == 2 else self.num_ch_dec[i + 1] 24 | num_ch_out = self.num_ch_dec[i] 25 | self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out) 26 | # print(i, num_ch_in, num_ch_out) 27 | # upconv_1 28 | num_ch_in = self.num_ch_dec[i] 29 | if self.use_skips and i > 0: 30 | num_ch_in += self.num_ch_enc[i - 1] 31 | num_ch_out = self.num_ch_dec[i] 32 | self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out) 33 | 34 | for s in self.scales: 35 | self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels) 36 | 37 | self.decoder = nn.ModuleList(list(self.convs.values())) 38 | self.sigmoid = nn.Sigmoid() 39 | 40 | self.apply(self._init_weights) 41 | 42 | def _init_weights(self, m): 43 | if isinstance(m, (nn.Conv2d, nn.Linear)): 44 | trunc_normal_(m.weight, std=.02) 45 | if m.bias is not None: 46 | nn.init.constant_(m.bias, 0) 47 | 48 | def forward(self, input_features): 49 | self.outputs = {} 50 | x = input_features[-1] 51 | for i in range(2, -1, -1): 52 | x = self.convs[("upconv", i, 0)](x) 53 | x = [upsample(x)] 54 | 55 | if self.use_skips and i > 0: 56 | x += [input_features[i - 1]] 57 | x = torch.cat(x, 1) 58 | x = self.convs[("upconv", i, 1)](x) 59 | 60 | if i in self.scales: 61 | f = upsample(self.convs[("dispconv", i)](x), mode='bilinear') 62 | self.outputs[("disp", i)] = self.sigmoid(f) 63 | 64 | return self.outputs 65 | 66 | -------------------------------------------------------------------------------- /networks/depth_encoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from timm.models.layers import DropPath 6 | import math 7 | import torch.cuda 8 | 9 | 10 | class PositionalEncodingFourier(nn.Module): 11 | """ 12 | Positional encoding relying on a fourier kernel matching the one used in the 13 | "Attention is all of Need" paper. The implementation builds on DeTR code 14 | https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py 15 | """ 16 | 17 | def __init__(self, hidden_dim=32, dim=768, temperature=10000): 18 | super().__init__() 19 | self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1) 20 | self.scale = 2 * math.pi 21 | self.temperature = temperature 22 | self.hidden_dim = hidden_dim 23 | self.dim = dim 24 | 25 | def forward(self, B, H, W): 26 | mask = torch.zeros(B, H, W).bool().to(self.token_projection.weight.device) 27 | not_mask = ~mask 28 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 29 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 30 | eps = 1e-6 31 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 32 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 33 | 34 | dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=mask.device) 35 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.hidden_dim) 36 | 37 | pos_x = x_embed[:, :, :, None] / dim_t 38 | pos_y = y_embed[:, :, :, None] / dim_t 39 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), 40 | pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 41 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), 42 | pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 43 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 44 | pos = self.token_projection(pos) 45 | return pos 46 | 47 | 48 | class XCA(nn.Module): 49 | """ Cross-Covariance Attention (XCA) operation where the channels are updated using a weighted 50 | sum. The weights are obtained from the (softmax normalized) Cross-covariance 51 | matrix (Q^T K \\in d_h \\times d_h) 52 | """ 53 | 54 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 55 | super().__init__() 56 | self.num_heads = num_heads 57 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 58 | 59 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 60 | self.attn_drop = nn.Dropout(attn_drop) 61 | self.proj = nn.Linear(dim, dim) 62 | self.proj_drop = nn.Dropout(proj_drop) 63 | 64 | def forward(self, x): 65 | B, N, C = x.shape 66 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 67 | qkv = qkv.permute(2, 0, 3, 1, 4) 68 | q, k, v = qkv[0], qkv[1], qkv[2] 69 | 70 | q = q.transpose(-2, -1) 71 | k = k.transpose(-2, -1) 72 | v = v.transpose(-2, -1) 73 | 74 | q = torch.nn.functional.normalize(q, dim=-1) 75 | k = torch.nn.functional.normalize(k, dim=-1) 76 | 77 | attn = (q @ k.transpose(-2, -1)) * self.temperature 78 | attn = attn.softmax(dim=-1) 79 | attn = self.attn_drop(attn) 80 | 81 | x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) 82 | x = self.proj(x) 83 | x = self.proj_drop(x) 84 | return x 85 | 86 | @torch.jit.ignore 87 | def no_weight_decay(self): 88 | return {'temperature'} 89 | 90 | 91 | class LayerNorm(nn.Module): 92 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 93 | super().__init__() 94 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 95 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 96 | self.eps = eps 97 | self.data_format = data_format 98 | if self.data_format not in ["channels_last", "channels_first"]: 99 | raise NotImplementedError 100 | self.normalized_shape = (normalized_shape,) 101 | 102 | 103 | def forward(self, x): 104 | if self.data_format == "channels_last": 105 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 106 | elif self.data_format == "channels_first": 107 | u = x.mean(1, keepdim=True) 108 | s = (x - u).pow(2).mean(1, keepdim=True) 109 | x = (x - u) / torch.sqrt(s + self.eps) 110 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 111 | return x 112 | 113 | 114 | class BNGELU(nn.Module): 115 | def __init__(self, nIn): 116 | super().__init__() 117 | self.bn = nn.BatchNorm2d(nIn, eps=1e-5) 118 | self.act = nn.GELU() 119 | 120 | def forward(self, x): 121 | output = self.bn(x) 122 | output = self.act(output) 123 | 124 | return output 125 | 126 | 127 | class Conv(nn.Module): 128 | def __init__(self, nIn, nOut, kSize, stride, padding=0, dilation=(1, 1), groups=1, bn_act=False, bias=False): 129 | super().__init__() 130 | 131 | self.bn_act = bn_act 132 | 133 | self.conv = nn.Conv2d(nIn, nOut, kernel_size=kSize, 134 | stride=stride, padding=padding, 135 | dilation=dilation, groups=groups, bias=bias) 136 | 137 | if self.bn_act: 138 | self.bn_gelu = BNGELU(nOut) 139 | 140 | def forward(self, x): 141 | output = self.conv(x) 142 | 143 | if self.bn_act: 144 | output = self.bn_gelu(output) 145 | 146 | return output 147 | 148 | 149 | class CDilated(nn.Module): 150 | """ 151 | This class defines the dilated convolution. 152 | """ 153 | 154 | def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1, bias=False): 155 | """ 156 | :param nIn: number of input channels 157 | :param nOut: number of output channels 158 | :param kSize: kernel size 159 | :param stride: optional stride rate for down-sampling 160 | :param d: optional dilation rate 161 | """ 162 | super().__init__() 163 | padding = int((kSize - 1) / 2) * d 164 | self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=bias, 165 | dilation=d, groups=groups) 166 | 167 | def forward(self, input): 168 | """ 169 | :param input: input feature map 170 | :return: transformed feature map 171 | """ 172 | 173 | output = self.conv(input) 174 | return output 175 | 176 | 177 | class DilatedConv(nn.Module): 178 | """ 179 | A single Dilated Convolution layer in the Consecutive Dilated Convolutions (CDC) module. 180 | """ 181 | def __init__(self, dim, k, dilation=1, stride=1, drop_path=0., 182 | layer_scale_init_value=1e-6, expan_ratio=6): 183 | """ 184 | :param dim: input dimension 185 | :param k: kernel size 186 | :param dilation: dilation rate 187 | :param drop_path: drop_path rate 188 | :param layer_scale_init_value: 189 | :param expan_ratio: inverted bottelneck residual 190 | """ 191 | 192 | super().__init__() 193 | 194 | self.ddwconv = CDilated(dim, dim, kSize=k, stride=stride, groups=dim, d=dilation) 195 | self.bn1 = nn.BatchNorm2d(dim) 196 | 197 | self.norm = LayerNorm(dim, eps=1e-6) 198 | self.pwconv1 = nn.Linear(dim, expan_ratio * dim) 199 | self.act = nn.GELU() 200 | self.pwconv2 = nn.Linear(expan_ratio * dim, dim) 201 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim), 202 | requires_grad=True) if layer_scale_init_value > 0 else None 203 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 204 | 205 | def forward(self, x): 206 | input = x 207 | 208 | x = self.ddwconv(x) 209 | x = self.bn1(x) 210 | 211 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 212 | x = self.pwconv1(x) 213 | x = self.act(x) 214 | x = self.pwconv2(x) 215 | if self.gamma is not None: 216 | x = self.gamma * x 217 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 218 | 219 | x = input + self.drop_path(x) 220 | 221 | return x 222 | 223 | 224 | class LGFI(nn.Module): 225 | """ 226 | Local-Global Features Interaction 227 | """ 228 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, expan_ratio=6, 229 | use_pos_emb=True, num_heads=6, qkv_bias=True, attn_drop=0., drop=0.): 230 | super().__init__() 231 | 232 | self.dim = dim 233 | self.pos_embd = None 234 | if use_pos_emb: 235 | self.pos_embd = PositionalEncodingFourier(dim=self.dim) 236 | 237 | self.norm_xca = LayerNorm(self.dim, eps=1e-6) 238 | 239 | self.gamma_xca = nn.Parameter(layer_scale_init_value * torch.ones(self.dim), 240 | requires_grad=True) if layer_scale_init_value > 0 else None 241 | self.xca = XCA(self.dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 242 | 243 | self.norm = LayerNorm(self.dim, eps=1e-6) 244 | self.pwconv1 = nn.Linear(self.dim, expan_ratio * self.dim) 245 | self.act = nn.GELU() 246 | self.pwconv2 = nn.Linear(expan_ratio * self.dim, self.dim) 247 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((self.dim)), 248 | requires_grad=True) if layer_scale_init_value > 0 else None 249 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 250 | 251 | def forward(self, x): 252 | input_ = x 253 | 254 | # XCA 255 | B, C, H, W = x.shape 256 | x = x.reshape(B, C, H * W).permute(0, 2, 1) 257 | 258 | if self.pos_embd: 259 | pos_encoding = self.pos_embd(B, H, W).reshape(B, -1, x.shape[1]).permute(0, 2, 1) 260 | x = x + pos_encoding 261 | 262 | x = x + self.gamma_xca * self.xca(self.norm_xca(x)) 263 | 264 | x = x.reshape(B, H, W, C) 265 | 266 | # Inverted Bottleneck 267 | x = self.norm(x) 268 | x = self.pwconv1(x) 269 | x = self.act(x) 270 | x = self.pwconv2(x) 271 | if self.gamma is not None: 272 | x = self.gamma * x 273 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 274 | 275 | x = input_ + self.drop_path(x) 276 | 277 | return x 278 | 279 | 280 | class AvgPool(nn.Module): 281 | def __init__(self, ratio): 282 | super().__init__() 283 | self.pool = nn.ModuleList() 284 | for i in range(0, ratio): 285 | self.pool.append(nn.AvgPool2d(3, stride=2, padding=1)) 286 | 287 | def forward(self, x): 288 | for pool in self.pool: 289 | x = pool(x) 290 | 291 | return x 292 | 293 | 294 | class LiteMono(nn.Module): 295 | """ 296 | Lite-Mono 297 | """ 298 | def __init__(self, in_chans=3, model='lite-mono', height=192, width=640, 299 | global_block=[1, 1, 1], global_block_type=['LGFI', 'LGFI', 'LGFI'], 300 | drop_path_rate=0.2, layer_scale_init_value=1e-6, expan_ratio=6, 301 | heads=[8, 8, 8], use_pos_embd_xca=[True, False, False], **kwargs): 302 | 303 | super().__init__() 304 | 305 | if model == 'lite-mono': 306 | self.num_ch_enc = np.array([48, 80, 128]) 307 | self.depth = [4, 4, 10] 308 | self.dims = [48, 80, 128] 309 | if height == 192 and width == 640: 310 | self.dilation = [[1, 2, 3], [1, 2, 3], [1, 2, 3, 1, 2, 3, 2, 4, 6]] 311 | elif height == 320 and width == 1024: 312 | self.dilation = [[1, 2, 5], [1, 2, 5], [1, 2, 5, 1, 2, 5, 2, 4, 10]] 313 | 314 | elif model == 'lite-mono-small': 315 | self.num_ch_enc = np.array([48, 80, 128]) 316 | self.depth = [4, 4, 7] 317 | self.dims = [48, 80, 128] 318 | if height == 192 and width == 640: 319 | self.dilation = [[1, 2, 3], [1, 2, 3], [1, 2, 3, 2, 4, 6]] 320 | elif height == 320 and width == 1024: 321 | self.dilation = [[1, 2, 5], [1, 2, 5], [1, 2, 5, 2, 4, 10]] 322 | 323 | elif model == 'lite-mono-tiny': 324 | self.num_ch_enc = np.array([32, 64, 128]) 325 | self.depth = [4, 4, 7] 326 | self.dims = [32, 64, 128] 327 | if height == 192 and width == 640: 328 | self.dilation = [[1, 2, 3], [1, 2, 3], [1, 2, 3, 2, 4, 6]] 329 | elif height == 320 and width == 1024: 330 | self.dilation = [[1, 2, 5], [1, 2, 5], [1, 2, 5, 2, 4, 10]] 331 | 332 | elif model == 'lite-mono-8m': 333 | self.num_ch_enc = np.array([64, 128, 224]) 334 | self.depth = [4, 4, 10] 335 | self.dims = [64, 128, 224] 336 | if height == 192 and width == 640: 337 | self.dilation = [[1, 2, 3], [1, 2, 3], [1, 2, 3, 1, 2, 3, 2, 4, 6]] 338 | elif height == 320 and width == 1024: 339 | self.dilation = [[1, 2, 3], [1, 2, 3], [1, 2, 3, 1, 2, 3, 2, 4, 6]] 340 | 341 | for g in global_block_type: 342 | assert g in ['None', 'LGFI'] 343 | 344 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 345 | stem1 = nn.Sequential( 346 | Conv(in_chans, self.dims[0], kSize=3, stride=2, padding=1, bn_act=True), 347 | Conv(self.dims[0], self.dims[0], kSize=3, stride=1, padding=1, bn_act=True), 348 | Conv(self.dims[0], self.dims[0], kSize=3, stride=1, padding=1, bn_act=True), 349 | ) 350 | 351 | self.stem2 = nn.Sequential( 352 | Conv(self.dims[0]+3, self.dims[0], kSize=3, stride=2, padding=1, bn_act=False), 353 | ) 354 | 355 | self.downsample_layers.append(stem1) 356 | 357 | self.input_downsample = nn.ModuleList() 358 | for i in range(1, 5): 359 | self.input_downsample.append(AvgPool(i)) 360 | 361 | for i in range(2): 362 | downsample_layer = nn.Sequential( 363 | Conv(self.dims[i]*2+3, self.dims[i+1], kSize=3, stride=2, padding=1, bn_act=False), 364 | ) 365 | self.downsample_layers.append(downsample_layer) 366 | 367 | self.stages = nn.ModuleList() 368 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depth))] 369 | cur = 0 370 | for i in range(3): 371 | stage_blocks = [] 372 | for j in range(self.depth[i]): 373 | if j > self.depth[i] - global_block[i] - 1: 374 | if global_block_type[i] == 'LGFI': 375 | stage_blocks.append(LGFI(dim=self.dims[i], drop_path=dp_rates[cur + j], 376 | expan_ratio=expan_ratio, 377 | use_pos_emb=use_pos_embd_xca[i], num_heads=heads[i], 378 | layer_scale_init_value=layer_scale_init_value, 379 | )) 380 | 381 | else: 382 | raise NotImplementedError 383 | else: 384 | stage_blocks.append(DilatedConv(dim=self.dims[i], k=3, dilation=self.dilation[i][j], drop_path=dp_rates[cur + j], 385 | layer_scale_init_value=layer_scale_init_value, 386 | expan_ratio=expan_ratio)) 387 | 388 | self.stages.append(nn.Sequential(*stage_blocks)) 389 | cur += self.depth[i] 390 | 391 | self.apply(self._init_weights) 392 | 393 | def _init_weights(self, m): 394 | if isinstance(m, (nn.Conv2d, nn.Linear)): 395 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 396 | 397 | elif isinstance(m, (LayerNorm, nn.LayerNorm)): 398 | nn.init.constant_(m.bias, 0) 399 | nn.init.constant_(m.weight, 1.0) 400 | 401 | elif isinstance(m, nn.BatchNorm2d): 402 | nn.init.constant_(m.weight, 1) 403 | nn.init.constant_(m.bias, 0) 404 | 405 | def forward_features(self, x): 406 | features = [] 407 | x = (x - 0.45) / 0.225 408 | 409 | x_down = [] 410 | for i in range(4): 411 | x_down.append(self.input_downsample[i](x)) 412 | 413 | tmp_x = [] 414 | x = self.downsample_layers[0](x) 415 | x = self.stem2(torch.cat((x, x_down[0]), dim=1)) 416 | tmp_x.append(x) 417 | 418 | for s in range(len(self.stages[0])-1): 419 | x = self.stages[0][s](x) 420 | x = self.stages[0][-1](x) 421 | tmp_x.append(x) 422 | features.append(x) 423 | 424 | for i in range(1, 3): 425 | tmp_x.append(x_down[i]) 426 | x = torch.cat(tmp_x, dim=1) 427 | x = self.downsample_layers[i](x) 428 | 429 | tmp_x = [x] 430 | for s in range(len(self.stages[i]) - 1): 431 | x = self.stages[i][s](x) 432 | x = self.stages[i][-1](x) 433 | tmp_x.append(x) 434 | 435 | features.append(x) 436 | 437 | return features 438 | 439 | def forward(self, x): 440 | x = self.forward_features(x) 441 | 442 | return x 443 | -------------------------------------------------------------------------------- /networks/pose_decoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import torch 3 | import torch.nn as nn 4 | from collections import OrderedDict 5 | from timm.models.layers import trunc_normal_ 6 | 7 | 8 | class PoseDecoder(nn.Module): 9 | def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=None, stride=1): 10 | super(PoseDecoder, self).__init__() 11 | 12 | self.num_ch_enc = num_ch_enc 13 | self.num_input_features = num_input_features 14 | 15 | if num_frames_to_predict_for is None: 16 | num_frames_to_predict_for = num_input_features - 1 17 | self.num_frames_to_predict_for = num_frames_to_predict_for 18 | 19 | self.convs = OrderedDict() 20 | self.convs[("squeeze")] = nn.Conv2d(self.num_ch_enc[-1], 256, 1) 21 | self.convs[("pose", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1) 22 | self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1) 23 | self.convs[("pose", 2)] = nn.Conv2d(256, 6 * num_frames_to_predict_for, 1) 24 | 25 | self.relu = nn.ReLU() 26 | 27 | self.net = nn.ModuleList(list(self.convs.values())) 28 | 29 | self.apply(self._init_weights) 30 | 31 | def _init_weights(self, m): 32 | if isinstance(m, (nn.Conv2d, nn.Linear)): 33 | if isinstance(m, (nn.Conv2d, nn.Linear)): 34 | trunc_normal_(m.weight, std=.02) 35 | if m.bias is not None: 36 | nn.init.constant_(m.bias, 0) 37 | 38 | def forward(self, input_features): 39 | last_features = [f[-1] for f in input_features] 40 | 41 | cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features] 42 | cat_features = torch.cat(cat_features, 1) 43 | 44 | out = cat_features 45 | for i in range(3): 46 | out = self.convs[("pose", i)](out) 47 | if i != 2: 48 | out = self.relu(out) 49 | 50 | out = out.mean(3).mean(2) 51 | 52 | out = 0.01 * out.view(-1, self.num_frames_to_predict_for, 1, 6) 53 | 54 | axisangle = out[..., :3] 55 | translation = out[..., 3:] 56 | 57 | return axisangle, translation 58 | -------------------------------------------------------------------------------- /networks/resnet_encoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.models as models 6 | import torch.utils.model_zoo as model_zoo 7 | from torchvision import transforms 8 | 9 | 10 | class ResNetMultiImageInput(models.ResNet): 11 | """Constructs a resnet model with varying number of input images. 12 | Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 13 | """ 14 | def __init__(self, block, layers, num_classes=1000, num_input_images=1): 15 | super(ResNetMultiImageInput, self).__init__(block, layers) 16 | self.inplanes = 64 17 | self.conv1 = nn.Conv2d( 18 | num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False) 19 | self.bn1 = nn.BatchNorm2d(64) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 22 | self.layer1 = self._make_layer(block, 64, layers[0]) 23 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 24 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 25 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 26 | 27 | for m in self.modules(): 28 | if isinstance(m, nn.Conv2d): 29 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 30 | elif isinstance(m, nn.BatchNorm2d): 31 | nn.init.constant_(m.weight, 1) 32 | nn.init.constant_(m.bias, 0) 33 | 34 | 35 | def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1): 36 | """Constructs a ResNet model. 37 | Args: 38 | num_layers (int): Number of resnet layers. Must be 18 or 50 39 | pretrained (bool): If True, returns a model pre-trained on ImageNet 40 | num_input_images (int): Number of frames stacked as input 41 | """ 42 | assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet" 43 | blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers] 44 | block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers] 45 | model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images) 46 | 47 | if pretrained: 48 | loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)]) 49 | loaded['conv1.weight'] = torch.cat( 50 | [loaded['conv1.weight']] * num_input_images, 1) / num_input_images 51 | model.load_state_dict(loaded) 52 | return model 53 | 54 | 55 | class ResnetEncoder(nn.Module): 56 | """Pytorch module for a resnet encoder 57 | """ 58 | def __init__(self, num_layers, pretrained, num_input_images=1): 59 | super(ResnetEncoder, self).__init__() 60 | 61 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 62 | std=[0.229, 0.224, 0.225]) 63 | 64 | self.num_ch_enc = np.array([64, 64, 128, 256, 512]) 65 | 66 | resnets = {18: models.resnet18, 67 | 34: models.resnet34, 68 | 50: models.resnet50, 69 | 101: models.resnet101, 70 | 152: models.resnet152} 71 | 72 | if num_layers not in resnets: 73 | raise ValueError("{} is not a valid number of resnet layers".format(num_layers)) 74 | 75 | if num_input_images > 1: 76 | self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images) 77 | else: 78 | self.encoder = resnets[num_layers](pretrained) 79 | 80 | if num_layers > 34: 81 | self.num_ch_enc[1:] *= 4 82 | 83 | def forward(self, input_image): 84 | self.features = [] 85 | x = (input_image - 0.45) / 0.225 86 | x = self.encoder.conv1(x) 87 | x = self.encoder.bn1(x) 88 | self.features.append(self.encoder.relu(x)) 89 | self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1]))) 90 | self.features.append(self.encoder.layer2(self.features[-1])) 91 | self.features.append(self.encoder.layer3(self.features[-1])) 92 | self.features.append(self.encoder.layer4(self.features[-1])) 93 | 94 | return self.features 95 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import argparse 5 | 6 | file_dir = os.path.dirname(__file__) # the directory that options.py resides in 7 | 8 | 9 | class LiteMonoOptions: 10 | def __init__(self): 11 | self.parser = argparse.ArgumentParser(description="Lite-Mono options") 12 | 13 | # PATHS 14 | self.parser.add_argument("--data_path", 15 | type=str, 16 | help="path to the training data", 17 | default=os.path.join(file_dir, "kitti_data")) 18 | self.parser.add_argument("--log_dir", 19 | type=str, 20 | help="log directory", 21 | default="./tmp") 22 | 23 | # TRAINING options 24 | self.parser.add_argument("--model_name", 25 | type=str, 26 | help="the name of the folder to save the model in", 27 | default="lite-mono") 28 | self.parser.add_argument("--split", 29 | type=str, 30 | help="which training split to use", 31 | choices=["eigen_zhou", "eigen_full", "odom", "benchmark"], 32 | default="eigen_zhou") 33 | self.parser.add_argument("--model", 34 | type=str, 35 | help="which model to load", 36 | choices=["lite-mono", "lite-mono-small", "lite-mono-tiny", "lite-mono-8m"], 37 | default="lite-mono") 38 | self.parser.add_argument("--weight_decay", 39 | type=float, 40 | help="weight decay in AdamW", 41 | default=1e-2) 42 | self.parser.add_argument("--drop_path", 43 | type=float, 44 | help="drop path rate", 45 | default=0.2) 46 | self.parser.add_argument("--num_layers", 47 | type=int, 48 | help="number of resnet layers", 49 | default=18, 50 | choices=[18, 34, 50, 101, 152]) 51 | self.parser.add_argument("--dataset", 52 | type=str, 53 | help="dataset to train on", 54 | default="kitti", 55 | choices=["kitti", "kitti_odom", "kitti_depth", "kitti_test"]) 56 | self.parser.add_argument("--png", 57 | help="if set, trains from raw KITTI png files (instead of jpgs)", 58 | action="store_true") 59 | self.parser.add_argument("--height", 60 | type=int, 61 | help="input image height", 62 | default=192) 63 | self.parser.add_argument("--width", 64 | type=int, 65 | help="input image width", 66 | default=640) 67 | self.parser.add_argument("--disparity_smoothness", 68 | type=float, 69 | help="disparity smoothness weight", 70 | default=1e-3) 71 | self.parser.add_argument("--scales", 72 | nargs="+", 73 | type=int, 74 | help="scales used in the loss", 75 | default=[0, 1, 2]) 76 | self.parser.add_argument("--min_depth", 77 | type=float, 78 | help="minimum depth", 79 | default=0.1) 80 | self.parser.add_argument("--max_depth", 81 | type=float, 82 | help="maximum depth", 83 | default=100.0) 84 | self.parser.add_argument("--use_stereo", 85 | help="if set, uses stereo pair for training", 86 | action="store_true") 87 | self.parser.add_argument("--frame_ids", 88 | nargs="+", 89 | type=int, 90 | help="frames to load", 91 | default=[0, -1, 1]) 92 | 93 | self.parser.add_argument("--profile", 94 | type=bool, 95 | help="profile once at the beginning of the training", 96 | default=True) 97 | 98 | # OPTIMIZATION options 99 | self.parser.add_argument("--batch_size", 100 | type=int, 101 | help="batch size", 102 | default=16) 103 | self.parser.add_argument("--lr", 104 | nargs="+", 105 | type=float, 106 | help="learning rates of DepthNet and PoseNet. " 107 | "Initial learning rate, " 108 | "minimum learning rate, " 109 | "First cycle step size.", 110 | default=[0.0001, 5e-6, 31, 0.0001, 1e-5, 31]) 111 | self.parser.add_argument("--num_epochs", 112 | type=int, 113 | help="number of epochs", 114 | default=50) 115 | self.parser.add_argument("--scheduler_step_size", 116 | type=int, 117 | help="step size of the scheduler", 118 | default=15) 119 | 120 | # ABLATION options 121 | self.parser.add_argument("--v1_multiscale", 122 | help="if set, uses monodepth v1 multiscale", 123 | action="store_true") 124 | self.parser.add_argument("--avg_reprojection", 125 | help="if set, uses average reprojection loss", 126 | action="store_true") 127 | self.parser.add_argument("--disable_automasking", 128 | help="if set, doesn't do auto-masking", 129 | action="store_true") 130 | self.parser.add_argument("--predictive_mask", 131 | help="if set, uses a predictive masking scheme as in Zhou et al", 132 | action="store_true") 133 | self.parser.add_argument("--no_ssim", 134 | help="if set, disables ssim in the loss", 135 | action="store_true") 136 | self.parser.add_argument("--mypretrain", 137 | type=str, 138 | help="if set, use my pretrained encoder") 139 | self.parser.add_argument("--weights_init", 140 | type=str, 141 | help="pretrained or scratch", 142 | default="pretrained", 143 | choices=["pretrained", "scratch"]) 144 | self.parser.add_argument("--pose_model_input", 145 | type=str, 146 | help="how many images the pose network gets", 147 | default="pairs", 148 | choices=["pairs", "all"]) 149 | self.parser.add_argument("--pose_model_type", 150 | type=str, 151 | help="normal or shared", 152 | default="separate_resnet", 153 | choices=["posecnn", "separate_resnet", "shared"]) 154 | 155 | # SYSTEM options 156 | self.parser.add_argument("--no_cuda", 157 | help="if set disables CUDA", 158 | action="store_true") 159 | self.parser.add_argument("--num_workers", 160 | type=int, 161 | help="number of dataloader workers", 162 | default=12) 163 | 164 | # LOADING options 165 | self.parser.add_argument("--load_weights_folder", 166 | type=str, 167 | help="name of model to load") 168 | self.parser.add_argument("--models_to_load", 169 | nargs="+", 170 | type=str, 171 | help="models to load", 172 | default=["encoder", "depth", "pose_encoder", "pose"]) 173 | 174 | # LOGGING options 175 | self.parser.add_argument("--log_frequency", 176 | type=int, 177 | help="number of batches between each tensorboard log", 178 | default=250) 179 | self.parser.add_argument("--save_frequency", 180 | type=int, 181 | help="number of epochs between each save", 182 | default=1) 183 | 184 | # EVALUATION options 185 | self.parser.add_argument("--disable_median_scaling", 186 | help="if set disables median scaling in evaluation", 187 | action="store_true") 188 | self.parser.add_argument("--pred_depth_scale_factor", 189 | help="if set multiplies predictions by this number", 190 | type=float, 191 | default=1) 192 | self.parser.add_argument("--ext_disp_to_eval", 193 | type=str, 194 | help="optional path to a .npy disparities file to evaluate") 195 | self.parser.add_argument("--eval_split", 196 | type=str, 197 | default="eigen", 198 | choices=[ 199 | "eigen"], 200 | help="which split to run eval on") 201 | self.parser.add_argument("--save_pred_disps", 202 | help="if set saves predicted disparities", 203 | action="store_true") 204 | self.parser.add_argument("--no_eval", 205 | help="if set disables evaluation", 206 | action="store_true") 207 | self.parser.add_argument("--eval_out_dir", 208 | help="if set will output the disparities to this folder", 209 | type=str) 210 | self.parser.add_argument("--post_process", 211 | help="if set will perform the flipping post processing " 212 | "from the original monodepth paper", 213 | action="store_true") 214 | 215 | def parse(self): 216 | self.options = self.parser.parse_args() 217 | return self.options 218 | -------------------------------------------------------------------------------- /test_simple.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import sys 5 | import glob 6 | import argparse 7 | import numpy as np 8 | import PIL.Image as pil 9 | import matplotlib as mpl 10 | import matplotlib.cm as cm 11 | 12 | import torch 13 | from torchvision import transforms, datasets 14 | 15 | import networks 16 | from layers import disp_to_depth 17 | import cv2 18 | import heapq 19 | from PIL import ImageFile 20 | ImageFile.LOAD_TRUNCATED_IMAGES = True 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser( 25 | description='Simple testing function for Lite-Mono models.') 26 | 27 | parser.add_argument('--image_path', type=str, 28 | help='path to a test image or folder of images', required=True) 29 | 30 | parser.add_argument('--load_weights_folder', type=str, 31 | help='path of a pretrained model to use', 32 | ) 33 | 34 | parser.add_argument('--test', 35 | action='store_true', 36 | help='if set, read images from a .txt file', 37 | ) 38 | 39 | parser.add_argument('--model', type=str, 40 | help='name of a pretrained model to use', 41 | default="lite-mono", 42 | choices=[ 43 | "lite-mono", 44 | "lite-mono-small", 45 | "lite-mono-tiny", 46 | "lite-mono-8m"]) 47 | 48 | parser.add_argument('--ext', type=str, 49 | help='image extension to search for in folder', default="jpg") 50 | parser.add_argument("--no_cuda", 51 | help='if set, disables CUDA', 52 | action='store_true') 53 | 54 | return parser.parse_args() 55 | 56 | 57 | def test_simple(args): 58 | """Function to predict for a single image or folder of images 59 | """ 60 | assert args.load_weights_folder is not None, \ 61 | "You must specify the --load_weights_folder parameter" 62 | 63 | if torch.cuda.is_available() and not args.no_cuda: 64 | device = torch.device("cuda") 65 | else: 66 | device = torch.device("cpu") 67 | 68 | print("-> Loading model from ", args.load_weights_folder) 69 | encoder_path = os.path.join(args.load_weights_folder, "encoder.pth") 70 | decoder_path = os.path.join(args.load_weights_folder, "depth.pth") 71 | 72 | encoder_dict = torch.load(encoder_path) 73 | decoder_dict = torch.load(decoder_path) 74 | 75 | # extract the height and width of image that this model was trained with 76 | feed_height = encoder_dict['height'] 77 | feed_width = encoder_dict['width'] 78 | 79 | # LOADING PRETRAINED MODEL 80 | print(" Loading pretrained encoder") 81 | encoder = networks.LiteMono(model=args.model, 82 | height=feed_height, 83 | width=feed_width) 84 | 85 | model_dict = encoder.state_dict() 86 | encoder.load_state_dict({k: v for k, v in encoder_dict.items() if k in model_dict}) 87 | 88 | encoder.to(device) 89 | encoder.eval() 90 | 91 | print(" Loading pretrained decoder") 92 | depth_decoder = networks.DepthDecoder(encoder.num_ch_enc, scales=range(3)) 93 | depth_model_dict = depth_decoder.state_dict() 94 | depth_decoder.load_state_dict({k: v for k, v in decoder_dict.items() if k in depth_model_dict}) 95 | 96 | depth_decoder.to(device) 97 | depth_decoder.eval() 98 | 99 | # FINDING INPUT IMAGES 100 | if os.path.isfile(args.image_path) and not args.test: 101 | # Only testing on a single image 102 | paths = [args.image_path] 103 | output_directory = os.path.dirname(args.image_path) 104 | elif os.path.isfile(args.image_path) and args.test: 105 | gt_path = os.path.join('splits', 'eigen', "gt_depths.npz") 106 | gt_depths = np.load(gt_path, fix_imports=True, encoding='latin1', allow_pickle=True)["data"] 107 | 108 | side_map = {"2": 2, "3": 3, "l": 2, "r": 3} 109 | # reading images from .txt file 110 | paths = [] 111 | with open(args.image_path) as f: 112 | filenames = f.readlines() 113 | for i in range(len(filenames)): 114 | filename = filenames[i] 115 | line = filename.split() 116 | folder = line[0] 117 | if len(line) == 3: 118 | frame_index = int(line[1]) 119 | side = line[2] 120 | 121 | f_str = "{:010d}{}".format(frame_index, '.jpg') 122 | image_path = os.path.join( 123 | 'kitti_data', 124 | folder, 125 | "image_0{}/data".format(side_map[side]), 126 | f_str) 127 | paths.append(image_path) 128 | 129 | elif os.path.isdir(args.image_path): 130 | # Searching folder for images 131 | paths = glob.glob(os.path.join(args.image_path, '*.{}'.format(args.ext))) 132 | output_directory = args.image_path 133 | else: 134 | raise Exception("Can not find args.image_path: {}".format(args.image_path)) 135 | 136 | print("-> Predicting on {:d} test images".format(len(paths))) 137 | 138 | # PREDICTING ON EACH IMAGE IN TURN 139 | with torch.no_grad(): 140 | for idx, image_path in enumerate(paths): 141 | 142 | if image_path.endswith("_disp.jpg"): 143 | # don't try to predict disparity for a disparity image! 144 | continue 145 | 146 | # Load image and preprocess 147 | input_image = pil.open(image_path).convert('RGB') 148 | original_width, original_height = input_image.size 149 | input_image = input_image.resize((feed_width, feed_height), pil.LANCZOS) 150 | input_image = transforms.ToTensor()(input_image).unsqueeze(0) 151 | 152 | # PREDICTION 153 | input_image = input_image.to(device) 154 | features = encoder(input_image) 155 | outputs = depth_decoder(features) 156 | 157 | disp = outputs[("disp", 0)] 158 | 159 | disp_resized = torch.nn.functional.interpolate( 160 | disp, (original_height, original_width), mode="bilinear", align_corners=False) 161 | 162 | # Saving numpy file 163 | output_name = os.path.splitext(os.path.basename(image_path))[0] 164 | # output_name = os.path.splitext(image_path)[0].split('/')[-1] 165 | scaled_disp, depth = disp_to_depth(disp, 0.1, 100) 166 | 167 | name_dest_npy = os.path.join(output_directory, "{}_disp.npy".format(output_name)) 168 | np.save(name_dest_npy, scaled_disp.cpu().numpy()) 169 | 170 | # Saving colormapped depth image 171 | disp_resized_np = disp_resized.squeeze().cpu().numpy() 172 | vmax = np.percentile(disp_resized_np, 95) 173 | normalizer = mpl.colors.Normalize(vmin=disp_resized_np.min(), vmax=vmax) 174 | mapper = cm.ScalarMappable(norm=normalizer, cmap='magma') 175 | colormapped_im = (mapper.to_rgba(disp_resized_np)[:, :, :3] * 255).astype(np.uint8) 176 | im = pil.fromarray(colormapped_im) 177 | 178 | name_dest_im = os.path.join(output_directory, "{}_disp.jpeg".format(output_name)) 179 | im.save(name_dest_im) 180 | 181 | print(" Processed {:d} of {:d} images - saved predictions to:".format( 182 | idx + 1, len(paths))) 183 | print(" - {}".format(name_dest_im)) 184 | print(" - {}".format(name_dest_npy)) 185 | 186 | 187 | print('-> Done!') 188 | 189 | 190 | if __name__ == '__main__': 191 | args = parse_args() 192 | test_simple(args) 193 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | from options import LiteMonoOptions 4 | from trainer import Trainer 5 | 6 | options = LiteMonoOptions() 7 | opts = options.parse() 8 | 9 | 10 | if __name__ == "__main__": 11 | trainer = Trainer(opts) 12 | trainer.train() 13 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | 4 | import time 5 | import torch.optim as optim 6 | from torch.utils.data import DataLoader 7 | from tensorboardX import SummaryWriter 8 | 9 | import json 10 | 11 | from utils import * 12 | from kitti_utils import * 13 | from layers import * 14 | 15 | import datasets 16 | import networks 17 | from linear_warmup_cosine_annealing_warm_restarts_weight_decay import ChainedScheduler 18 | 19 | 20 | # torch.backends.cudnn.benchmark = True 21 | 22 | 23 | def time_sync(): 24 | # PyTorch-accurate time 25 | if torch.cuda.is_available(): 26 | torch.cuda.synchronize() 27 | return time.time() 28 | 29 | 30 | class Trainer: 31 | def __init__(self, options): 32 | self.opt = options 33 | self.log_path = os.path.join(self.opt.log_dir, self.opt.model_name) 34 | 35 | # checking height and width are multiples of 32 36 | assert self.opt.height % 32 == 0, "'height' must be a multiple of 32" 37 | assert self.opt.width % 32 == 0, "'width' must be a multiple of 32" 38 | 39 | self.models = {} 40 | self.models_pose = {} 41 | self.parameters_to_train = [] 42 | self.parameters_to_train_pose = [] 43 | 44 | self.device = torch.device("cpu" if self.opt.no_cuda else "cuda") 45 | self.profile = self.opt.profile 46 | 47 | self.num_scales = len(self.opt.scales) 48 | self.frame_ids = len(self.opt.frame_ids) 49 | self.num_pose_frames = 2 if self.opt.pose_model_input == "pairs" else self.num_input_frames 50 | 51 | assert self.opt.frame_ids[0] == 0, "frame_ids must start with 0" 52 | 53 | self.use_pose_net = not (self.opt.use_stereo and self.opt.frame_ids == [0]) 54 | 55 | if self.opt.use_stereo: 56 | self.opt.frame_ids.append("s") 57 | 58 | self.models["encoder"] = networks.LiteMono(model=self.opt.model, 59 | drop_path_rate=self.opt.drop_path, 60 | width=self.opt.width, height=self.opt.height) 61 | 62 | self.models["encoder"].to(self.device) 63 | self.parameters_to_train += list(self.models["encoder"].parameters()) 64 | 65 | self.models["depth"] = networks.DepthDecoder(self.models["encoder"].num_ch_enc, 66 | self.opt.scales) 67 | self.models["depth"].to(self.device) 68 | self.parameters_to_train += list(self.models["depth"].parameters()) 69 | 70 | if self.use_pose_net: 71 | if self.opt.pose_model_type == "separate_resnet": 72 | self.models_pose["pose_encoder"] = networks.ResnetEncoder( 73 | self.opt.num_layers, 74 | self.opt.weights_init == "pretrained", 75 | num_input_images=self.num_pose_frames) 76 | 77 | self.models_pose["pose_encoder"].to(self.device) 78 | self.parameters_to_train_pose += list(self.models_pose["pose_encoder"].parameters()) 79 | 80 | self.models_pose["pose"] = networks.PoseDecoder( 81 | self.models_pose["pose_encoder"].num_ch_enc, 82 | num_input_features=1, 83 | num_frames_to_predict_for=2) 84 | 85 | elif self.opt.pose_model_type == "shared": 86 | self.models_pose["pose"] = networks.PoseDecoder( 87 | self.models["encoder"].num_ch_enc, self.num_pose_frames) 88 | 89 | elif self.opt.pose_model_type == "posecnn": 90 | self.models_pose["pose"] = networks.PoseCNN( 91 | self.num_input_frames if self.opt.pose_model_input == "all" else 2) 92 | 93 | self.models_pose["pose"].to(self.device) 94 | self.parameters_to_train_pose += list(self.models_pose["pose"].parameters()) 95 | 96 | if self.opt.predictive_mask: 97 | assert self.opt.disable_automasking, \ 98 | "When using predictive_mask, please disable automasking with --disable_automasking" 99 | 100 | # Our implementation of the predictive masking baseline has the the same architecture 101 | # as our depth decoder. We predict a separate mask for each source frame. 102 | self.models["predictive_mask"] = networks.DepthDecoder( 103 | self.models["encoder"].num_ch_enc, self.opt.scales, 104 | num_output_channels=(len(self.opt.frame_ids) - 1)) 105 | self.models["predictive_mask"].to(self.device) 106 | self.parameters_to_train += list(self.models["predictive_mask"].parameters()) 107 | 108 | self.model_optimizer = optim.AdamW(self.parameters_to_train, self.opt.lr[0], weight_decay=self.opt.weight_decay) 109 | if self.use_pose_net: 110 | self.model_pose_optimizer = optim.AdamW(self.parameters_to_train_pose, self.opt.lr[3], weight_decay=self.opt.weight_decay) 111 | 112 | self.model_lr_scheduler = ChainedScheduler( 113 | self.model_optimizer, 114 | T_0=int(self.opt.lr[2]), 115 | T_mul=1, 116 | eta_min=self.opt.lr[1], 117 | last_epoch=-1, 118 | max_lr=self.opt.lr[0], 119 | warmup_steps=0, 120 | gamma=0.9 121 | ) 122 | self.model_pose_lr_scheduler = ChainedScheduler( 123 | self.model_pose_optimizer, 124 | T_0=int(self.opt.lr[5]), 125 | T_mul=1, 126 | eta_min=self.opt.lr[4], 127 | last_epoch=-1, 128 | max_lr=self.opt.lr[3], 129 | warmup_steps=0, 130 | gamma=0.9 131 | ) 132 | 133 | if self.opt.load_weights_folder is not None: 134 | self.load_model() 135 | 136 | if self.opt.mypretrain is not None: 137 | self.load_pretrain() 138 | 139 | print("Training model named:\n ", self.opt.model_name) 140 | print("Models and tensorboard events files are saved to:\n ", self.opt.log_dir) 141 | print("Training is using:\n ", self.device) 142 | 143 | # data 144 | datasets_dict = {"kitti": datasets.KITTIRAWDataset, 145 | "kitti_odom": datasets.KITTIOdomDataset} 146 | self.dataset = datasets_dict[self.opt.dataset] 147 | 148 | fpath = os.path.join(os.path.dirname(__file__), "splits", self.opt.split, "{}_files.txt") 149 | 150 | train_filenames = readlines(fpath.format("train")) 151 | val_filenames = readlines(fpath.format("val")) 152 | img_ext = '.png' if self.opt.png else '.jpg' 153 | 154 | num_train_samples = len(train_filenames) 155 | self.num_total_steps = num_train_samples // self.opt.batch_size * self.opt.num_epochs 156 | 157 | train_dataset = self.dataset( 158 | self.opt.data_path, train_filenames, self.opt.height, self.opt.width, 159 | self.opt.frame_ids, 4, is_train=True, img_ext=img_ext) 160 | self.train_loader = DataLoader( 161 | train_dataset, self.opt.batch_size, True, 162 | num_workers=self.opt.num_workers, pin_memory=True, drop_last=True) 163 | val_dataset = self.dataset( 164 | self.opt.data_path, val_filenames, self.opt.height, self.opt.width, 165 | self.opt.frame_ids, 4, is_train=False, img_ext=img_ext) 166 | self.val_loader = DataLoader( 167 | val_dataset, self.opt.batch_size, True, 168 | num_workers=self.opt.num_workers, pin_memory=True, drop_last=True) 169 | self.val_iter = iter(self.val_loader) 170 | 171 | self.writers = {} 172 | for mode in ["train", "val"]: 173 | self.writers[mode] = SummaryWriter(os.path.join(self.log_path, mode)) 174 | 175 | if not self.opt.no_ssim: 176 | self.ssim = SSIM() 177 | self.ssim.to(self.device) 178 | 179 | self.backproject_depth = {} 180 | self.project_3d = {} 181 | for scale in self.opt.scales: 182 | h = self.opt.height // (2 ** scale) 183 | w = self.opt.width // (2 ** scale) 184 | 185 | self.backproject_depth[scale] = BackprojectDepth(self.opt.batch_size, h, w) 186 | self.backproject_depth[scale].to(self.device) 187 | 188 | self.project_3d[scale] = Project3D(self.opt.batch_size, h, w) 189 | self.project_3d[scale].to(self.device) 190 | 191 | self.depth_metric_names = [ 192 | "de/abs_rel", "de/sq_rel", "de/rms", "de/log_rms", "da/a1", "da/a2", "da/a3"] 193 | 194 | print("Using split:\n ", self.opt.split) 195 | print("There are {:d} training items and {:d} validation items\n".format( 196 | len(train_dataset), len(val_dataset))) 197 | 198 | self.save_opts() 199 | 200 | def set_train(self): 201 | """Convert all models to training mode 202 | """ 203 | for m in self.models.values(): 204 | m.train() 205 | 206 | def set_eval(self): 207 | """Convert all models to testing/evaluation mode 208 | """ 209 | for m in self.models.values(): 210 | m.eval() 211 | 212 | def train(self): 213 | """Run the entire training pipeline 214 | """ 215 | self.epoch = 0 216 | self.step = 0 217 | self.start_time = time.time() 218 | for self.epoch in range(self.opt.num_epochs): 219 | self.run_epoch() 220 | if (self.epoch + 1) % self.opt.save_frequency == 0: 221 | self.save_model() 222 | 223 | def run_epoch(self): 224 | """Run a single epoch of training and validation 225 | """ 226 | 227 | print("Training") 228 | self.set_train() 229 | 230 | self.model_lr_scheduler.step() 231 | if self.use_pose_net: 232 | self.model_pose_lr_scheduler.step() 233 | 234 | for batch_idx, inputs in enumerate(self.train_loader): 235 | 236 | before_op_time = time.time() 237 | 238 | outputs, losses = self.process_batch(inputs) 239 | 240 | self.model_optimizer.zero_grad() 241 | if self.use_pose_net: 242 | self.model_pose_optimizer.zero_grad() 243 | losses["loss"].backward() 244 | self.model_optimizer.step() 245 | if self.use_pose_net: 246 | self.model_pose_optimizer.step() 247 | 248 | duration = time.time() - before_op_time 249 | 250 | # log less frequently after the first 2000 steps to save time & disk space 251 | early_phase = batch_idx % self.opt.log_frequency == 0 and self.step < 20000 252 | late_phase = self.step % 2000 == 0 253 | 254 | if early_phase or late_phase: 255 | self.log_time(batch_idx, duration, losses["loss"].cpu().data) 256 | 257 | if "depth_gt" in inputs: 258 | self.compute_depth_losses(inputs, outputs, losses) 259 | 260 | self.log("train", inputs, outputs, losses) 261 | self.val() 262 | 263 | self.step += 1 264 | 265 | def process_batch(self, inputs): 266 | """Pass a minibatch through the network and generate images and losses 267 | """ 268 | for key, ipt in inputs.items(): 269 | inputs[key] = ipt.to(self.device) 270 | 271 | if self.opt.pose_model_type == "shared": 272 | # If we are using a shared encoder for both depth and pose (as advocated 273 | # in monodepthv1), then all images are fed separately through the depth encoder. 274 | all_color_aug = torch.cat([inputs[("color_aug", i, 0)] for i in self.opt.frame_ids]) 275 | all_features = self.models["encoder"](all_color_aug) 276 | all_features = [torch.split(f, self.opt.batch_size) for f in all_features] 277 | 278 | features = {} 279 | for i, k in enumerate(self.opt.frame_ids): 280 | features[k] = [f[i] for f in all_features] 281 | 282 | outputs = self.models["depth"](features[0]) 283 | else: 284 | # Otherwise, we only feed the image with frame_id 0 through the depth encoder 285 | 286 | features = self.models["encoder"](inputs["color_aug", 0, 0]) 287 | 288 | outputs = self.models["depth"](features) 289 | 290 | if self.opt.predictive_mask: 291 | outputs["predictive_mask"] = self.models["predictive_mask"](features) 292 | 293 | if self.use_pose_net: 294 | outputs.update(self.predict_poses(inputs, features)) 295 | 296 | self.generate_images_pred(inputs, outputs) 297 | losses = self.compute_losses(inputs, outputs) 298 | 299 | return outputs, losses 300 | 301 | def predict_poses(self, inputs, features): 302 | """Predict poses between input frames for monocular sequences. 303 | """ 304 | outputs = {} 305 | if self.num_pose_frames == 2: 306 | # In this setting, we compute the pose to each source frame via a 307 | # separate forward pass through the pose network. 308 | 309 | # select what features the pose network takes as input 310 | if self.opt.pose_model_type == "shared": 311 | pose_feats = {f_i: features[f_i] for f_i in self.opt.frame_ids} 312 | else: 313 | pose_feats = {f_i: inputs["color_aug", f_i, 0] for f_i in self.opt.frame_ids} 314 | 315 | for f_i in self.opt.frame_ids[1:]: 316 | if f_i != "s": 317 | # To maintain ordering we always pass frames in temporal order 318 | if f_i < 0: 319 | pose_inputs = [pose_feats[f_i], pose_feats[0]] 320 | else: 321 | pose_inputs = [pose_feats[0], pose_feats[f_i]] 322 | 323 | if self.opt.pose_model_type == "separate_resnet": 324 | pose_inputs = [self.models_pose["pose_encoder"](torch.cat(pose_inputs, 1))] 325 | elif self.opt.pose_model_type == "posecnn": 326 | pose_inputs = torch.cat(pose_inputs, 1) 327 | 328 | axisangle, translation = self.models_pose["pose"](pose_inputs) 329 | outputs[("axisangle", 0, f_i)] = axisangle 330 | outputs[("translation", 0, f_i)] = translation 331 | 332 | # Invert the matrix if the frame id is negative 333 | outputs[("cam_T_cam", 0, f_i)] = transformation_from_parameters( 334 | axisangle[:, 0], translation[:, 0], invert=(f_i < 0)) 335 | 336 | else: 337 | # Here we input all frames to the pose net (and predict all poses) together 338 | if self.opt.pose_model_type in ["separate_resnet", "posecnn"]: 339 | pose_inputs = torch.cat( 340 | [inputs[("color_aug", i, 0)] for i in self.opt.frame_ids if i != "s"], 1) 341 | 342 | if self.opt.pose_model_type == "separate_resnet": 343 | pose_inputs = [self.models["pose_encoder"](pose_inputs)] 344 | 345 | elif self.opt.pose_model_type == "shared": 346 | pose_inputs = [features[i] for i in self.opt.frame_ids if i != "s"] 347 | 348 | axisangle, translation = self.models_pose["pose"](pose_inputs) 349 | 350 | for i, f_i in enumerate(self.opt.frame_ids[1:]): 351 | if f_i != "s": 352 | outputs[("axisangle", 0, f_i)] = axisangle 353 | outputs[("translation", 0, f_i)] = translation 354 | outputs[("cam_T_cam", 0, f_i)] = transformation_from_parameters( 355 | axisangle[:, i], translation[:, i]) 356 | 357 | return outputs 358 | 359 | def val(self): 360 | """Validate the model on a single minibatch 361 | """ 362 | self.set_eval() 363 | try: 364 | inputs = self.val_iter.next() 365 | except StopIteration: 366 | self.val_iter = iter(self.val_loader) 367 | inputs = self.val_iter.next() 368 | 369 | with torch.no_grad(): 370 | outputs, losses = self.process_batch(inputs) 371 | 372 | if "depth_gt" in inputs: 373 | self.compute_depth_losses(inputs, outputs, losses) 374 | 375 | self.log("val", inputs, outputs, losses) 376 | del inputs, outputs, losses 377 | 378 | self.set_train() 379 | 380 | def generate_images_pred(self, inputs, outputs): 381 | """Generate the warped (reprojected) color images for a minibatch. 382 | Generated images are saved into the `outputs` dictionary. 383 | """ 384 | for scale in self.opt.scales: 385 | disp = outputs[("disp", scale)] 386 | if self.opt.v1_multiscale: 387 | source_scale = scale 388 | else: 389 | disp = F.interpolate( 390 | disp, [self.opt.height, self.opt.width], mode="bilinear", align_corners=False) 391 | source_scale = 0 392 | 393 | _, depth = disp_to_depth(disp, self.opt.min_depth, self.opt.max_depth) 394 | 395 | outputs[("depth", 0, scale)] = depth 396 | 397 | for i, frame_id in enumerate(self.opt.frame_ids[1:]): 398 | 399 | if frame_id == "s": 400 | T = inputs["stereo_T"] 401 | else: 402 | T = outputs[("cam_T_cam", 0, frame_id)] 403 | 404 | # from the authors of https://arxiv.org/abs/1712.00175 405 | if self.opt.pose_model_type == "posecnn": 406 | 407 | axisangle = outputs[("axisangle", 0, frame_id)] 408 | translation = outputs[("translation", 0, frame_id)] 409 | 410 | inv_depth = 1 / depth 411 | mean_inv_depth = inv_depth.mean(3, True).mean(2, True) 412 | 413 | T = transformation_from_parameters( 414 | axisangle[:, 0], translation[:, 0] * mean_inv_depth[:, 0], frame_id < 0) 415 | 416 | cam_points = self.backproject_depth[source_scale]( 417 | depth, inputs[("inv_K", source_scale)]) 418 | pix_coords = self.project_3d[source_scale]( 419 | cam_points, inputs[("K", source_scale)], T) 420 | 421 | outputs[("sample", frame_id, scale)] = pix_coords 422 | 423 | outputs[("color", frame_id, scale)] = F.grid_sample( 424 | inputs[("color", frame_id, source_scale)], 425 | outputs[("sample", frame_id, scale)], 426 | padding_mode="border", align_corners=True) 427 | 428 | if not self.opt.disable_automasking: 429 | outputs[("color_identity", frame_id, scale)] = \ 430 | inputs[("color", frame_id, source_scale)] 431 | 432 | def compute_reprojection_loss(self, pred, target): 433 | """Computes reprojection loss between a batch of predicted and target images 434 | """ 435 | abs_diff = torch.abs(target - pred) 436 | l1_loss = abs_diff.mean(1, True) 437 | 438 | if self.opt.no_ssim: 439 | reprojection_loss = l1_loss 440 | else: 441 | ssim_loss = self.ssim(pred, target).mean(1, True) 442 | reprojection_loss = 0.85 * ssim_loss + 0.15 * l1_loss 443 | 444 | return reprojection_loss 445 | 446 | def compute_losses(self, inputs, outputs): 447 | """Compute the reprojection and smoothness losses for a minibatch 448 | """ 449 | 450 | losses = {} 451 | total_loss = 0 452 | 453 | for scale in self.opt.scales: 454 | loss = 0 455 | reprojection_losses = [] 456 | 457 | if self.opt.v1_multiscale: 458 | source_scale = scale 459 | else: 460 | source_scale = 0 461 | 462 | disp = outputs[("disp", scale)] 463 | color = inputs[("color", 0, scale)] 464 | target = inputs[("color", 0, source_scale)] 465 | 466 | for frame_id in self.opt.frame_ids[1:]: 467 | pred = outputs[("color", frame_id, scale)] 468 | reprojection_losses.append(self.compute_reprojection_loss(pred, target)) 469 | 470 | reprojection_losses = torch.cat(reprojection_losses, 1) 471 | 472 | if not self.opt.disable_automasking: 473 | identity_reprojection_losses = [] 474 | for frame_id in self.opt.frame_ids[1:]: 475 | pred = inputs[("color", frame_id, source_scale)] 476 | identity_reprojection_losses.append( 477 | self.compute_reprojection_loss(pred, target)) 478 | 479 | identity_reprojection_losses = torch.cat(identity_reprojection_losses, 1) 480 | 481 | if self.opt.avg_reprojection: 482 | identity_reprojection_loss = identity_reprojection_losses.mean(1, keepdim=True) 483 | else: 484 | # save both images, and do min all at once below 485 | identity_reprojection_loss = identity_reprojection_losses 486 | 487 | elif self.opt.predictive_mask: 488 | # use the predicted mask 489 | mask = outputs["predictive_mask"]["disp", scale] 490 | if not self.opt.v1_multiscale: 491 | mask = F.interpolate( 492 | mask, [self.opt.height, self.opt.width], 493 | mode="bilinear", align_corners=False) 494 | 495 | reprojection_losses *= mask 496 | 497 | # add a loss pushing mask to 1 (using nn.BCELoss for stability) 498 | weighting_loss = 0.2 * nn.BCELoss()(mask, torch.ones(mask.shape).cuda()) 499 | loss += weighting_loss.mean() 500 | 501 | if self.opt.avg_reprojection: 502 | reprojection_loss = reprojection_losses.mean(1, keepdim=True) 503 | else: 504 | reprojection_loss = reprojection_losses 505 | 506 | if not self.opt.disable_automasking: 507 | # add random numbers to break ties 508 | identity_reprojection_loss += torch.randn( 509 | identity_reprojection_loss.shape, device=self.device) * 0.00001 510 | 511 | combined = torch.cat((identity_reprojection_loss, reprojection_loss), dim=1) 512 | else: 513 | combined = reprojection_loss 514 | 515 | if combined.shape[1] == 1: 516 | to_optimise = combined 517 | else: 518 | to_optimise, idxs = torch.min(combined, dim=1) 519 | 520 | if not self.opt.disable_automasking: 521 | outputs["identity_selection/{}".format(scale)] = ( 522 | idxs > identity_reprojection_loss.shape[1] - 1).float() 523 | 524 | loss += to_optimise.mean() 525 | 526 | mean_disp = disp.mean(2, True).mean(3, True) 527 | norm_disp = disp / (mean_disp + 1e-7) 528 | smooth_loss = get_smooth_loss(norm_disp, color) 529 | 530 | loss += self.opt.disparity_smoothness * smooth_loss / (2 ** scale) 531 | total_loss += loss 532 | losses["loss/{}".format(scale)] = loss 533 | 534 | total_loss /= self.num_scales 535 | losses["loss"] = total_loss 536 | return losses 537 | 538 | def compute_depth_losses(self, inputs, outputs, losses): 539 | """Compute depth metrics, to allow monitoring during training 540 | 541 | This isn't particularly accurate as it averages over the entire batch, 542 | so is only used to give an indication of validation performance 543 | """ 544 | depth_pred = outputs[("depth", 0, 0)] 545 | depth_pred = torch.clamp(F.interpolate( 546 | depth_pred, [375, 1242], mode="bilinear", align_corners=False), 1e-3, 80) 547 | depth_pred = depth_pred.detach() 548 | 549 | depth_gt = inputs["depth_gt"] 550 | mask = depth_gt > 0 551 | 552 | # garg/eigen crop 553 | crop_mask = torch.zeros_like(mask) 554 | crop_mask[:, :, 153:371, 44:1197] = 1 555 | mask = mask * crop_mask 556 | 557 | depth_gt = depth_gt[mask] 558 | depth_pred = depth_pred[mask] 559 | depth_pred *= torch.median(depth_gt) / torch.median(depth_pred) 560 | 561 | depth_pred = torch.clamp(depth_pred, min=1e-3, max=80) 562 | 563 | depth_errors = compute_depth_errors(depth_gt, depth_pred) 564 | 565 | for i, metric in enumerate(self.depth_metric_names): 566 | losses[metric] = np.array(depth_errors[i].cpu()) 567 | 568 | def log_time(self, batch_idx, duration, loss): 569 | """Print a logging statement to the terminal 570 | """ 571 | samples_per_sec = self.opt.batch_size / duration 572 | time_sofar = time.time() - self.start_time 573 | training_time_left = ( 574 | self.num_total_steps / self.step - 1.0) * time_sofar if self.step > 0 else 0 575 | print_string = "epoch {:>3} | lr {:.6f} |lr_p {:.6f} | batch {:>6} | examples/s: {:5.1f}" + \ 576 | " | loss: {:.5f} | time elapsed: {} | time left: {}" 577 | print(print_string.format(self.epoch, self.model_optimizer.state_dict()['param_groups'][0]['lr'], 578 | self.model_pose_optimizer.state_dict()['param_groups'][0]['lr'], 579 | batch_idx, samples_per_sec, loss, 580 | sec_to_hm_str(time_sofar), sec_to_hm_str(training_time_left))) 581 | 582 | def log(self, mode, inputs, outputs, losses): 583 | """Write an event to the tensorboard events file 584 | """ 585 | writer = self.writers[mode] 586 | for l, v in losses.items(): 587 | writer.add_scalar("{}".format(l), v, self.step) 588 | 589 | for j in range(min(4, self.opt.batch_size)): # write a maxmimum of four images 590 | for s in self.opt.scales: 591 | for frame_id in self.opt.frame_ids: 592 | writer.add_image( 593 | "color_{}_{}/{}".format(frame_id, s, j), 594 | inputs[("color", frame_id, s)][j].data, self.step) 595 | if s == 0 and frame_id != 0: 596 | writer.add_image( 597 | "color_pred_{}_{}/{}".format(frame_id, s, j), 598 | outputs[("color", frame_id, s)][j].data, self.step) 599 | 600 | writer.add_image( 601 | "disp_{}/{}".format(s, j), 602 | normalize_image(outputs[("disp", s)][j]), self.step) 603 | 604 | if self.opt.predictive_mask: 605 | for f_idx, frame_id in enumerate(self.opt.frame_ids[1:]): 606 | writer.add_image( 607 | "predictive_mask_{}_{}/{}".format(frame_id, s, j), 608 | outputs["predictive_mask"][("disp", s)][j, f_idx][None, ...], 609 | self.step) 610 | 611 | elif not self.opt.disable_automasking: 612 | writer.add_image( 613 | "automask_{}/{}".format(s, j), 614 | outputs["identity_selection/{}".format(s)][j][None, ...], self.step) 615 | 616 | def save_opts(self): 617 | """Save options to disk so we know what we ran this experiment with 618 | """ 619 | models_dir = os.path.join(self.log_path, "models") 620 | if not os.path.exists(models_dir): 621 | os.makedirs(models_dir) 622 | to_save = self.opt.__dict__.copy() 623 | 624 | with open(os.path.join(models_dir, 'opt.json'), 'w') as f: 625 | json.dump(to_save, f, indent=2) 626 | 627 | def save_model(self): 628 | """Save model weights to disk 629 | """ 630 | save_folder = os.path.join(self.log_path, "models", "weights_{}".format(self.epoch)) 631 | if not os.path.exists(save_folder): 632 | os.makedirs(save_folder) 633 | 634 | for model_name, model in self.models.items(): 635 | save_path = os.path.join(save_folder, "{}.pth".format(model_name)) 636 | to_save = model.state_dict() 637 | if model_name == 'encoder': 638 | # save the sizes - these are needed at prediction time 639 | to_save['height'] = self.opt.height 640 | to_save['width'] = self.opt.width 641 | to_save['use_stereo'] = self.opt.use_stereo 642 | torch.save(to_save, save_path) 643 | 644 | for model_name, model in self.models_pose.items(): 645 | save_path = os.path.join(save_folder, "{}.pth".format(model_name)) 646 | to_save = model.state_dict() 647 | torch.save(to_save, save_path) 648 | 649 | save_path = os.path.join(save_folder, "{}.pth".format("adam")) 650 | torch.save(self.model_optimizer.state_dict(), save_path) 651 | 652 | save_path = os.path.join(save_folder, "{}.pth".format("adam_pose")) 653 | if self.use_pose_net: 654 | torch.save(self.model_pose_optimizer.state_dict(), save_path) 655 | 656 | def load_pretrain(self): 657 | self.opt.mypretrain = os.path.expanduser(self.opt.mypretrain) 658 | path = self.opt.mypretrain 659 | model_dict = self.models["encoder"].state_dict() 660 | pretrained_dict = torch.load(path)['model'] 661 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict and not k.startswith('norm'))} 662 | model_dict.update(pretrained_dict) 663 | self.models["encoder"].load_state_dict(model_dict) 664 | print('mypretrain loaded.') 665 | 666 | def load_model(self): 667 | """Load model(s) from disk 668 | """ 669 | self.opt.load_weights_folder = os.path.expanduser(self.opt.load_weights_folder) 670 | 671 | assert os.path.isdir(self.opt.load_weights_folder), \ 672 | "Cannot find folder {}".format(self.opt.load_weights_folder) 673 | print("loading model from folder {}".format(self.opt.load_weights_folder)) 674 | 675 | for n in self.opt.models_to_load: 676 | print("Loading {} weights...".format(n)) 677 | path = os.path.join(self.opt.load_weights_folder, "{}.pth".format(n)) 678 | 679 | if n in ['pose_encoder', 'pose']: 680 | model_dict = self.models_pose[n].state_dict() 681 | pretrained_dict = torch.load(path) 682 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 683 | model_dict.update(pretrained_dict) 684 | self.models_pose[n].load_state_dict(model_dict) 685 | else: 686 | model_dict = self.models[n].state_dict() 687 | pretrained_dict = torch.load(path) 688 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 689 | model_dict.update(pretrained_dict) 690 | self.models[n].load_state_dict(model_dict) 691 | 692 | # loading adam state 693 | 694 | optimizer_load_path = os.path.join(self.opt.load_weights_folder, "adam.pth") 695 | optimizer_pose_load_path = os.path.join(self.opt.load_weights_folder, "adam_pose.pth") 696 | if os.path.isfile(optimizer_load_path): 697 | print("Loading Adam weights") 698 | optimizer_dict = torch.load(optimizer_load_path) 699 | optimizer_pose_dict = torch.load(optimizer_pose_load_path) 700 | self.model_optimizer.load_state_dict(optimizer_dict) 701 | self.model_pose_optimizer.load_state_dict(optimizer_pose_dict) 702 | else: 703 | print("Cannot find Adam weights so Adam is randomly initialized") 704 | 705 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import os 3 | import hashlib 4 | import zipfile 5 | from six.moves import urllib 6 | 7 | 8 | def readlines(filename): 9 | """Read all the lines in a text file and return as a list 10 | """ 11 | with open(filename, 'r') as f: 12 | lines = f.read().splitlines() 13 | return lines 14 | 15 | 16 | def normalize_image(x): 17 | """Rescale image pixels to span range [0, 1] 18 | """ 19 | ma = float(x.max().cpu().data) 20 | mi = float(x.min().cpu().data) 21 | d = ma - mi if ma != mi else 1e5 22 | return (x - mi) / d 23 | 24 | 25 | def sec_to_hm(t): 26 | """Convert time in seconds to time in hours, minutes and seconds 27 | e.g. 10239 -> (2, 50, 39) 28 | """ 29 | t = int(t) 30 | s = t % 60 31 | t //= 60 32 | m = t % 60 33 | t //= 60 34 | return t, m, s 35 | 36 | 37 | def sec_to_hm_str(t): 38 | """Convert time in seconds to a nice string 39 | e.g. 10239 -> '02h50m39s' 40 | """ 41 | h, m, s = sec_to_hm(t) 42 | return "{:02d}h{:02d}m{:02d}s".format(h, m, s) 43 | 44 | 45 | def download_model_if_doesnt_exist(model_name): 46 | """If pretrained kitti model doesn't exist, download and unzip it 47 | """ 48 | # values are tuples of (, ) 49 | download_paths = { 50 | "mono_640x192": 51 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono_640x192.zip", 52 | "a964b8356e08a02d009609d9e3928f7c"), 53 | "stereo_640x192": 54 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/stereo_640x192.zip", 55 | "3dfb76bcff0786e4ec07ac00f658dd07"), 56 | "mono+stereo_640x192": 57 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono%2Bstereo_640x192.zip", 58 | "c024d69012485ed05d7eaa9617a96b81"), 59 | "mono_no_pt_640x192": 60 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono_no_pt_640x192.zip", 61 | "9c2f071e35027c895a4728358ffc913a"), 62 | "stereo_no_pt_640x192": 63 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/stereo_no_pt_640x192.zip", 64 | "41ec2de112905f85541ac33a854742d1"), 65 | "mono+stereo_no_pt_640x192": 66 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono%2Bstereo_no_pt_640x192.zip", 67 | "46c3b824f541d143a45c37df65fbab0a"), 68 | "mono_1024x320": 69 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono_1024x320.zip", 70 | "0ab0766efdfeea89a0d9ea8ba90e1e63"), 71 | "stereo_1024x320": 72 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/stereo_1024x320.zip", 73 | "afc2f2126d70cf3fdf26b550898b501a"), 74 | "mono+stereo_1024x320": 75 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono%2Bstereo_1024x320.zip", 76 | "cdc5fc9b23513c07d5b19235d9ef08f7"), 77 | } 78 | 79 | if not os.path.exists("models"): 80 | os.makedirs("models") 81 | 82 | model_path = os.path.join("models", model_name) 83 | 84 | def check_file_matches_md5(checksum, fpath): 85 | if not os.path.exists(fpath): 86 | return False 87 | with open(fpath, 'rb') as f: 88 | current_md5checksum = hashlib.md5(f.read()).hexdigest() 89 | return current_md5checksum == checksum 90 | 91 | # see if we have the model already downloaded... 92 | if not os.path.exists(os.path.join(model_path, "encoder.pth")): 93 | 94 | model_url, required_md5checksum = download_paths[model_name] 95 | 96 | if not check_file_matches_md5(required_md5checksum, model_path + ".zip"): 97 | print("-> Downloading pretrained model to {}".format(model_path + ".zip")) 98 | urllib.request.urlretrieve(model_url, model_path + ".zip") 99 | 100 | if not check_file_matches_md5(required_md5checksum, model_path + ".zip"): 101 | print(" Failed to download a file which matches the checksum - quitting") 102 | quit() 103 | 104 | print(" Unzipping model...") 105 | with zipfile.ZipFile(model_path + ".zip", 'r') as f: 106 | f.extractall(model_path) 107 | 108 | print(" Model unzipped to {}".format(model_path)) 109 | --------------------------------------------------------------------------------