├── .gitignore
├── README.md
├── datasets
├── __init__.py
├── kitti_dataset.py
└── mono_dataset.py
├── evaluate_depth.py
├── img
├── overview.png
├── robustness.png
├── speed.png
└── teaser_m.gif
├── kitti_utils.py
├── layers.py
├── license
├── lite-mono-pretrain-code
├── README.md
├── datasets.py
├── engine.py
├── main.py
├── models
│ └── litemono.py
├── optim_factory.py
├── requirements.txt
├── sampler.py
└── utils.py
├── networks
├── __init__.py
├── depth_decoder.py
├── depth_encoder.py
├── pose_decoder.py
└── resnet_encoder.py
├── options.py
├── splits
└── eigen
│ └── test_files.txt
├── test_simple.py
├── train.py
├── trainer.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .idea/
106 | .env
107 | .venv
108 | env/
109 | venv/
110 | ENV/
111 | env.bak/
112 | venv.bak/
113 |
114 | # Spyder project settings
115 | .spyderproject
116 | .spyproject
117 |
118 | # Rope project settings
119 | .ropeproject
120 |
121 | # mkdocs documentation
122 | /site
123 |
124 | # mypy
125 | .mypy_cache/
126 | .dmypy.json
127 | dmypy.json
128 |
129 | # Pyre type checker
130 | .pyre/
131 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Lite-Mono
4 | **A Lightweight CNN and Transformer Architecture for Self-Supervised Monocular Depth Estimation**
5 | [[paper link]](https://arxiv.org/abs/2211.13202)
6 |
7 | Ning Zhang*, Francesco Nex, George Vosselman, Norman Kerle
8 |
9 |
10 |
11 |
12 |
13 |

14 |
15 | (Lite-Mono-8m 1024x320)
16 |
17 |
18 |
19 |
20 | ## Table of Contents
21 | - [Overview](#overview)
22 | - [Results](#results)
23 | - [KITTI](#kitti)
24 | - [Speed Evaluation](#speed-evaluation)
25 | - [Robustness](#robustness)
26 | - [Data Preparation](#data-preparation)
27 | - [Single Image Test](#single-image-test)
28 | - [Preparing Trained Model](#preparing-trained-model)
29 | - [Start Testing](#start-testing)
30 | - [Evaluation](#evaluation)
31 | - [Training](#training)
32 | - [Dependency Installation](#dependency-installation)
33 | - [Preparing Pre-trained Weights](#preparing-pre-trained-weights)
34 | - [Start Training](#start-training)
35 | - [Tensorboard Visualization](#tensorboard-visualization)
36 | - [Make Your Own Pre-training Weights On ImageNet](#make-your-own-pre-training-weights-on-imagenet)
37 | - [Citation](#citation)
38 |
39 |
40 | ## Overview
41 |
42 |
43 |
44 | ## Results
45 | ### KITTI
46 | You can download the trained models using the links below.
47 |
48 | | --model | Params | ImageNet Pretrained | Input size | Abs Rel | Sq Rel | RMSE | RMSE log | delta < 1.25 | delta < 1.25^2 | delta < 1.25^3 |
49 | |:---------------:|:------:|:-------------------:|:----------:|:---------:|:---------:|:---------:|:---------:|:------------:|:--------------:|:--------------:|
50 | | [**lite-mono**](https://surfdrive.surf.nl/files/index.php/s/CUjiK221EFLyXDY) | 3.1M | [yes](https://surfdrive.surf.nl/files/index.php/s/InMMGd5ZP2fXuia) | 640x192 | 0.107 | 0.765 | 4.561 | 0.183 | 0.886 | 0.963 | 0.983 |
51 | | [lite-mono-small](https://surfdrive.surf.nl/files/index.php/s/8cuZNH1CkNtQwxQ) | 2.5M | [yes](https://surfdrive.surf.nl/files/index.php/s/DYbWV4bsWImfJKu) | 640x192 | 0.110 | 0.802 | 4.671 | 0.186 | 0.879 | 0.961 | 0.982 |
52 | | [lite-mono-tiny](https://surfdrive.surf.nl/files/index.php/s/TFDlF3wYQy0Nhmg) | 2.2M | yes | 640x192 | 0.110 | 0.837 | 4.710 | 0.187 | 0.880 | 0.960 | 0.982 |
53 | | [**lite-mono-8m**](https://surfdrive.surf.nl/files/index.php/s/UlkVBi1p99NFWWI) | 8.7M | [yes](https://surfdrive.surf.nl/files/index.php/s/oil2ME6ymoLGDlL) | 640x192 | 0.101 | 0.729 | 4.454 | 0.178 | 0.897 | 0.965 | 0.983 |
54 | | [**lite-mono**](https://surfdrive.surf.nl/files/index.php/s/IK3VtPj6b5FkVnl) | 3.1M | yes | 1024x320 | 0.102 | 0.746 | 4.444 | 0.179 | 0.896 | 0.965 | 0.983 |
55 | | [lite-mono-small](https://surfdrive.surf.nl/files/index.php/s/w8mvJMkB1dP15pu) | 2.5M | yes | 1024x320 | 0.103 | 0.757 | 4.449 | 0.180 | 0.894 | 0.964 | 0.983 |
56 | | [lite-mono-tiny](https://surfdrive.surf.nl/files/index.php/s/myxcplTciOkgu5w) | 2.2M | yes | 1024x320 | 0.104 | 0.764 | 4.487 | 0.180 | 0.892 | 0.964 | 0.983 |
57 | | [**lite-mono-8m**](https://surfdrive.surf.nl/files/index.php/s/mgonNFAvoEJmMas) | 8.7M | yes | 1024x320 | 0.097 | 0.710 | 4.309 | 0.174 | 0.905 | 0.967 | 0.984 |
58 |
59 |
60 | ### Speed Evaluation
61 |
62 |
63 |
64 | ### Robustness
65 |
66 |
67 | The [RoboDepth Challenge Team](https://github.com/ldkong1205/RoboDepth) is evaluating the robustness of different depth estimation algorithms. Lite-Mono has achieved the best robustness to date.
68 |
69 | ## Data Preparation
70 | Please refer to [Monodepth2](https://github.com/nianticlabs/monodepth2) to prepare your KITTI data.
71 |
72 | ## Single Image Test
73 | #### preparing trained model
74 | From this [table](#kitti) you can download trained models (depth encoder and depth decoder).
75 |
76 | Click on the links in the '--model' column to download a trained model.
77 |
78 | #### start testing
79 | python test_simple.py --load_weights_folder path/to/your/weights/folder --image_path path/to/your/test/image
80 |
81 | ## Evaluation
82 | python evaluate_depth.py --load_weights_folder path/to/your/weights/folder --data_path path/to/kitti_data/ --model lite-mono
83 |
84 |
85 | ## Training
86 | #### dependency installation
87 | pip install 'git+https://github.com/saadnaeem-dev/pytorch-linear-warmup-cosine-annealing-warm-restarts-weight-decay'
88 |
89 | #### preparing pre-trained weights
90 | From this [table](#kitti) you can also download weights of backbone (depth encoder) pre-trained on ImageNet.
91 |
92 | Click 'yes' on a row to download specific pre-trained weights. The weights are agnostic to image resolutions.
93 |
94 | #### start training
95 | python train.py --data_path path/to/your/data --model_name mytrain --num_epochs 30 --batch_size 12 --mypretrain path/to/your/pretrained/weights --lr 0.0001 5e-6 31 0.0001 1e-5 31
96 |
97 | #### tensorboard visualization
98 | tensorboard --log_dir ./tmp/mytrain
99 |
100 | ## Make Your Own Pre-training Weights On ImageNet
101 | Since a lot of people are interested in training their own backbone on ImageNet, I also upload my pre-training [scripts](lite-mono-pretrain-code) to this repo.
102 |
103 |
104 | ## Citation
105 |
106 | @InProceedings{Zhang_2023_CVPR,
107 | author = {Zhang, Ning and Nex, Francesco and Vosselman, George and Kerle, Norman},
108 | title = {Lite-Mono: A Lightweight CNN and Transformer Architecture for Self-Supervised Monocular Depth Estimation},
109 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
110 | month = {June},
111 | year = {2023},
112 | pages = {18537-18546}
113 | }
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .kitti_dataset import KITTIRAWDataset, KITTIOdomDataset, KITTIDepthDataset
2 |
--------------------------------------------------------------------------------
/datasets/kitti_dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import os
4 | import skimage.transform
5 | import numpy as np
6 | import PIL.Image as pil
7 |
8 | from kitti_utils import generate_depth_map
9 | from .mono_dataset import MonoDataset
10 |
11 |
12 | class KITTIDataset(MonoDataset):
13 | """Superclass for different types of KITTI dataset loaders
14 | """
15 | def __init__(self, *args, **kwargs):
16 | super(KITTIDataset, self).__init__(*args, **kwargs)
17 |
18 | self.K = np.array([[0.58, 0, 0.5, 0],
19 | [0, 1.92, 0.5, 0],
20 | [0, 0, 1, 0],
21 | [0, 0, 0, 1]], dtype=np.float32)
22 |
23 | self.full_res_shape = (1242, 375)
24 | self.side_map = {"2": 2, "3": 3, "l": 2, "r": 3}
25 |
26 | def check_depth(self):
27 | line = self.filenames[0].split()
28 | scene_name = line[0]
29 | frame_index = int(line[1])
30 |
31 | velo_filename = os.path.join(
32 | self.data_path,
33 | scene_name,
34 | "velodyne_points/data/{:010d}.bin".format(int(frame_index)))
35 |
36 | return os.path.isfile(velo_filename)
37 |
38 | def get_color(self, folder, frame_index, side, do_flip):
39 | color = self.loader(self.get_image_path(folder, frame_index, side))
40 |
41 | if do_flip:
42 | color = color.transpose(pil.FLIP_LEFT_RIGHT)
43 |
44 | return color
45 |
46 |
47 | class KITTIRAWDataset(KITTIDataset):
48 | """KITTI dataset which loads the original velodyne depth maps for ground truth
49 | """
50 | def __init__(self, *args, **kwargs):
51 | super(KITTIRAWDataset, self).__init__(*args, **kwargs)
52 |
53 | def get_image_path(self, folder, frame_index, side):
54 | f_str = "{:010d}{}".format(frame_index, self.img_ext)
55 | image_path = os.path.join(
56 | self.data_path, folder, "image_0{}/data".format(self.side_map[side]), f_str)
57 | return image_path
58 |
59 | def get_depth(self, folder, frame_index, side, do_flip):
60 | calib_path = os.path.join(self.data_path, folder.split("/")[0])
61 |
62 | velo_filename = os.path.join(
63 | self.data_path,
64 | folder,
65 | "velodyne_points/data/{:010d}.bin".format(int(frame_index)))
66 |
67 | depth_gt = generate_depth_map(calib_path, velo_filename, self.side_map[side])
68 | depth_gt = skimage.transform.resize(
69 | depth_gt, self.full_res_shape[::-1], order=0, preserve_range=True, mode='constant')
70 |
71 | if do_flip:
72 | depth_gt = np.fliplr(depth_gt)
73 |
74 | return depth_gt
75 |
76 |
77 | class KITTIOdomDataset(KITTIDataset):
78 | """KITTI dataset for odometry training and testing
79 | """
80 | def __init__(self, *args, **kwargs):
81 | super(KITTIOdomDataset, self).__init__(*args, **kwargs)
82 |
83 | def get_image_path(self, folder, frame_index, side):
84 | f_str = "{:06d}{}".format(frame_index, self.img_ext)
85 | image_path = os.path.join(
86 | self.data_path,
87 | "sequences/{:02d}".format(int(folder)),
88 | "image_{}".format(self.side_map[side]),
89 | f_str)
90 | return image_path
91 |
92 |
93 | class KITTIDepthDataset(KITTIDataset):
94 | """KITTI dataset which uses the updated ground truth depth maps
95 | """
96 | def __init__(self, *args, **kwargs):
97 | super(KITTIDepthDataset, self).__init__(*args, **kwargs)
98 |
99 | def get_image_path(self, folder, frame_index, side):
100 | f_str = "{:010d}{}".format(frame_index, self.img_ext)
101 | image_path = os.path.join(
102 | self.data_path,
103 | folder,
104 | "image_0{}/data".format(self.side_map[side]),
105 | f_str)
106 | return image_path
107 |
108 | def get_depth(self, folder, frame_index, side, do_flip):
109 | f_str = "{:010d}.png".format(frame_index)
110 | depth_path = os.path.join(
111 | self.data_path,
112 | folder,
113 | "proj_depth/groundtruth/image_0{}".format(self.side_map[side]),
114 | f_str)
115 |
116 | depth_gt = pil.open(depth_path)
117 | depth_gt = depth_gt.resize(self.full_res_shape, pil.NEAREST)
118 | depth_gt = np.array(depth_gt).astype(np.float32) / 256
119 |
120 | if do_flip:
121 | depth_gt = np.fliplr(depth_gt)
122 |
123 | return depth_gt
124 |
--------------------------------------------------------------------------------
/datasets/mono_dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import os
4 | import random
5 | import numpy as np
6 | import copy
7 | from PIL import Image # using pillow-simd for increased speed
8 |
9 | import torch
10 | import torch.utils.data as data
11 | from torchvision import transforms
12 |
13 |
14 | def pil_loader(path):
15 | with open(path, 'rb') as f:
16 | with Image.open(f) as img:
17 | return img.convert('RGB')
18 |
19 |
20 | class MonoDataset(data.Dataset):
21 | """Superclass for monocular dataloaders
22 |
23 | Args:
24 | data_path
25 | filenames
26 | height
27 | width
28 | frame_idxs
29 | num_scales
30 | is_train
31 | img_ext
32 | """
33 | def __init__(self,
34 | data_path,
35 | filenames,
36 | height,
37 | width,
38 | frame_idxs,
39 | num_scales,
40 | is_train=False,
41 | img_ext='.jpg'):
42 | super(MonoDataset, self).__init__()
43 |
44 | self.data_path = data_path
45 | self.filenames = filenames
46 | self.height = height
47 | self.width = width
48 | self.num_scales = num_scales
49 | self.interp = Image.ANTIALIAS
50 |
51 | self.frame_idxs = frame_idxs
52 |
53 | self.is_train = is_train
54 | self.img_ext = img_ext
55 |
56 | self.loader = pil_loader
57 | self.to_tensor = transforms.ToTensor()
58 |
59 | # We need to specify augmentations differently in newer versions of torchvision.
60 | # We first try the newer tuple version; if this fails we fall back to scalars
61 | try:
62 | self.brightness = (0.8, 1.2)
63 | self.contrast = (0.8, 1.2)
64 | self.saturation = (0.8, 1.2)
65 | self.hue = (-0.1, 0.1)
66 | transforms.ColorJitter.get_params(
67 | self.brightness, self.contrast, self.saturation, self.hue)
68 | except TypeError:
69 | self.brightness = 0.2
70 | self.contrast = 0.2
71 | self.saturation = 0.2
72 | self.hue = 0.1
73 |
74 | self.resize = {}
75 | for i in range(self.num_scales):
76 | s = 2 ** i
77 | self.resize[i] = transforms.Resize((self.height // s, self.width // s),
78 | interpolation=self.interp)
79 |
80 | self.load_depth = self.check_depth()
81 |
82 | def preprocess(self, inputs, color_aug):
83 | """Resize colour images to the required scales and augment if required
84 |
85 | We create the color_aug object in advance and apply the same augmentation to all
86 | images in this item. This ensures that all images input to the pose network receive the
87 | same augmentation.
88 | """
89 | for k in list(inputs):
90 | frame = inputs[k]
91 | if "color" in k:
92 | n, im, i = k
93 | for i in range(self.num_scales):
94 | inputs[(n, im, i)] = self.resize[i](inputs[(n, im, i - 1)])
95 |
96 | for k in list(inputs):
97 | f = inputs[k]
98 | if "color" in k:
99 | n, im, i = k
100 | inputs[(n, im, i)] = self.to_tensor(f)
101 | inputs[(n + "_aug", im, i)] = self.to_tensor(color_aug(f))
102 |
103 | def __len__(self):
104 | return len(self.filenames)
105 |
106 | def __getitem__(self, index):
107 | """Returns a single training item from the dataset as a dictionary.
108 |
109 | Values correspond to torch tensors.
110 | Keys in the dictionary are either strings or tuples:
111 |
112 | ("color", , ) for raw colour images,
113 | ("color_aug", , ) for augmented colour images,
114 | ("K", scale) or ("inv_K", scale) for camera intrinsics,
115 | "stereo_T" for camera extrinsics, and
116 | "depth_gt" for ground truth depth maps.
117 |
118 | is either:
119 | an integer (e.g. 0, -1, or 1) representing the temporal step relative to 'index',
120 | or
121 | "s" for the opposite image in the stereo pair.
122 |
123 | is an integer representing the scale of the image relative to the fullsize image:
124 | -1 images at native resolution as loaded from disk
125 | 0 images resized to (self.width, self.height )
126 | 1 images resized to (self.width // 2, self.height // 2)
127 | 2 images resized to (self.width // 4, self.height // 4)
128 | 3 images resized to (self.width // 8, self.height // 8)
129 | """
130 | inputs = {}
131 |
132 | do_color_aug = self.is_train and random.random() > 0.5
133 | do_flip = self.is_train and random.random() > 0.5
134 |
135 | line = self.filenames[index].split()
136 | folder = line[0]
137 |
138 | if len(line) == 3:
139 | frame_index = int(line[1])
140 | else:
141 | frame_index = 0
142 |
143 | if len(line) == 3:
144 | side = line[2]
145 | else:
146 | side = None
147 |
148 | for i in self.frame_idxs:
149 | if i == "s":
150 | other_side = {"r": "l", "l": "r"}[side]
151 | inputs[("color", i, -1)] = self.get_color(folder, frame_index, other_side, do_flip)
152 | else:
153 | inputs[("color", i, -1)] = self.get_color(folder, frame_index + i, side, do_flip)
154 |
155 | # adjusting intrinsics to match each scale in the pyramid
156 | for scale in range(self.num_scales):
157 | K = self.K.copy()
158 |
159 | K[0, :] *= self.width // (2 ** scale)
160 | K[1, :] *= self.height // (2 ** scale)
161 |
162 | inv_K = np.linalg.pinv(K)
163 |
164 | inputs[("K", scale)] = torch.from_numpy(K)
165 | inputs[("inv_K", scale)] = torch.from_numpy(inv_K)
166 |
167 | if do_color_aug:
168 | # color_aug = transforms.ColorJitter.get_params(
169 | # self.brightness, self.contrast, self.saturation, self.hue)
170 | color_aug = transforms.ColorJitter(
171 | self.brightness, self.contrast, self.saturation, self.hue)
172 | else:
173 | color_aug = (lambda x: x)
174 |
175 | self.preprocess(inputs, color_aug)
176 |
177 | for i in self.frame_idxs:
178 | del inputs[("color", i, -1)]
179 | del inputs[("color_aug", i, -1)]
180 |
181 | if self.load_depth:
182 | depth_gt = self.get_depth(folder, frame_index, side, do_flip)
183 | inputs["depth_gt"] = np.expand_dims(depth_gt, 0)
184 | inputs["depth_gt"] = torch.from_numpy(inputs["depth_gt"].astype(np.float32))
185 |
186 | if "s" in self.frame_idxs:
187 | stereo_T = np.eye(4, dtype=np.float32)
188 | baseline_sign = -1 if do_flip else 1
189 | side_sign = -1 if side == "l" else 1
190 | stereo_T[0, 3] = side_sign * baseline_sign * 0.1
191 |
192 | inputs["stereo_T"] = torch.from_numpy(stereo_T)
193 |
194 | return inputs
195 |
196 | def get_color(self, folder, frame_index, side, do_flip):
197 | raise NotImplementedError
198 |
199 | def check_depth(self):
200 | raise NotImplementedError
201 |
202 | def get_depth(self, folder, frame_index, side, do_flip):
203 | raise NotImplementedError
204 |
--------------------------------------------------------------------------------
/evaluate_depth.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | import os
3 | import cv2
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import DataLoader
7 | from layers import disp_to_depth
8 | from utils import readlines
9 | from options import LiteMonoOptions
10 | import datasets
11 | import networks
12 | import time
13 | from thop import clever_format
14 | from thop import profile
15 |
16 |
17 | cv2.setNumThreads(0) # This speeds up evaluation 5x on our unix systems (OpenCV 3.3.1)
18 |
19 | splits_dir = os.path.join(os.path.dirname(__file__), "splits")
20 |
21 |
22 | def profile_once(encoder, decoder, x):
23 | x_e = x[0, :, :, :].unsqueeze(0)
24 | x_d = encoder(x_e)
25 | flops_e, params_e = profile(encoder, inputs=(x_e, ), verbose=False)
26 | flops_d, params_d = profile(decoder, inputs=(x_d, ), verbose=False)
27 |
28 | flops, params = clever_format([flops_e + flops_d, params_e + params_d], "%.3f")
29 | flops_e, params_e = clever_format([flops_e, params_e], "%.3f")
30 | flops_d, params_d = clever_format([flops_d, params_d], "%.3f")
31 |
32 | return flops, params, flops_e, params_e, flops_d, params_d
33 |
34 |
35 | def time_sync():
36 | # PyTorch-accurate time
37 | if torch.cuda.is_available():
38 | torch.cuda.synchronize()
39 | return time.time()
40 |
41 |
42 | def compute_errors(gt, pred):
43 | """Computation of error metrics between predicted and ground truth depths
44 | """
45 | thresh = np.maximum((gt / pred), (pred / gt))
46 | a1 = (thresh < 1.25 ).mean()
47 | a2 = (thresh < 1.25 ** 2).mean()
48 | a3 = (thresh < 1.25 ** 3).mean()
49 |
50 | rmse = (gt - pred) ** 2
51 | rmse = np.sqrt(rmse.mean())
52 |
53 | rmse_log = (np.log(gt) - np.log(pred)) ** 2
54 | rmse_log = np.sqrt(rmse_log.mean())
55 |
56 | abs_rel = np.mean(np.abs(gt - pred) / gt)
57 |
58 | sq_rel = np.mean(((gt - pred) ** 2) / gt)
59 |
60 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3
61 |
62 |
63 | def batch_post_process_disparity(l_disp, r_disp):
64 | """Apply the disparity post-processing method as introduced in Monodepthv1
65 | """
66 | _, h, w = l_disp.shape
67 | m_disp = 0.5 * (l_disp + r_disp)
68 | l, _ = np.meshgrid(np.linspace(0, 1, w), np.linspace(0, 1, h))
69 | l_mask = (1.0 - np.clip(20 * (l - 0.05), 0, 1))[None, ...]
70 | r_mask = l_mask[:, :, ::-1]
71 | return r_mask * l_disp + l_mask * r_disp + (1.0 - l_mask - r_mask) * m_disp
72 |
73 |
74 | def evaluate(opt):
75 | """Evaluates a pretrained model using a specified test set
76 | """
77 | MIN_DEPTH = 1e-3
78 | MAX_DEPTH = 80
79 |
80 | if opt.ext_disp_to_eval is None:
81 |
82 | opt.load_weights_folder = os.path.expanduser(opt.load_weights_folder)
83 |
84 | assert os.path.isdir(opt.load_weights_folder), \
85 | "Cannot find a folder at {}".format(opt.load_weights_folder)
86 |
87 | print("-> Loading weights from {}".format(opt.load_weights_folder))
88 |
89 | filenames = readlines(os.path.join(splits_dir, opt.eval_split, "test_files.txt"))
90 | encoder_path = os.path.join(opt.load_weights_folder, "encoder.pth")
91 | decoder_path = os.path.join(opt.load_weights_folder, "depth.pth")
92 |
93 | encoder_dict = torch.load(encoder_path)
94 | decoder_dict = torch.load(decoder_path)
95 |
96 | dataset = datasets.KITTIRAWDataset(opt.data_path, filenames,
97 | encoder_dict['height'], encoder_dict['width'],
98 | [0], 4, is_train=False)
99 | dataloader = DataLoader(dataset, 16, shuffle=False, num_workers=opt.num_workers,
100 | pin_memory=True, drop_last=False)
101 |
102 | encoder = networks.LiteMono(model=opt.model,
103 | height=encoder_dict['height'],
104 | width=encoder_dict['width'])
105 | depth_decoder = networks.DepthDecoder(encoder.num_ch_enc, scales=range(3))
106 | model_dict = encoder.state_dict()
107 | depth_model_dict = depth_decoder.state_dict()
108 | encoder.load_state_dict({k: v for k, v in encoder_dict.items() if k in model_dict})
109 | depth_decoder.load_state_dict({k: v for k, v in decoder_dict.items() if k in depth_model_dict})
110 |
111 | encoder.cuda()
112 | encoder.eval()
113 | depth_decoder.cuda()
114 | depth_decoder.eval()
115 |
116 | pred_disps = []
117 |
118 | print("-> Computing predictions with size {}x{}".format(
119 | encoder_dict['width'], encoder_dict['height']))
120 |
121 | with torch.no_grad():
122 | for data in dataloader:
123 | input_color = data[("color", 0, 0)].cuda()
124 |
125 | if opt.post_process:
126 | # Post-processed results require each image to have two forward passes
127 | input_color = torch.cat((input_color, torch.flip(input_color, [3])), 0)
128 |
129 | flops, params, flops_e, params_e, flops_d, params_d = profile_once(encoder, depth_decoder, input_color)
130 | t1 = time_sync()
131 | output = depth_decoder(encoder(input_color))
132 | t2 = time_sync()
133 |
134 | pred_disp, _ = disp_to_depth(output[("disp", 0)], opt.min_depth, opt.max_depth)
135 | pred_disp = pred_disp.cpu()[:, 0].numpy()
136 |
137 | if opt.post_process:
138 | N = pred_disp.shape[0] // 2
139 | pred_disp = batch_post_process_disparity(pred_disp[:N], pred_disp[N:, :, ::-1])
140 |
141 | pred_disps.append(pred_disp)
142 |
143 | pred_disps = np.concatenate(pred_disps)
144 |
145 | else:
146 | # Load predictions from file
147 | print("-> Loading predictions from {}".format(opt.ext_disp_to_eval))
148 | pred_disps = np.load(opt.ext_disp_to_eval)
149 |
150 | if opt.eval_eigen_to_benchmark:
151 | eigen_to_benchmark_ids = np.load(
152 | os.path.join(splits_dir, "benchmark", "eigen_to_benchmark_ids.npy"))
153 |
154 | pred_disps = pred_disps[eigen_to_benchmark_ids]
155 |
156 | if opt.save_pred_disps:
157 | output_path = os.path.join(
158 | opt.load_weights_folder, "disps_{}_split.npy".format(opt.eval_split))
159 | print("-> Saving predicted disparities to ", output_path)
160 | np.save(output_path, pred_disps)
161 |
162 | if opt.no_eval:
163 | print("-> Evaluation disabled. Done.")
164 | quit()
165 |
166 | gt_path = os.path.join(splits_dir, opt.eval_split, "gt_depths.npz")
167 | gt_depths = np.load(gt_path, fix_imports=True, encoding='latin1', allow_pickle=True)["data"]
168 |
169 | print("-> Evaluating")
170 | print(" Mono evaluation - using median scaling")
171 |
172 | errors = []
173 | ratios = []
174 |
175 | for i in range(pred_disps.shape[0]):
176 | gt_depth = gt_depths[i]
177 | gt_height, gt_width = gt_depth.shape[:2]
178 |
179 | pred_disp = pred_disps[i]
180 | pred_disp = cv2.resize(pred_disp, (gt_width, gt_height))
181 | pred_depth = 1 / pred_disp
182 |
183 | if opt.eval_split == "eigen":
184 | mask = np.logical_and(gt_depth > MIN_DEPTH, gt_depth < MAX_DEPTH)
185 |
186 | crop = np.array([0.40810811 * gt_height, 0.99189189 * gt_height,
187 | 0.03594771 * gt_width, 0.96405229 * gt_width]).astype(np.int32)
188 | crop_mask = np.zeros(mask.shape)
189 | crop_mask[crop[0]:crop[1], crop[2]:crop[3]] = 1
190 | mask = np.logical_and(mask, crop_mask)
191 |
192 | else:
193 | mask = gt_depth > 0
194 |
195 | pred_depth = pred_depth[mask]
196 | gt_depth = gt_depth[mask]
197 |
198 | pred_depth *= opt.pred_depth_scale_factor
199 | if not opt.disable_median_scaling:
200 | ratio = np.median(gt_depth) / np.median(pred_depth)
201 | ratios.append(ratio)
202 | pred_depth *= ratio
203 |
204 | pred_depth[pred_depth < MIN_DEPTH] = MIN_DEPTH
205 | pred_depth[pred_depth > MAX_DEPTH] = MAX_DEPTH
206 | errors.append(compute_errors(gt_depth, pred_depth))
207 |
208 | if not opt.disable_median_scaling:
209 | ratios = np.array(ratios)
210 | med = np.median(ratios)
211 | print(" Scaling ratios | med: {:0.3f} | std: {:0.3f}".format(med, np.std(ratios / med)))
212 |
213 | mean_errors = np.array(errors).mean(0)
214 |
215 | print("\n " + ("{:>8} | " * 7).format("abs_rel", "sq_rel", "rmse", "rmse_log", "a1", "a2", "a3"))
216 | print(("&{: 8.3f} " * 7).format(*mean_errors.tolist()) + "\\\\")
217 | print("\n " + ("flops: {0}, params: {1}, flops_e: {2}, params_e:{3}, flops_d:{4}, params_d:{5}").format(flops, params, flops_e, params_e, flops_d, params_d))
218 | print("\n-> Done!")
219 |
220 |
221 | if __name__ == "__main__":
222 | options = LiteMonoOptions()
223 | evaluate(options.parse())
224 |
--------------------------------------------------------------------------------
/img/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/noahzn/Lite-Mono/4874b35df8ed4da16159ce8be8c697028b72bf76/img/overview.png
--------------------------------------------------------------------------------
/img/robustness.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/noahzn/Lite-Mono/4874b35df8ed4da16159ce8be8c697028b72bf76/img/robustness.png
--------------------------------------------------------------------------------
/img/speed.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/noahzn/Lite-Mono/4874b35df8ed4da16159ce8be8c697028b72bf76/img/speed.png
--------------------------------------------------------------------------------
/img/teaser_m.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/noahzn/Lite-Mono/4874b35df8ed4da16159ce8be8c697028b72bf76/img/teaser_m.gif
--------------------------------------------------------------------------------
/kitti_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | import os
3 | import numpy as np
4 | from collections import Counter
5 |
6 |
7 | def load_velodyne_points(filename):
8 | """Load 3D point cloud from KITTI file format
9 | (adapted from https://github.com/hunse/kitti)
10 | """
11 | points = np.fromfile(filename, dtype=np.float32).reshape(-1, 4)
12 | points[:, 3] = 1.0 # homogeneous
13 | return points
14 |
15 |
16 | def read_calib_file(path):
17 | """Read KITTI calibration file
18 | (from https://github.com/hunse/kitti)
19 | """
20 | float_chars = set("0123456789.e+- ")
21 | data = {}
22 | with open(path, 'r') as f:
23 | for line in f.readlines():
24 | key, value = line.split(':', 1)
25 | value = value.strip()
26 | data[key] = value
27 | if float_chars.issuperset(value):
28 | # try to cast to float array
29 | try:
30 | data[key] = np.array(list(map(float, value.split(' '))))
31 | except ValueError:
32 | # casting error: data[key] already eq. value, so pass
33 | pass
34 |
35 | return data
36 |
37 |
38 | def sub2ind(matrixSize, rowSub, colSub):
39 | """Convert row, col matrix subscripts to linear indices
40 | """
41 | m, n = matrixSize
42 | return rowSub * (n-1) + colSub - 1
43 |
44 |
45 | def generate_depth_map(calib_dir, velo_filename, cam=2, vel_depth=False):
46 | """Generate a depth map from velodyne data
47 | """
48 | # load calibration files
49 | cam2cam = read_calib_file(os.path.join(calib_dir, 'calib_cam_to_cam.txt'))
50 | velo2cam = read_calib_file(os.path.join(calib_dir, 'calib_velo_to_cam.txt'))
51 | velo2cam = np.hstack((velo2cam['R'].reshape(3, 3), velo2cam['T'][..., np.newaxis]))
52 | velo2cam = np.vstack((velo2cam, np.array([0, 0, 0, 1.0])))
53 |
54 | # get image shape
55 | im_shape = cam2cam["S_rect_02"][::-1].astype(np.int32)
56 |
57 | # compute projection matrix velodyne->image plane
58 | R_cam2rect = np.eye(4)
59 | R_cam2rect[:3, :3] = cam2cam['R_rect_00'].reshape(3, 3)
60 | P_rect = cam2cam['P_rect_0'+str(cam)].reshape(3, 4)
61 | P_velo2im = np.dot(np.dot(P_rect, R_cam2rect), velo2cam)
62 |
63 | # load velodyne points and remove all behind image plane (approximation)
64 | # each row of the velodyne data is forward, left, up, reflectance
65 | velo = load_velodyne_points(velo_filename)
66 | velo = velo[velo[:, 0] >= 0, :]
67 |
68 | # project the points to the camera
69 | velo_pts_im = np.dot(P_velo2im, velo.T).T
70 | velo_pts_im[:, :2] = velo_pts_im[:, :2] / velo_pts_im[:, 2][..., np.newaxis]
71 |
72 | if vel_depth:
73 | velo_pts_im[:, 2] = velo[:, 0]
74 |
75 | # check if in bounds
76 | # use minus 1 to get the exact same value as KITTI matlab code
77 | velo_pts_im[:, 0] = np.round(velo_pts_im[:, 0]) - 1
78 | velo_pts_im[:, 1] = np.round(velo_pts_im[:, 1]) - 1
79 | val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0)
80 | val_inds = val_inds & (velo_pts_im[:, 0] < im_shape[1]) & (velo_pts_im[:, 1] < im_shape[0])
81 | velo_pts_im = velo_pts_im[val_inds, :]
82 |
83 | # project to image
84 | depth = np.zeros((im_shape[:2]))
85 | depth[velo_pts_im[:, 1].astype(np.int), velo_pts_im[:, 0].astype(np.int)] = velo_pts_im[:, 2]
86 |
87 | # find the duplicate points and choose the closest depth
88 | inds = sub2ind(depth.shape, velo_pts_im[:, 1], velo_pts_im[:, 0])
89 | dupe_inds = [item for item, count in Counter(inds).items() if count > 1]
90 | for dd in dupe_inds:
91 | pts = np.where(inds == dd)[0]
92 | x_loc = int(velo_pts_im[pts[0], 0])
93 | y_loc = int(velo_pts_im[pts[0], 1])
94 | depth[y_loc, x_loc] = velo_pts_im[pts, 2].min()
95 | depth[depth < 0] = 0
96 |
97 | return depth
98 |
--------------------------------------------------------------------------------
/layers.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import numpy as np
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import math
9 |
10 |
11 | def disp_to_depth(disp, min_depth, max_depth):
12 | """Convert network's sigmoid output into depth prediction
13 | The formula for this conversion is given in the 'additional considerations'
14 | section of the paper.
15 | """
16 | min_disp = 1 / max_depth
17 | max_disp = 1 / min_depth
18 | scaled_disp = min_disp + (max_disp - min_disp) * disp
19 | depth = 1 / scaled_disp
20 | return scaled_disp, depth
21 |
22 |
23 | def transformation_from_parameters(axisangle, translation, invert=False):
24 | """Convert the network's (axisangle, translation) output into a 4x4 matrix
25 | """
26 | R = rot_from_axisangle(axisangle)
27 | t = translation.clone()
28 |
29 | if invert:
30 | R = R.transpose(1, 2)
31 | t *= -1
32 |
33 | T = get_translation_matrix(t)
34 |
35 | if invert:
36 | M = torch.matmul(R, T)
37 | else:
38 | M = torch.matmul(T, R)
39 |
40 | return M
41 |
42 |
43 | def get_translation_matrix(translation_vector):
44 | """Convert a translation vector into a 4x4 transformation matrix
45 | """
46 | T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device)
47 |
48 | t = translation_vector.contiguous().view(-1, 3, 1)
49 |
50 | T[:, 0, 0] = 1
51 | T[:, 1, 1] = 1
52 | T[:, 2, 2] = 1
53 | T[:, 3, 3] = 1
54 | T[:, :3, 3, None] = t
55 |
56 | return T
57 |
58 |
59 | def rot_from_axisangle(vec):
60 | """Convert an axisangle rotation into a 4x4 transformation matrix
61 | (adapted from https://github.com/Wallacoloo/printipi)
62 | Input 'vec' has to be Bx1x3
63 | """
64 | angle = torch.norm(vec, 2, 2, True)
65 | axis = vec / (angle + 1e-7)
66 |
67 | ca = torch.cos(angle)
68 | sa = torch.sin(angle)
69 | C = 1 - ca
70 |
71 | x = axis[..., 0].unsqueeze(1)
72 | y = axis[..., 1].unsqueeze(1)
73 | z = axis[..., 2].unsqueeze(1)
74 |
75 | xs = x * sa
76 | ys = y * sa
77 | zs = z * sa
78 | xC = x * C
79 | yC = y * C
80 | zC = z * C
81 | xyC = x * yC
82 | yzC = y * zC
83 | zxC = z * xC
84 |
85 | rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device)
86 |
87 | rot[:, 0, 0] = torch.squeeze(x * xC + ca)
88 | rot[:, 0, 1] = torch.squeeze(xyC - zs)
89 | rot[:, 0, 2] = torch.squeeze(zxC + ys)
90 | rot[:, 1, 0] = torch.squeeze(xyC + zs)
91 | rot[:, 1, 1] = torch.squeeze(y * yC + ca)
92 | rot[:, 1, 2] = torch.squeeze(yzC - xs)
93 | rot[:, 2, 0] = torch.squeeze(zxC - ys)
94 | rot[:, 2, 1] = torch.squeeze(yzC + xs)
95 | rot[:, 2, 2] = torch.squeeze(z * zC + ca)
96 | rot[:, 3, 3] = 1
97 |
98 | return rot
99 |
100 |
101 | class ConvBlock(nn.Module):
102 | """Layer to perform a convolution followed by ELU
103 | """
104 | def __init__(self, in_channels, out_channels):
105 | super(ConvBlock, self).__init__()
106 |
107 | self.conv = Conv3x3(in_channels, out_channels)
108 | self.nonlin = nn.ELU(inplace=True)
109 |
110 | def forward(self, x):
111 | out = self.conv(x)
112 | out = self.nonlin(out)
113 | return out
114 |
115 |
116 | class ConvBlockDepth(nn.Module):
117 | """Layer to perform a convolution followed by ELU
118 | """
119 | def __init__(self, in_channels, out_channels):
120 | super(ConvBlockDepth, self).__init__()
121 |
122 | self.conv = DepthConv3x3(in_channels, out_channels)
123 | self.nonlin = nn.GELU()
124 |
125 | def forward(self, x):
126 | out = self.conv(x)
127 | out = self.nonlin(out)
128 | return out
129 |
130 |
131 | class DepthConv3x3(nn.Module):
132 | """Layer to pad and convolve input
133 | """
134 | def __init__(self, in_channels, out_channels, use_refl=True):
135 | super(DepthConv3x3, self).__init__()
136 |
137 | if use_refl:
138 | self.pad = nn.ReflectionPad2d(1)
139 | else:
140 | self.pad = nn.ZeroPad2d(1)
141 | # self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
142 | self.conv = nn.Conv2d(int(in_channels), int(out_channels), kernel_size=3, groups=int(out_channels), bias=False)
143 |
144 | def forward(self, x):
145 | out = self.pad(x)
146 | out = self.conv(out)
147 | return out
148 |
149 | class Conv3x3(nn.Module):
150 | """Layer to pad and convolve input
151 | """
152 | def __init__(self, in_channels, out_channels, use_refl=True):
153 | super(Conv3x3, self).__init__()
154 |
155 | if use_refl:
156 | self.pad = nn.ReflectionPad2d(1)
157 | else:
158 | self.pad = nn.ZeroPad2d(1)
159 | self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
160 | # self.conv = nn.Conv2d(int(in_channels), int(out_channels), kernel_size=3, padding=3 // 2, groups=int(out_channels), bias=False)
161 |
162 | def forward(self, x):
163 | out = self.pad(x)
164 | out = self.conv(out)
165 | return out
166 |
167 |
168 | class BackprojectDepth(nn.Module):
169 | """Layer to transform a depth image into a point cloud
170 | """
171 | def __init__(self, batch_size, height, width):
172 | super(BackprojectDepth, self).__init__()
173 |
174 | self.batch_size = batch_size
175 | self.height = height
176 | self.width = width
177 |
178 | meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
179 | self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
180 | self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords),
181 | requires_grad=False)
182 |
183 | self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width),
184 | requires_grad=False)
185 |
186 | self.pix_coords = torch.unsqueeze(torch.stack(
187 | [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0)
188 | self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1)
189 | self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1),
190 | requires_grad=False)
191 |
192 | def forward(self, depth, inv_K):
193 | cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords)
194 | cam_points = depth.view(self.batch_size, 1, -1) * cam_points
195 | cam_points = torch.cat([cam_points, self.ones], 1)
196 |
197 | return cam_points
198 |
199 |
200 | class Project3D(nn.Module):
201 | """Layer which projects 3D points into a camera with intrinsics K and at position T
202 | """
203 | def __init__(self, batch_size, height, width, eps=1e-7):
204 | super(Project3D, self).__init__()
205 |
206 | self.batch_size = batch_size
207 | self.height = height
208 | self.width = width
209 | self.eps = eps
210 |
211 | def forward(self, points, K, T):
212 | P = torch.matmul(K, T)[:, :3, :]
213 |
214 | cam_points = torch.matmul(P, points)
215 |
216 | pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
217 | pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width)
218 | pix_coords = pix_coords.permute(0, 2, 3, 1)
219 | pix_coords[..., 0] /= self.width - 1
220 | pix_coords[..., 1] /= self.height - 1
221 | pix_coords = (pix_coords - 0.5) * 2
222 | return pix_coords
223 |
224 |
225 | def upsample(x, scale_factor=2, mode="bilinear"):
226 | """Upsample input tensor by a factor of 2
227 | """
228 | return F.interpolate(x, scale_factor=scale_factor, mode=mode)
229 |
230 |
231 | def get_smooth_loss(disp, img):
232 | """Computes the smoothness loss for a disparity image
233 | The color image is used for edge-aware smoothness
234 | """
235 | grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
236 | grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])
237 |
238 | grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
239 | grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)
240 |
241 | grad_disp_x *= torch.exp(-grad_img_x)
242 | grad_disp_y *= torch.exp(-grad_img_y)
243 |
244 | return grad_disp_x.mean() + grad_disp_y.mean()
245 |
246 |
247 | class SSIM(nn.Module):
248 | """Layer to compute the SSIM loss between a pair of images
249 | """
250 | def __init__(self):
251 | super(SSIM, self).__init__()
252 | self.mu_x_pool = nn.AvgPool2d(3, 1)
253 | self.mu_y_pool = nn.AvgPool2d(3, 1)
254 | self.sig_x_pool = nn.AvgPool2d(3, 1)
255 | self.sig_y_pool = nn.AvgPool2d(3, 1)
256 | self.sig_xy_pool = nn.AvgPool2d(3, 1)
257 |
258 | self.refl = nn.ReflectionPad2d(1)
259 |
260 | self.C1 = 0.01 ** 2
261 | self.C2 = 0.03 ** 2
262 |
263 | def forward(self, x, y):
264 | x = self.refl(x)
265 | y = self.refl(y)
266 |
267 | mu_x = self.mu_x_pool(x)
268 | mu_y = self.mu_y_pool(y)
269 |
270 | sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2
271 | sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2
272 | sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y
273 |
274 | SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
275 | SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2)
276 |
277 | return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)
278 |
279 |
280 | def compute_depth_errors(gt, pred):
281 | """Computation of error metrics between predicted and ground truth depths
282 | """
283 | thresh = torch.max((gt / pred), (pred / gt))
284 | a1 = (thresh < 1.25 ).float().mean()
285 | a2 = (thresh < 1.25 ** 2).float().mean()
286 | a3 = (thresh < 1.25 ** 3).float().mean()
287 |
288 | rmse = (gt - pred) ** 2
289 | rmse = torch.sqrt(rmse.mean())
290 |
291 | rmse_log = (torch.log(gt) - torch.log(pred)) ** 2
292 | rmse_log = torch.sqrt(rmse_log.mean())
293 |
294 | abs_rel = torch.mean(torch.abs(gt - pred) / gt)
295 |
296 | sq_rel = torch.mean((gt - pred) ** 2 / gt)
297 |
298 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3
299 |
300 |
301 |
--------------------------------------------------------------------------------
/license:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Noah
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/lite-mono-pretrain-code/README.md:
--------------------------------------------------------------------------------
1 | I have no time to clean up the code, but I have kept all the necessary lines in ‘litemono.py’ to run the pretraining.
2 |
3 | To be short, the last layer of the encoder of Lite-Mono should have `1000` channels for the classification task. Please add these lines to your current model file. The `main.py` file will create the model from this file.
4 |
5 |
6 | To train on a single machine with 2 GPUs, using the following command.
7 |
8 | python -m torch.distributed.launch --nproc_per_node=2 main.py --data_path data/imagenet/
9 |
10 |
11 | Plese check the code if you want to change parameters such as epochs, learning rates, etc.
--------------------------------------------------------------------------------
/lite-mono-pretrain-code/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torchvision import datasets, transforms
3 | from PIL import Image
4 |
5 | from timm.data.constants import \
6 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
7 | from timm.data import create_transform
8 | from sampler import MultiScaleImageFolder
9 |
10 |
11 | # from typing import Any, Callable, cast, Dict, List, Optional, Tuple
12 | # from typing import Union
13 | #
14 | # from PIL import Image
15 | # # IMAGENET_DEFAULT_MEAN = (0.445, )
16 | # # IMAGENET_DEFAULT_STD = (0.269, )
17 | # IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
18 | #
19 | #
20 | # def pil_loader(path: str) -> Image.Image:
21 | # # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
22 | # with open(path, "rb") as f:
23 | # img = Image.open(f)
24 | # return img.convert("L")
25 | #
26 | #
27 | # # TODO: specify the return type
28 | # def accimage_loader(path: str) -> Any:
29 | # import accimage
30 | # try:
31 | # return accimage.Image(path)
32 | # except OSError:
33 | # # Potentially a decoding problem, fall back to PIL.Image
34 | # return pil_loader(path)
35 | #
36 | #
37 | # def default_loader(path: str) -> Any:
38 | # from torchvision import get_image_backend
39 | #
40 | # if get_image_backend() == "accimage":
41 | # return accimage_loader(path)
42 | # else:
43 | #
44 | # return pil_loader(path)
45 | #
46 | #
47 | # class MyImageFolder(datasets.DatasetFolder):
48 | # def __init__(
49 | # self,
50 | # root: str,
51 | # transform: Optional[Callable] = None,
52 | # target_transform: Optional[Callable] = None,
53 | # loader: Callable[[str], Any] = default_loader,
54 | # is_valid_file: Optional[Callable[[str], bool]] = None,
55 | # ):
56 | # super().__init__(
57 | # root,
58 | # loader,
59 | # IMG_EXTENSIONS if is_valid_file is None else None,
60 | # transform=transform,
61 | # target_transform=target_transform,
62 | # is_valid_file=is_valid_file,
63 | # )
64 | # self.imgs = self.samples
65 |
66 |
67 | def build_dataset(is_train, args):
68 | transform = build_transform(is_train, args)
69 |
70 | print("Transform = ")
71 | if isinstance(transform, tuple):
72 | for trans in transform:
73 | print(" - - - - - - - - - - ")
74 | for t in trans.transforms:
75 | print(t)
76 | else:
77 | for t in transform.transforms:
78 | print(t)
79 | print("---------------------------")
80 |
81 | if args.data_set == 'CIFAR':
82 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True)
83 | nb_classes = 100
84 | elif args.data_set == 'IMNET':
85 | print("reading from datapath", args.data_path)
86 | root = os.path.join(args.data_path, 'train' if is_train else 'val')
87 | if is_train and args.multi_scale_sampler:
88 | dataset = MultiScaleImageFolder(root, args)
89 | else:
90 | dataset = datasets.ImageFolder(root, transform=transform)
91 | nb_classes = 1000
92 | elif args.data_set == "image_folder":
93 | root = args.data_path if is_train else args.eval_data_path
94 | dataset = datasets.ImageFolder(root, transform=transform)
95 | nb_classes = args.nb_classes
96 | assert len(dataset.class_to_idx) == nb_classes
97 | else:
98 | raise NotImplementedError()
99 | print("Number of the class = %d" % nb_classes)
100 |
101 | return dataset, nb_classes
102 |
103 |
104 | def build_transform(is_train, args):
105 | resize_im = args.input_size > 32
106 | imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
107 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
108 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
109 |
110 | if is_train:
111 | # This should always dispatch to transforms_imagenet_train
112 | transform = create_transform(
113 | input_size=args.input_size,
114 | is_training=True,
115 | color_jitter=args.color_jitter if args.color_jitter > 0 else None,
116 | auto_augment=args.aa,
117 | interpolation=args.train_interpolation,
118 | re_prob=args.reprob,
119 | re_mode=args.remode,
120 | re_count=args.recount,
121 | mean=mean,
122 | std=std,
123 | )
124 | if args.three_aug: # --aa should not be "" to use this as it actually overrides the auto-augment
125 | print(f"Using 3-Augments instead of Rand Augment")
126 | cur_augs = transform.transforms
127 | three_aug = transforms.RandomChoice([transforms.Grayscale(num_output_channels=3),
128 | transforms.RandomSolarize(threshold=192.0),
129 | transforms.GaussianBlur(kernel_size=(5, 9))])
130 | final_transforms = cur_augs[0:2] + [three_aug] + cur_augs[2:]
131 | transform = transforms.Compose(final_transforms)
132 | if not resize_im:
133 | transform.transforms[0] = transforms.RandomCrop(
134 | args.input_size, padding=4)
135 | return transform
136 |
137 | t = []
138 | if resize_im:
139 | # Warping (no cropping) when evaluated at 384 or larger
140 | if args.input_size >= 384:
141 | t.append(
142 | transforms.Resize((args.input_size, args.input_size),
143 | interpolation=transforms.InterpolationMode.BICUBIC),
144 | )
145 | print(f"Warping {args.input_size} size input images...")
146 | else:
147 | if args.crop_pct is None:
148 | args.crop_pct = 224 / 256
149 | size = int(args.input_size / args.crop_pct)
150 | t.append(
151 | # To maintain same ratio w.r.t. 224 images
152 | transforms.Resize(size, interpolation=Image.BICUBIC),
153 | )
154 | t.append(transforms.CenterCrop(args.input_size))
155 |
156 | t.append(transforms.ToTensor())
157 | t.append(transforms.Normalize(mean, std))
158 | return transforms.Compose(t)
159 |
--------------------------------------------------------------------------------
/lite-mono-pretrain-code/engine.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Iterable, Optional
3 | import torch
4 | from timm.data import Mixup
5 | from timm.utils import accuracy, ModelEma
6 |
7 | import utils
8 |
9 |
10 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
11 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
12 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
13 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
14 | wandb_logger=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
15 | num_training_steps_per_epoch=None, update_freq=None, use_amp=False):
16 | model.train(True)
17 | metric_logger = utils.MetricLogger(delimiter=" ")
18 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
19 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
20 | header = 'Epoch: [{}]'.format(epoch)
21 | print_freq = 10
22 |
23 | optimizer.zero_grad()
24 |
25 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
26 | step = data_iter_step // update_freq
27 | if step >= num_training_steps_per_epoch:
28 | continue
29 | it = start_steps + step # Global training iteration
30 | # Update LR & WD for the first acc
31 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
32 | for i, param_group in enumerate(optimizer.param_groups):
33 | if lr_schedule_values is not None:
34 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
35 | if wd_schedule_values is not None and param_group["weight_decay"] > 0:
36 | param_group["weight_decay"] = wd_schedule_values[it]
37 |
38 | samples = samples.to(device, non_blocking=True)
39 | targets = targets.to(device, non_blocking=True)
40 |
41 | if mixup_fn is not None:
42 | samples, targets = mixup_fn(samples, targets)
43 |
44 | if use_amp:
45 | with torch.cuda.amp.autocast():
46 | output = model(samples)
47 | loss = criterion(output, targets)
48 | else: # Full precision
49 | output = model(samples)
50 | loss = criterion(output, targets)
51 |
52 | loss_value = loss.item()
53 |
54 | if not math.isfinite(loss_value): # This could trigger if using AMP
55 | print("Loss is {}, stopping training".format(loss_value))
56 | assert math.isfinite(loss_value)
57 |
58 | if use_amp:
59 | # This attribute is added by timm on one optimizer (adahessian)
60 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
61 | loss /= update_freq
62 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
63 | parameters=model.parameters(), create_graph=is_second_order,
64 | update_grad=(data_iter_step + 1) % update_freq == 0)
65 |
66 | # for name, param in model.named_parameters():
67 | # if param.grad is None:
68 | # print(name)
69 | if (data_iter_step + 1) % update_freq == 0:
70 | optimizer.zero_grad()
71 | if model_ema is not None:
72 | model_ema.update(model)
73 | else: # Full precision
74 | loss /= update_freq
75 | loss.backward()
76 | if (data_iter_step + 1) % update_freq == 0:
77 | optimizer.step()
78 | optimizer.zero_grad()
79 | if model_ema is not None:
80 | model_ema.update(model)
81 |
82 | torch.cuda.synchronize()
83 |
84 | if mixup_fn is None:
85 | class_acc = (output.max(-1)[-1] == targets).float().mean()
86 | else:
87 | class_acc = None
88 | metric_logger.update(loss=loss_value)
89 | metric_logger.update(class_acc=class_acc)
90 | min_lr = 10.
91 | max_lr = 0.
92 | for group in optimizer.param_groups:
93 | min_lr = min(min_lr, group["lr"])
94 | max_lr = max(max_lr, group["lr"])
95 |
96 | metric_logger.update(lr=max_lr)
97 | metric_logger.update(min_lr=min_lr)
98 | weight_decay_value = None
99 | for group in optimizer.param_groups:
100 | if group["weight_decay"] > 0:
101 | weight_decay_value = group["weight_decay"]
102 | metric_logger.update(weight_decay=weight_decay_value)
103 | if use_amp:
104 | metric_logger.update(grad_norm=grad_norm)
105 |
106 | if log_writer is not None:
107 | log_writer.update(loss=loss_value, head="loss")
108 | log_writer.update(class_acc=class_acc, head="loss")
109 | log_writer.update(lr=max_lr, head="opt")
110 | log_writer.update(min_lr=min_lr, head="opt")
111 | log_writer.update(weight_decay=weight_decay_value, head="opt")
112 | if use_amp:
113 | log_writer.update(grad_norm=grad_norm, head="opt")
114 | log_writer.set_step()
115 |
116 | if wandb_logger:
117 | wandb_logger._wandb.log({
118 | 'Rank-0 Batch Wise/train_loss': loss_value,
119 | 'Rank-0 Batch Wise/train_max_lr': max_lr,
120 | 'Rank-0 Batch Wise/train_min_lr': min_lr
121 | }, commit=False)
122 | if class_acc:
123 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_class_acc': class_acc}, commit=False)
124 | if use_amp:
125 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_grad_norm': grad_norm}, commit=False)
126 | wandb_logger._wandb.log({'Rank-0 Batch Wise/global_train_step': it})
127 |
128 | # Gather the stats from all processes
129 | metric_logger.synchronize_between_processes()
130 | print("Averaged stats:", metric_logger)
131 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
132 |
133 |
134 | @torch.no_grad()
135 | def evaluate(data_loader, model, device, use_amp=False):
136 | criterion = torch.nn.CrossEntropyLoss()
137 |
138 | metric_logger = utils.MetricLogger(delimiter=" ")
139 | header = 'Test:'
140 |
141 | # Switch to evaluation mode
142 | model.eval()
143 | for batch in metric_logger.log_every(data_loader, 10, header):
144 | images = batch[0]
145 | target = batch[-1]
146 |
147 | images = images.to(device, non_blocking=True)
148 | target = target.to(device, non_blocking=True)
149 |
150 | # Compute output
151 | if use_amp:
152 | with torch.cuda.amp.autocast():
153 | output = model(images)
154 | loss = criterion(output, target)
155 | else:
156 | output = model(images)
157 | loss = criterion(output, target)
158 |
159 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
160 |
161 | batch_size = images.shape[0]
162 | metric_logger.update(loss=loss.item())
163 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
164 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
165 | # Gather the stats from all processes
166 | metric_logger.synchronize_between_processes()
167 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
168 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
169 |
170 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
171 |
--------------------------------------------------------------------------------
/lite-mono-pretrain-code/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import numpy as np
4 | import time
5 | import torch
6 | import torch.backends.cudnn as cudnn
7 | import json
8 | import os
9 |
10 | from pathlib import Path
11 |
12 | from timm.data.mixup import Mixup
13 | from timm.models import create_model
14 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
15 | from timm.utils import ModelEma
16 | from optim_factory import create_optimizer
17 |
18 | from datasets import build_dataset
19 | from engine import train_one_epoch, evaluate
20 |
21 | from utils import NativeScalerWithGradNormCount as NativeScaler
22 | import utils
23 | import models.model
24 |
25 | from sampler import MultiScaleSamplerDDP
26 | from fvcore.nn import FlopCountAnalysis
27 |
28 |
29 | def str2bool(v):
30 | """
31 | Converts string to bool type; enables command line
32 | arguments in the format of '--arg1 true --arg2 false'
33 | """
34 | if isinstance(v, bool):
35 | return v
36 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
37 | return True
38 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
39 | return False
40 | else:
41 | raise argparse.ArgumentTypeError('Boolean value expected.')
42 |
43 |
44 | def get_args_parser():
45 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script for image classification', add_help=False)
46 | parser.add_argument('--batch_size', default=512, type=int,
47 | help='Per GPU batch size')
48 | parser.add_argument('--epochs', default=100, type=int)
49 | parser.add_argument('--update_freq', default=2, type=int,
50 | help='gradient accumulation steps')
51 |
52 | # Model parameters
53 | parser.add_argument('--model', default='depth_encoder', type=str, metavar='MODEL',
54 | help='Name of model to train')
55 | parser.add_argument('--drop_path', type=float, default=0.0, metavar='PCT',
56 | help='Drop path rate (default: 0.0)')
57 | parser.add_argument('--input_size', default=256, type=int,
58 | help='image input size')
59 | parser.add_argument('--layer_scale_init_value', default=1e-5, type=float,
60 | help="Layer scale initial values")
61 |
62 | # EMA related parameters
63 | parser.add_argument('--model_ema', type=str2bool, default=False)
64 | parser.add_argument('--model_ema_decay', type=float, default=0.9995, help='') # TODO: MobileViT is using 0.9995
65 | parser.add_argument('--model_ema_force_cpu', type=str2bool, default=False, help='')
66 | parser.add_argument('--model_ema_eval', type=str2bool, default=False, help='Using ema to eval during training.')
67 |
68 | # Optimization parameters
69 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', help='Optimizer (default: "adamw"')
70 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
71 | help='Optimizer Epsilon (default: 1e-8)')
72 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
73 | help='Optimizer Betas (default: None, use opt default)')
74 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
75 | help='Clip gradient norm (default: None, no clipping)')
76 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
77 | help='SGD momentum (default: 0.9)')
78 | parser.add_argument('--weight_decay', type=float, default=0.05,
79 | help='weight decay (default: 0.05)')
80 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
81 | weight decay. We use a cosine schedule for WD and using a larger decay by
82 | the end of training improves performance for ViTs.""")
83 |
84 | parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
85 | help='learning rate (default: 6e-3), with total batch size 4096')
86 | parser.add_argument('--layer_decay', type=float, default=1.0)
87 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
88 | help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')
89 | parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
90 | help='epochs to warmup LR, if scheduler supports')
91 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
92 | help='num of steps to warmup LR, will overload warmup_epochs if set > 0')
93 | parser.add_argument('--warmup_start_lr', type=float, default=0, metavar='LR',
94 | help='Starting LR for warmup (default 0)')
95 |
96 | # Augmentation parameters
97 | parser.add_argument('--color_jitter', type=float, default=0.0, metavar='PCT',
98 | help='Color jitter factor (default: 0.4)')
99 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
100 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
101 | parser.add_argument('--smoothing', type=float, default=0.0,
102 | help='Label smoothing (default: 0.1)')
103 | parser.add_argument('--train_interpolation', type=str, default='bicubic',
104 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
105 |
106 | # Evaluation parameters
107 | parser.add_argument('--crop_pct', type=float, default=None)
108 |
109 | # * Random Erase params
110 | parser.add_argument('--reprob', type=float, default=0.0, metavar='PCT',
111 | help='Random erase prob (default: 0.0)')
112 | parser.add_argument('--remode', type=str, default='pixel',
113 | help='Random erase mode (default: "pixel")')
114 | parser.add_argument('--recount', type=int, default=1,
115 | help='Random erase count (default: 1)')
116 | parser.add_argument('--resplit', type=str2bool, default=False,
117 | help='Do not random erase first (clean) augmentation split')
118 |
119 | # Mixup params
120 | parser.add_argument('--mixup', type=float, default=0.0,
121 | help='mixup alpha, mixup enabled if > 0.')
122 | parser.add_argument('--cutmix', type=float, default=0.0,
123 | help='cutmix alpha, cutmix enabled if > 0.')
124 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
125 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
126 | parser.add_argument('--mixup_prob', type=float, default=0.0,
127 | help='Probability of performing mixup or cutmix when either/both is enabled')
128 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
129 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
130 | parser.add_argument('--mixup_mode', type=str, default='batch',
131 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
132 |
133 | # Dataset parameters
134 | parser.add_argument('--data_path', default='data/ILSVRC2012/imagenet/', type=str,
135 | help='dataset path (path to full imagenet)')
136 | parser.add_argument('--eval_data_path', default=None, type=str,
137 | help='dataset path for evaluation')
138 | parser.add_argument('--nb_classes', default=1000, type=int,
139 | help='number of the classification types')
140 | parser.add_argument('--imagenet_default_mean_and_std', type=str2bool, default=True)
141 | parser.add_argument('--data_set', default='IMNET', choices=['IMNET', 'image_folder'],
142 | type=str, help='ImageNet dataset path')
143 | parser.add_argument('--output_dir', default='./save',
144 | help='path where to save, empty for no saving')
145 | parser.add_argument('--log_dir', default=None,
146 | help='path where to tensorboard log')
147 | parser.add_argument('--device', default='cuda',
148 | help='device to use for training / testing')
149 | parser.add_argument('--seed', default=0, type=int)
150 |
151 | parser.add_argument('--resume', default='',
152 | help='resume from checkpoint')
153 | parser.add_argument('--auto_resume', type=str2bool, default=False)
154 | parser.add_argument('--save_ckpt', type=str2bool, default=True)
155 | parser.add_argument('--save_ckpt_freq', default=1, type=int)
156 | parser.add_argument('--save_ckpt_num', default=3, type=int)
157 |
158 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
159 | help='start epoch')
160 | parser.add_argument('--eval', type=str2bool, default=False,
161 | help='Perform evaluation only')
162 | parser.add_argument('--dist_eval', type=str2bool, default=True,
163 | help='Enabling distributed evaluation')
164 | parser.add_argument('--disable_eval', type=str2bool, default=False,
165 | help='Disabling evaluation during training')
166 | parser.add_argument('--num_workers', default=10, type=int)
167 | parser.add_argument('--pin_mem', type=str2bool, default=True,
168 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
169 |
170 | # Distributed training parameters
171 | parser.add_argument('--world_size', default=1, type=int,
172 | help='number of distributed processes')
173 | parser.add_argument('--local_rank', default=-1, type=int)
174 | parser.add_argument('--dist_on_itp', type=str2bool, default=False)
175 | parser.add_argument('--dist_url', default='env://',
176 | help='url used to set up distributed training')
177 |
178 | parser.add_argument('--use_amp', type=str2bool, default=True,
179 | help="Use PyTorch's AMP (Automatic Mixed Precision) or not")
180 |
181 | # Weights and Biases arguments
182 | parser.add_argument('--enable_wandb', type=str2bool, default=False,
183 | help="enable logging to Weights and Biases")
184 | parser.add_argument('--project', default='litemono', type=str,
185 | help="The name of the W&B project where you're sending the new run.")
186 | parser.add_argument('--wandb_ckpt', type=str2bool, default=False,
187 | help="Save model checkpoints as W&B Artifacts.")
188 | parser.add_argument("--multi_scale_sampler", action="store_true", help="Either to use multi-scale sampler or not.")
189 | parser.add_argument('--min_crop_size_w', default=160, type=int)
190 | parser.add_argument('--max_crop_size_w', default=320, type=int)
191 | parser.add_argument('--min_crop_size_h', default=160, type=int)
192 | parser.add_argument('--max_crop_size_h', default=320, type=int)
193 | parser.add_argument("--find_unused_params", action="store_true",
194 | help="Set this flag to enable unused parameters finding in DistributedDataParallel()")
195 | parser.add_argument("--three_aug", action="store_true",
196 | help="Either to use three augments proposed by DeiT-III")
197 | parser.add_argument('--classifier_dropout', default=0.0, type=float)
198 | parser.add_argument('--usi_eval', type=str2bool, default=False,
199 | help="Enable it when testing USI model.")
200 |
201 | return parser
202 |
203 |
204 | def main(args):
205 | utils.init_distributed_mode(args)
206 | print(args)
207 | device = torch.device(args.device)
208 |
209 | # Eval/USI_eval configurations
210 | if args.eval:
211 | if args.usi_eval:
212 | args.crop_pct = 0.95
213 | model_state_dict_name = 'state_dict'
214 | else:
215 | model_state_dict_name = 'model_ema'
216 | else:
217 | model_state_dict_name = 'model'
218 |
219 | # Fix the seed for reproducibility
220 | seed = args.seed + utils.get_rank()
221 | torch.manual_seed(seed)
222 | np.random.seed(seed)
223 | cudnn.benchmark = True
224 |
225 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
226 | if args.disable_eval:
227 | args.dist_eval = False
228 | dataset_val = None
229 | else:
230 | dataset_val, _ = build_dataset(is_train=False, args=args)
231 |
232 | num_tasks = utils.get_world_size()
233 | global_rank = utils.get_rank()
234 | if args.multi_scale_sampler:
235 | sampler_train = MultiScaleSamplerDDP(base_im_w=args.input_size, base_im_h=args.input_size,
236 | base_batch_size=args.batch_size, n_data_samples=len(dataset_train),
237 | is_training=True, distributed=args.distributed,
238 | min_crop_size_w=args.min_crop_size_w, max_crop_size_w=args.max_crop_size_w,
239 | min_crop_size_h=args.min_crop_size_h, max_crop_size_h=args.max_crop_size_h)
240 | else:
241 | sampler_train = torch.utils.data.DistributedSampler(
242 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=args.seed,
243 | )
244 | print("Sampler_train = %s" % str(sampler_train))
245 | if args.dist_eval:
246 | if len(dataset_val) % num_tasks != 0:
247 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
248 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
249 | 'equal num of samples per-process.')
250 | sampler_val = torch.utils.data.DistributedSampler(
251 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
252 | else:
253 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
254 |
255 | if global_rank == 0 and args.log_dir is not None:
256 | os.makedirs(args.log_dir, exist_ok=True)
257 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
258 | else:
259 | log_writer = None
260 |
261 | if global_rank == 0 and args.enable_wandb:
262 | wandb_logger = utils.WandbLogger(args)
263 | else:
264 | wandb_logger = None
265 |
266 | if args.multi_scale_sampler:
267 | data_loader_train = torch.utils.data.DataLoader(
268 | dataset_train, batch_sampler=sampler_train,
269 | batch_size=1,
270 | num_workers=args.num_workers,
271 | pin_memory=args.pin_mem,
272 | )
273 | else:
274 | data_loader_train = torch.utils.data.DataLoader(
275 | dataset_train, sampler=sampler_train,
276 | batch_size=args.batch_size,
277 | num_workers=args.num_workers,
278 | pin_memory=args.pin_mem,
279 | drop_last=True,
280 | )
281 |
282 | if dataset_val is not None:
283 | data_loader_val = torch.utils.data.DataLoader(
284 | dataset_val, sampler=sampler_val,
285 | batch_size=int(1.5 * args.batch_size),
286 | num_workers=args.num_workers,
287 | pin_memory=args.pin_mem,
288 | drop_last=False
289 | )
290 | else:
291 | data_loader_val = None
292 |
293 | mixup_fn = None
294 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
295 | if mixup_active:
296 | print("Mixup is activated!")
297 | mixup_fn = Mixup(
298 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
299 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
300 | label_smoothing=args.smoothing, num_classes=args.nb_classes)
301 |
302 | model = create_model(
303 | args.model,
304 | pretrained=False,
305 | num_classes=args.nb_classes,
306 | drop_path_rate=args.drop_path,
307 | layer_scale_init_value=args.layer_scale_init_value,
308 | head_init_scale=1.0,
309 | input_res=args.input_size,
310 | classifier_dropout=args.classifier_dropout,
311 | )
312 | model.to(device)
313 |
314 | model_ema = None
315 | if args.model_ema:
316 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
317 | model_ema = ModelEma(
318 | model,
319 | decay=args.model_ema_decay,
320 | device='cpu' if args.model_ema_force_cpu else '',
321 | resume='')
322 | print("Using EMA with decay = %.8f" % args.model_ema_decay)
323 |
324 | model_without_ddp = model
325 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
326 |
327 | print("Model = %s" % str(model_without_ddp))
328 | print('number of params:', n_parameters)
329 |
330 | total_batch_size = args.batch_size * args.update_freq * utils.get_world_size()
331 | num_training_steps_per_epoch = len(dataset_train) // total_batch_size
332 | print("LR = %.8f" % args.lr)
333 | print("Batch size = %d" % total_batch_size)
334 | print("Update frequent = %d" % args.update_freq)
335 | print("Number of training examples = %d" % len(dataset_train))
336 | print("Number of training training per epoch = %d" % num_training_steps_per_epoch)
337 |
338 | if args.layer_decay < 1.0 or args.layer_decay > 1.0:
339 | # Layer decay not supported
340 | raise NotImplementedError
341 | else:
342 | assigner = None
343 |
344 | if assigner is not None:
345 | print("Assigned values = %s" % str(assigner.values))
346 |
347 | if args.distributed:
348 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu],
349 | find_unused_parameters=args.find_unused_params)
350 | model_without_ddp = model.module
351 |
352 | optimizer = create_optimizer(
353 | args, model_without_ddp, skip_list=None,
354 | get_num_layer=assigner.get_layer_id if assigner is not None else None,
355 | get_layer_scale=assigner.get_scale if assigner is not None else None)
356 |
357 | loss_scaler = NativeScaler() # if args.use_amp is False, this won't be used
358 |
359 | print("Use Cosine LR scheduler")
360 | lr_schedule_values = utils.cosine_scheduler(
361 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
362 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
363 | start_warmup_value=args.warmup_start_lr
364 | )
365 |
366 | if args.weight_decay_end is None:
367 | args.weight_decay_end = args.weight_decay
368 | wd_schedule_values = utils.cosine_scheduler(
369 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
370 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values)))
371 |
372 | if mixup_fn is not None:
373 | # smoothing is handled with mixup label transform
374 | criterion = SoftTargetCrossEntropy()
375 | elif args.smoothing > 0.:
376 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
377 | else:
378 | criterion = torch.nn.CrossEntropyLoss()
379 |
380 | print("criterion = %s" % str(criterion))
381 |
382 | utils.auto_load_model(
383 | args=args, model=model, model_without_ddp=model_without_ddp,
384 | optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema, state_dict_name=model_state_dict_name)
385 |
386 | if args.eval:
387 | print(f"Eval only mode")
388 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp)
389 | print(f"Accuracy of the network on {len(dataset_val)} test images: {test_stats['acc1']:.5f}%")
390 | return
391 |
392 | max_accuracy = 0.0
393 | if args.model_ema and args.model_ema_eval:
394 | max_accuracy_ema = 0.0
395 |
396 | def count_parameters(model):
397 | total_trainable_params = 0
398 | for name, parameter in model.named_parameters():
399 | if not parameter.requires_grad:
400 | continue
401 | params = parameter.numel()
402 | total_trainable_params += params
403 | return total_trainable_params
404 |
405 | total_params = count_parameters(model)
406 | # fvcore to calculate MAdds
407 | input_res = (3, args.input_size, args.input_size)
408 | input = torch.ones(()).new_empty((1, *input_res), dtype=next(model.parameters()).dtype,
409 | device=next(model.parameters()).device)
410 | flops = FlopCountAnalysis(model, input)
411 | model_flops = flops.total()
412 | print(f"Total Trainable Params: {round(total_params * 1e-6, 2)} M")
413 | print(f"MAdds: {round(model_flops * 1e-6, 2)} M")
414 |
415 | print("Start training for %d epochs" % args.epochs)
416 | start_time = time.time()
417 | for epoch in range(args.start_epoch, args.epochs):
418 | if args.multi_scale_sampler:
419 | data_loader_train.batch_sampler.set_epoch(epoch)
420 | elif args.distributed:
421 | data_loader_train.sampler.set_epoch(epoch)
422 | if log_writer is not None:
423 | log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq)
424 | if wandb_logger:
425 | wandb_logger.set_steps()
426 | train_stats = train_one_epoch(
427 | model, criterion, data_loader_train, optimizer,
428 | device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn,
429 | log_writer=log_writer, wandb_logger=wandb_logger, start_steps=epoch * num_training_steps_per_epoch,
430 | lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values,
431 | num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq,
432 | use_amp=args.use_amp
433 | )
434 | if args.output_dir and args.save_ckpt:
435 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
436 | utils.save_model(
437 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
438 | loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema)
439 | if data_loader_val is not None:
440 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp)
441 | print(f"Accuracy of the model on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
442 | if max_accuracy < test_stats["acc1"]:
443 | max_accuracy = test_stats["acc1"]
444 | if args.output_dir and args.save_ckpt:
445 | utils.save_model(
446 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
447 | loss_scaler=loss_scaler, epoch="best", model_ema=model_ema)
448 | print(f'Max accuracy: {max_accuracy:.2f}%')
449 |
450 | if log_writer is not None:
451 | log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch)
452 | log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch)
453 | log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch)
454 |
455 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
456 | **{f'test_{k}': v for k, v in test_stats.items()},
457 | 'epoch': epoch,
458 | 'n_parameters': n_parameters}
459 |
460 | # Repeat testing routines for EMA, if ema eval is turned on
461 | if args.model_ema and args.model_ema_eval:
462 | test_stats_ema = evaluate(data_loader_val, model_ema.ema, device, use_amp=args.use_amp)
463 | print(f"Accuracy of the model EMA on {len(dataset_val)} test images: {test_stats_ema['acc1']:.1f}%")
464 | if max_accuracy_ema < test_stats_ema["acc1"]:
465 | max_accuracy_ema = test_stats_ema["acc1"]
466 | if args.output_dir and args.save_ckpt:
467 | utils.save_model(
468 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
469 | loss_scaler=loss_scaler, epoch="best-ema", model_ema=model_ema)
470 | print(f'Max EMA accuracy: {max_accuracy_ema:.2f}%')
471 | if log_writer is not None:
472 | log_writer.update(test_acc1_ema=test_stats_ema['acc1'], head="perf", step=epoch)
473 | log_stats.update({**{f'test_{k}_ema': v for k, v in test_stats_ema.items()}})
474 | else:
475 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
476 | 'epoch': epoch,
477 | 'n_parameters': n_parameters}
478 |
479 | if args.output_dir and utils.is_main_process():
480 | if log_writer is not None:
481 | log_writer.flush()
482 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
483 | f.write(json.dumps(log_stats) + "\n")
484 |
485 | if wandb_logger:
486 | wandb_logger.log_epoch_metrics(log_stats)
487 |
488 | if wandb_logger and args.wandb_ckpt and args.save_ckpt and args.output_dir:
489 | wandb_logger.log_checkpoints()
490 |
491 | total_time = time.time() - start_time
492 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
493 | print('Training time {}'.format(total_time_str))
494 |
495 |
496 | if __name__ == '__main__':
497 | parser = argparse.ArgumentParser('Lite-Mono pretraining code', parents=[get_args_parser()])
498 | args = parser.parse_args()
499 | if args.output_dir:
500 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
501 | main(args)
502 |
--------------------------------------------------------------------------------
/lite-mono-pretrain-code/models/litemono.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 |
5 | from torch import nn
6 | import torchvision.models as models
7 | import torch.utils.model_zoo as model_zoo
8 | import torch.nn.functional as F
9 | from timm.models.layers import DropPath
10 | from timm.models.layers import trunc_normal_
11 | import math
12 | import torch.cuda
13 |
14 | __all__ = ["litemono"]
15 |
16 |
17 | class LiteMono(nn.Module):
18 |
19 | def __init__(self, in_chans=3, num_classes=1000, head_init_scale=1., ** kwargs):
20 | super().__init__()
21 |
22 |
23 | self.norm = nn.LayerNorm(dims[-2], eps=1e-6) # Final norm layer
24 | self.head = nn.Linear(dims[-2], num_classes)
25 |
26 | self.apply(self._init_weights)
27 | self.head_dropout = nn.Dropout()
28 | self.head.weight.data.mul_(head_init_scale)
29 | self.head.bias.data.mul_(head_init_scale)
30 |
31 |
32 | def forward_features(self, x):
33 | # x = (x - 0.45) / 0.225 don't do the normalization because the training script already does that! so please keep this line commented.
34 |
35 | for i in range(1, 3):
36 | ...
37 |
38 | return self.norm(x.mean([-2, -1])) # Global average pooling, (N, C, H, W) -> (N, C)
39 |
40 | def forward(self, x):
41 | x = self.forward_features(x)
42 | x = self.head(x)
43 | return x
44 |
45 |
--------------------------------------------------------------------------------
/lite-mono-pretrain-code/optim_factory.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import optim as optim
3 | from timm.optim.adafactor import Adafactor
4 | from timm.optim.adahessian import Adahessian
5 | from timm.optim.adamp import AdamP
6 | from timm.optim.lookahead import Lookahead
7 | from timm.optim.nadam import Nadam
8 | from timm.optim.novograd import NovoGrad
9 | from timm.optim.nvnovograd import NvNovoGrad
10 | from timm.optim.radam import RAdam
11 | from timm.optim.rmsprop_tf import RMSpropTF
12 | from timm.optim.sgdp import SGDP
13 |
14 | import json
15 |
16 | try:
17 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
18 |
19 | has_apex = True
20 | except ImportError:
21 | has_apex = False
22 |
23 |
24 | def get_num_layer_for_convnext(var_name):
25 | """
26 | Divide [3, 3, 27, 3] layers into 12 groups; each group is three
27 | consecutive blocks, including possible neighboring downsample layers;
28 | adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
29 | """
30 | num_max_layer = 12
31 | if var_name.startswith("downsample_layers"):
32 | stage_id = int(var_name.split('.')[1])
33 | if stage_id == 0:
34 | layer_id = 0
35 | elif stage_id == 1 or stage_id == 2:
36 | layer_id = stage_id + 1
37 | elif stage_id == 3:
38 | layer_id = 12
39 | return layer_id
40 |
41 | elif var_name.startswith("stages"):
42 | stage_id = int(var_name.split('.')[1])
43 | block_id = int(var_name.split('.')[2])
44 | if stage_id == 0 or stage_id == 1:
45 | layer_id = stage_id + 1
46 | elif stage_id == 2:
47 | layer_id = 3 + block_id // 3
48 | elif stage_id == 3:
49 | layer_id = 12
50 | return layer_id
51 | else:
52 | return num_max_layer + 1
53 |
54 |
55 | class LayerDecayValueAssigner(object):
56 | def __init__(self, values):
57 | self.values = values
58 |
59 | def get_scale(self, layer_id):
60 | return self.values[layer_id]
61 |
62 | def get_layer_id(self, var_name):
63 | return get_num_layer_for_convnext(var_name)
64 |
65 |
66 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
67 | parameter_group_names = {}
68 | parameter_group_vars = {}
69 |
70 | for name, param in model.named_parameters():
71 | if not param.requires_grad:
72 | continue # frozen weights
73 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
74 | group_name = "no_decay"
75 | this_weight_decay = 0.
76 | else:
77 | group_name = "decay"
78 | this_weight_decay = weight_decay
79 | if get_num_layer is not None:
80 | layer_id = get_num_layer(name)
81 | group_name = "layer_%d_%s" % (layer_id, group_name)
82 | else:
83 | layer_id = None
84 |
85 | if group_name not in parameter_group_names:
86 | if get_layer_scale is not None:
87 | scale = get_layer_scale(layer_id)
88 | else:
89 | scale = 1.
90 |
91 | parameter_group_names[group_name] = {
92 | "weight_decay": this_weight_decay,
93 | "params": [],
94 | "lr_scale": scale
95 | }
96 | parameter_group_vars[group_name] = {
97 | "weight_decay": this_weight_decay,
98 | "params": [],
99 | "lr_scale": scale
100 | }
101 |
102 | parameter_group_vars[group_name]["params"].append(param)
103 | parameter_group_names[group_name]["params"].append(name)
104 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
105 | return list(parameter_group_vars.values())
106 |
107 |
108 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None):
109 | opt_lower = args.opt.lower()
110 | weight_decay = args.weight_decay
111 | # if weight_decay and filter_bias_and_bn:
112 | if filter_bias_and_bn:
113 | skip = {}
114 | if skip_list is not None:
115 | skip = skip_list
116 | elif hasattr(model, 'no_weight_decay'):
117 | skip = model.no_weight_decay()
118 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)
119 | weight_decay = 0.
120 | else:
121 | parameters = model.parameters()
122 |
123 | if 'fused' in opt_lower:
124 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
125 |
126 | opt_args = dict(lr=args.lr, weight_decay=weight_decay)
127 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
128 | opt_args['eps'] = args.opt_eps
129 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
130 | opt_args['betas'] = args.opt_betas
131 |
132 | opt_split = opt_lower.split('_')
133 | opt_lower = opt_split[-1]
134 | if opt_lower == 'sgd' or opt_lower == 'nesterov':
135 | opt_args.pop('eps', None)
136 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
137 | elif opt_lower == 'momentum':
138 | opt_args.pop('eps', None)
139 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
140 | elif opt_lower == 'adam':
141 | optimizer = optim.Adam(parameters, **opt_args)
142 | elif opt_lower == 'adamw':
143 | optimizer = optim.AdamW(parameters, **opt_args)
144 | elif opt_lower == 'nadam':
145 | optimizer = Nadam(parameters, **opt_args)
146 | elif opt_lower == 'radam':
147 | optimizer = RAdam(parameters, **opt_args)
148 | elif opt_lower == 'adamp':
149 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
150 | elif opt_lower == 'sgdp':
151 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
152 | elif opt_lower == 'adadelta':
153 | optimizer = optim.Adadelta(parameters, **opt_args)
154 | elif opt_lower == 'adafactor':
155 | if not args.lr:
156 | opt_args['lr'] = None
157 | optimizer = Adafactor(parameters, **opt_args)
158 | elif opt_lower == 'adahessian':
159 | optimizer = Adahessian(parameters, **opt_args)
160 | elif opt_lower == 'rmsprop':
161 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
162 | elif opt_lower == 'rmsproptf':
163 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
164 | elif opt_lower == 'novograd':
165 | optimizer = NovoGrad(parameters, **opt_args)
166 | elif opt_lower == 'nvnovograd':
167 | optimizer = NvNovoGrad(parameters, **opt_args)
168 | elif opt_lower == 'fusedsgd':
169 | opt_args.pop('eps', None)
170 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
171 | elif opt_lower == 'fusedmomentum':
172 | opt_args.pop('eps', None)
173 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
174 | elif opt_lower == 'fusedadam':
175 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
176 | elif opt_lower == 'fusedadamw':
177 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
178 | elif opt_lower == 'fusedlamb':
179 | optimizer = FusedLAMB(parameters, **opt_args)
180 | elif opt_lower == 'fusednovograd':
181 | opt_args.setdefault('betas', (0.95, 0.98))
182 | optimizer = FusedNovoGrad(parameters, **opt_args)
183 | else:
184 | assert False and "Invalid optimizer"
185 |
186 | if len(opt_split) > 1:
187 | if opt_split[0] == 'lookahead':
188 | optimizer = Lookahead(optimizer)
189 |
190 | return optimizer
191 |
--------------------------------------------------------------------------------
/lite-mono-pretrain-code/requirements.txt:
--------------------------------------------------------------------------------
1 | timm==0.4.12
2 | tensorboardX==2.2
3 | six==1.16.0
4 | fvcore==0.1.5.post20220414
5 | protobuf==3.20.*
--------------------------------------------------------------------------------
/lite-mono-pretrain-code/sampler.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data.sampler import Sampler
2 | import torch.distributed as dist
3 | import math
4 | import random
5 | import numpy as np
6 | from torchvision.datasets import ImageFolder
7 | from timm.data import create_transform
8 | from timm.data.constants import \
9 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
10 | from torchvision import transforms
11 | from typing import Tuple
12 | from typing import Optional, Union
13 |
14 |
15 | class MultiScaleSamplerDDP(Sampler):
16 | def __init__(self, base_im_w: int, base_im_h: int, base_batch_size: int, n_data_samples: int,
17 | min_crop_size_w: int = 160, max_crop_size_w: int = 320,
18 | min_crop_size_h: int = 160, max_crop_size_h: int = 320,
19 | n_scales: int = 5, is_training: bool = True, distributed=True) -> None:
20 | # min. and max. spatial dimensions
21 | min_im_w, max_im_w = min_crop_size_w, max_crop_size_w
22 | min_im_h, max_im_h = min_crop_size_h, max_crop_size_h
23 |
24 | # Get the GPU and node related information
25 | if not distributed:
26 | num_replicas = 1
27 | rank = 0
28 | else:
29 | num_replicas = dist.get_world_size()
30 | rank = dist.get_rank()
31 |
32 | # adjust the total samples to avoid batch dropping
33 | num_samples_per_replica = int(math.ceil(n_data_samples * 1.0 / num_replicas))
34 | total_size = num_samples_per_replica * num_replicas
35 | img_indices = [idx for idx in range(n_data_samples)]
36 | img_indices += img_indices[:(total_size - n_data_samples)]
37 | assert len(img_indices) == total_size
38 |
39 | self.shuffle = True if is_training else False
40 | if is_training:
41 | self.img_batch_pairs = _image_batch_pairs(base_im_w, base_im_h, base_batch_size, num_replicas, n_scales, 32,
42 | min_im_w, max_im_w, min_im_h, max_im_h)
43 | else:
44 | self.img_batch_pairs = [(base_im_h, base_im_w, base_batch_size)]
45 |
46 | self.img_indices = img_indices
47 | self.n_samples_per_replica = num_samples_per_replica
48 | self.epoch = 0
49 | self.rank = rank
50 | self.num_replicas = num_replicas
51 | self.batch_size_gpu0 = base_batch_size
52 |
53 | def __iter__(self):
54 | if self.shuffle:
55 | random.seed(self.epoch)
56 | random.shuffle(self.img_indices)
57 | random.shuffle(self.img_batch_pairs)
58 | indices_rank_i = self.img_indices[self.rank:len(self.img_indices):self.num_replicas]
59 | else:
60 | indices_rank_i = self.img_indices[self.rank:len(self.img_indices):self.num_replicas]
61 |
62 | start_index = 0
63 | while start_index < self.n_samples_per_replica:
64 | curr_h, curr_w, curr_bsz = random.choice(self.img_batch_pairs)
65 |
66 | end_index = min(start_index + curr_bsz, self.n_samples_per_replica)
67 | batch_ids = indices_rank_i[start_index:end_index]
68 | n_batch_samples = len(batch_ids)
69 | if n_batch_samples != curr_bsz:
70 | batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
71 | start_index += curr_bsz
72 |
73 | if len(batch_ids) > 0:
74 | batch = [(curr_h, curr_w, b_id) for b_id in batch_ids]
75 | yield batch
76 |
77 | def set_epoch(self, epoch: int) -> None:
78 | self.epoch = epoch
79 |
80 | def __len__(self):
81 | return self.n_samples_per_replica
82 |
83 |
84 | def _image_batch_pairs(crop_size_w: int,
85 | crop_size_h: int,
86 | batch_size_gpu0: int,
87 | n_gpus: int,
88 | max_scales: Optional[float] = 5,
89 | check_scale_div_factor: Optional[int] = 32,
90 | min_crop_size_w: Optional[int] = 160,
91 | max_crop_size_w: Optional[int] = 320,
92 | min_crop_size_h: Optional[int] = 160,
93 | max_crop_size_h: Optional[int] = 320,
94 | *args, **kwargs) -> list:
95 | """
96 | This function creates batch and image size pairs. For a given batch size and image size, different image sizes
97 | are generated and batch size is adjusted so that GPU memory can be utilized efficiently.
98 |
99 | :param crop_size_w: Base Image width (e.g., 224)
100 | :param crop_size_h: Base Image height (e.g., 224)
101 | :param batch_size_gpu0: Batch size on GPU 0 for base image
102 | :param n_gpus: Number of available GPUs
103 | :param max_scales: Number of scales. How many image sizes that we want to generate between min and max scale factors.
104 | :param check_scale_div_factor: Check if image scales are divisible by this factor.
105 | :param min_crop_size_w: Min. crop size along width
106 | :param max_crop_size_w: Max. crop size along width
107 | :param min_crop_size_h: Min. crop size along height
108 | :param max_crop_size_h: Max. crop size along height
109 | :param args:
110 | :param kwargs:
111 | :return: a sorted list of tuples. Each index is of the form (h, w, batch_size)
112 | """
113 |
114 | width_dims = list(np.linspace(min_crop_size_w, max_crop_size_w, max_scales))
115 | if crop_size_w not in width_dims:
116 | width_dims.append(crop_size_w)
117 |
118 | height_dims = list(np.linspace(min_crop_size_h, max_crop_size_h, max_scales))
119 | if crop_size_h not in height_dims:
120 | height_dims.append(crop_size_h)
121 |
122 | image_scales = set()
123 |
124 | for h, w in zip(height_dims, width_dims):
125 | # ensure that sampled sizes are divisible by check_scale_div_factor
126 | # This is important in some cases where input undergoes a fixed number of down-sampling stages
127 | # for instance, in ImageNet training, CNNs usually have 5 downsampling stages, which downsamples the
128 | # input image of resolution 224x224 to 7x7 size
129 | h = make_divisible(h, check_scale_div_factor)
130 | w = make_divisible(w, check_scale_div_factor)
131 | image_scales.add((h, w))
132 |
133 | image_scales = list(image_scales)
134 |
135 | img_batch_tuples = set()
136 | n_elements = crop_size_w * crop_size_h * batch_size_gpu0
137 | for (crop_h, crop_y) in image_scales:
138 | # compute the batch size for sampled image resolutions with respect to the base resolution
139 | _bsz = max(batch_size_gpu0, int(round(n_elements/(crop_h * crop_y), 2)))
140 |
141 | _bsz = make_divisible(_bsz, n_gpus)
142 | _bsz = _bsz if _bsz % 2 == 0 else _bsz - 1 # Batch size must be even
143 | img_batch_tuples.add((crop_h, crop_y, _bsz))
144 |
145 | img_batch_tuples = list(img_batch_tuples)
146 | return sorted(img_batch_tuples)
147 |
148 |
149 | def make_divisible(v: Union[float, int],
150 | divisor: Optional[int] = 8,
151 | min_value: Optional[Union[float, int]] = None) -> Union[float, int]:
152 | """
153 | This function is taken from the original tf repo.
154 | It ensures that all layers have a channel number that is divisible by 8
155 | It can be seen here:
156 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
157 | :param v:
158 | :param divisor:
159 | :param min_value:
160 | :return:
161 | """
162 | if min_value is None:
163 | min_value = divisor
164 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
165 | # Make sure that round down does not go down by more than 10%.
166 | if new_v < 0.9 * v:
167 | new_v += divisor
168 | return new_v
169 |
170 |
171 | class MultiScaleImageFolder(ImageFolder):
172 | def __init__(self, root, args) -> None:
173 | self.args = args
174 | ImageFolder.__init__(self, root=root, transform=None, target_transform=None, is_valid_file=None)
175 |
176 | def get_transforms(self, size: int):
177 | imagenet_default_mean_and_std = self.args.imagenet_default_mean_and_std
178 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
179 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
180 | resize_im = size > 32
181 | transform = create_transform(
182 | input_size=size,
183 | is_training=True,
184 | color_jitter=self.args.color_jitter,
185 | auto_augment=self.args.aa,
186 | interpolation=self.args.train_interpolation,
187 | re_prob=self.args.reprob,
188 | re_mode=self.args.remode,
189 | re_count=self.args.recount,
190 | mean=mean,
191 | std=std,
192 | )
193 | if not resize_im:
194 | transform.transforms[0] = transforms.RandomCrop(size, padding=4)
195 |
196 | return transform
197 |
198 | def __getitem__(self, batch_indexes_tup: Tuple):
199 | crop_size_h, crop_size_w, img_index = batch_indexes_tup
200 | transforms = self.get_transforms(size=int(crop_size_w))
201 |
202 | path, target = self.samples[img_index]
203 | sample = self.loader(path)
204 | if transforms is not None:
205 | sample = transforms(sample)
206 | if self.target_transform is not None:
207 | target = self.target_transform(target)
208 |
209 | return sample, target
210 |
211 | def __len__(self):
212 | return len(self.samples)
213 |
--------------------------------------------------------------------------------
/lite-mono-pretrain-code/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import time
4 | from collections import defaultdict, deque
5 | import datetime
6 | import numpy as np
7 | from timm.utils import get_state_dict
8 |
9 | from pathlib import Path
10 |
11 | import torch
12 | import torch.distributed as dist
13 | from torch._six import inf
14 |
15 | from tensorboardX import SummaryWriter
16 |
17 | import subprocess
18 |
19 |
20 | class SmoothedValue(object):
21 | """Track a series of values and provide access to smoothed values over a
22 | window or the global series average.
23 | """
24 |
25 | def __init__(self, window_size=20, fmt=None):
26 | if fmt is None:
27 | fmt = "{median:.4f} ({global_avg:.4f})"
28 | self.deque = deque(maxlen=window_size)
29 | self.total = 0.0
30 | self.count = 0
31 | self.fmt = fmt
32 |
33 | def update(self, value, n=1):
34 | self.deque.append(value)
35 | self.count += n
36 | self.total += value * n
37 |
38 | def synchronize_between_processes(self):
39 | """
40 | Warning: does not synchronize the deque!
41 | """
42 | if not is_dist_avail_and_initialized():
43 | return
44 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
45 | dist.barrier()
46 | dist.all_reduce(t)
47 | t = t.tolist()
48 | self.count = int(t[0])
49 | self.total = t[1]
50 |
51 | @property
52 | def median(self):
53 | d = torch.tensor(list(self.deque))
54 | return d.median().item()
55 |
56 | @property
57 | def avg(self):
58 | d = torch.tensor(list(self.deque), dtype=torch.float32)
59 | return d.mean().item()
60 |
61 | @property
62 | def global_avg(self):
63 | return self.total / self.count
64 |
65 | @property
66 | def max(self):
67 | return max(self.deque)
68 |
69 | @property
70 | def value(self):
71 | return self.deque[-1]
72 |
73 | def __str__(self):
74 | return self.fmt.format(
75 | median=self.median,
76 | avg=self.avg,
77 | global_avg=self.global_avg,
78 | max=self.max,
79 | value=self.value)
80 |
81 |
82 | class MetricLogger(object):
83 | def __init__(self, delimiter="\t"):
84 | self.meters = defaultdict(SmoothedValue)
85 | self.delimiter = delimiter
86 |
87 | def update(self, **kwargs):
88 | for k, v in kwargs.items():
89 | if v is None:
90 | continue
91 | if isinstance(v, torch.Tensor):
92 | v = v.item()
93 | assert isinstance(v, (float, int))
94 | self.meters[k].update(v)
95 |
96 | def __getattr__(self, attr):
97 | if attr in self.meters:
98 | return self.meters[attr]
99 | if attr in self.__dict__:
100 | return self.__dict__[attr]
101 | raise AttributeError("'{}' object has no attribute '{}'".format(
102 | type(self).__name__, attr))
103 |
104 | def __str__(self):
105 | loss_str = []
106 | for name, meter in self.meters.items():
107 | loss_str.append(
108 | "{}: {}".format(name, str(meter))
109 | )
110 | return self.delimiter.join(loss_str)
111 |
112 | def synchronize_between_processes(self):
113 | for meter in self.meters.values():
114 | meter.synchronize_between_processes()
115 |
116 | def add_meter(self, name, meter):
117 | self.meters[name] = meter
118 |
119 | def log_every(self, iterable, print_freq, header=None):
120 | i = 0
121 | if not header:
122 | header = ''
123 | start_time = time.time()
124 | end = time.time()
125 | iter_time = SmoothedValue(fmt='{avg:.4f}')
126 | data_time = SmoothedValue(fmt='{avg:.4f}')
127 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
128 | log_msg = [
129 | header,
130 | '[{0' + space_fmt + '}/{1}]',
131 | 'eta: {eta}',
132 | '{meters}',
133 | 'time: {time}',
134 | 'data: {data}'
135 | ]
136 | if torch.cuda.is_available():
137 | log_msg.append('max mem: {memory:.0f}')
138 | log_msg = self.delimiter.join(log_msg)
139 | MB = 1024.0 * 1024.0
140 | for obj in iterable:
141 | data_time.update(time.time() - end)
142 | yield obj
143 | iter_time.update(time.time() - end)
144 | if i % print_freq == 0 or i == len(iterable) - 1:
145 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
146 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
147 | if torch.cuda.is_available():
148 | print(log_msg.format(
149 | i, len(iterable), eta=eta_string,
150 | meters=str(self),
151 | time=str(iter_time), data=str(data_time),
152 | memory=torch.cuda.max_memory_allocated() / MB))
153 | else:
154 | print(log_msg.format(
155 | i, len(iterable), eta=eta_string,
156 | meters=str(self),
157 | time=str(iter_time), data=str(data_time)))
158 | i += 1
159 | end = time.time()
160 | total_time = time.time() - start_time
161 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
162 | print('{} Total time: {} ({:.4f} s / it)'.format(
163 | header, total_time_str, total_time / len(iterable)))
164 |
165 |
166 | class TensorboardLogger(object):
167 | def __init__(self, log_dir):
168 | self.writer = SummaryWriter(logdir=log_dir)
169 | self.step = 0
170 |
171 | def set_step(self, step=None):
172 | if step is not None:
173 | self.step = step
174 | else:
175 | self.step += 1
176 |
177 | def update(self, head='scalar', step=None, **kwargs):
178 | for k, v in kwargs.items():
179 | if v is None:
180 | continue
181 | if isinstance(v, torch.Tensor):
182 | v = v.item()
183 | assert isinstance(v, (float, int))
184 | self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step)
185 |
186 | def flush(self):
187 | self.writer.flush()
188 |
189 |
190 | class WandbLogger(object):
191 | def __init__(self, args):
192 | self.args = args
193 |
194 | try:
195 | import wandb
196 | self._wandb = wandb
197 | except ImportError:
198 | raise ImportError(
199 | "To use the Weights and Biases Logger please install wandb."
200 | "Run `pip install wandb` to install it."
201 | )
202 |
203 | # Initialize a W&B run
204 | if self._wandb.run is None:
205 | self._wandb.init(
206 | project=args.project,
207 | config=args
208 | )
209 |
210 | def log_epoch_metrics(self, metrics, commit=True):
211 | """
212 | Log train/test metrics onto W&B.
213 | """
214 | # Log number of model parameters as W&B summary
215 | self._wandb.summary['n_parameters'] = metrics.get('n_parameters', None)
216 | metrics.pop('n_parameters', None)
217 |
218 | # Log current epoch
219 | self._wandb.log({'epoch': metrics.get('epoch')}, commit=False)
220 | metrics.pop('epoch')
221 |
222 | for k, v in metrics.items():
223 | if 'train' in k:
224 | self._wandb.log({f'Global Train/{k}': v}, commit=False)
225 | elif 'test' in k:
226 | self._wandb.log({f'Global Test/{k}': v}, commit=False)
227 |
228 | self._wandb.log({})
229 |
230 | def log_checkpoints(self):
231 | output_dir = self.args.output_dir
232 | model_artifact = self._wandb.Artifact(
233 | self._wandb.run.id + "_model", type="model"
234 | )
235 |
236 | model_artifact.add_dir(output_dir)
237 | self._wandb.log_artifact(model_artifact, aliases=["latest", "best"])
238 |
239 | def set_steps(self):
240 | # Set global training step
241 | self._wandb.define_metric('Rank-0 Batch Wise/*', step_metric='Rank-0 Batch Wise/global_train_step')
242 | # Set epoch-wise step
243 | self._wandb.define_metric('Global Train/*', step_metric='epoch')
244 | self._wandb.define_metric('Global Test/*', step_metric='epoch')
245 |
246 |
247 | def setup_for_distributed(is_master):
248 | """
249 | This function disables printing when not in master process
250 | """
251 | import builtins as __builtin__
252 | builtin_print = __builtin__.print
253 |
254 | def print(*args, **kwargs):
255 | force = kwargs.pop('force', False)
256 | if is_master or force:
257 | builtin_print(*args, **kwargs)
258 |
259 | __builtin__.print = print
260 |
261 |
262 | def is_dist_avail_and_initialized():
263 | if not dist.is_available():
264 | return False
265 | if not dist.is_initialized():
266 | return False
267 | return True
268 |
269 |
270 | def get_world_size():
271 | if not is_dist_avail_and_initialized():
272 | return 1
273 | return dist.get_world_size()
274 |
275 |
276 | def get_rank():
277 | if not is_dist_avail_and_initialized():
278 | return 0
279 | return dist.get_rank()
280 |
281 |
282 | def is_main_process():
283 | return get_rank() == 0
284 |
285 |
286 | def save_on_master(*args, **kwargs):
287 | if is_main_process():
288 | torch.save(*args, **kwargs)
289 |
290 |
291 | def init_distributed_mode(args):
292 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
293 | args.rank = int(os.environ["RANK"])
294 | args.world_size = int(os.environ['WORLD_SIZE'])
295 | args.gpu = int(os.environ['LOCAL_RANK'])
296 | args.dist_url = 'env://'
297 | os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count())
298 | print('Using distributed mode: 1')
299 | elif 'SLURM_PROCID' in os.environ:
300 | proc_id = int(os.environ['SLURM_PROCID'])
301 | ntasks = int(os.environ['SLURM_NTASKS'])
302 | node_list = os.environ['SLURM_NODELIST']
303 | num_gpus = torch.cuda.device_count()
304 | addr = subprocess.getoutput(
305 | 'scontrol show hostname {} | head -n1'.format(node_list))
306 | os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500')
307 | os.environ['MASTER_ADDR'] = addr
308 | os.environ['WORLD_SIZE'] = str(ntasks)
309 | os.environ['RANK'] = str(proc_id)
310 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
311 | os.environ['LOCAL_SIZE'] = str(num_gpus)
312 | args.dist_url = 'env://'
313 | args.world_size = ntasks
314 | args.rank = proc_id
315 | args.gpu = proc_id % num_gpus
316 | print('Using distributed mode: slurm')
317 | print(f"world: {os.environ['WORLD_SIZE']}, rank:{os.environ['RANK']},"
318 | f" local_rank{os.environ['LOCAL_RANK']}, local_size{os.environ['LOCAL_SIZE']}")
319 | else:
320 | print('Not using distributed mode')
321 | args.distributed = False
322 | return
323 |
324 | args.distributed = True
325 |
326 | torch.cuda.set_device(args.gpu)
327 | args.dist_backend = 'nccl'
328 | print('| distributed init (rank {}): {}'.format(
329 | args.rank, args.dist_url), flush=True)
330 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
331 | world_size=args.world_size, rank=args.rank)
332 | torch.distributed.barrier()
333 | setup_for_distributed(args.rank == 0)
334 |
335 |
336 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
337 | missing_keys = []
338 | unexpected_keys = []
339 | error_msgs = []
340 | # copy state_dict so _load_from_state_dict can modify it
341 | metadata = getattr(state_dict, '_metadata', None)
342 | state_dict = state_dict.copy()
343 | if metadata is not None:
344 | state_dict._metadata = metadata
345 |
346 | def load(module, prefix=''):
347 | local_metadata = {} if metadata is None else metadata.get(
348 | prefix[:-1], {})
349 | module._load_from_state_dict(
350 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
351 | for name, child in module._modules.items():
352 | if child is not None:
353 | load(child, prefix + name + '.')
354 |
355 | load(model, prefix=prefix)
356 |
357 | warn_missing_keys = []
358 | ignore_missing_keys = []
359 | for key in missing_keys:
360 | keep_flag = True
361 | for ignore_key in ignore_missing.split('|'):
362 | if ignore_key in key:
363 | keep_flag = False
364 | break
365 | if keep_flag:
366 | warn_missing_keys.append(key)
367 | else:
368 | ignore_missing_keys.append(key)
369 |
370 | missing_keys = warn_missing_keys
371 |
372 | if len(missing_keys) > 0:
373 | print("Weights of {} not initialized from pretrained model: {}".format(
374 | model.__class__.__name__, missing_keys))
375 | if len(unexpected_keys) > 0:
376 | print("Weights from pretrained model not used in {}: {}".format(
377 | model.__class__.__name__, unexpected_keys))
378 | if len(ignore_missing_keys) > 0:
379 | print("Ignored weights of {} not initialized from pretrained model: {}".format(
380 | model.__class__.__name__, ignore_missing_keys))
381 | if len(error_msgs) > 0:
382 | print('\n'.join(error_msgs))
383 |
384 |
385 | class NativeScalerWithGradNormCount:
386 | state_dict_key = "amp_scaler"
387 |
388 | def __init__(self):
389 | self._scaler = torch.cuda.amp.GradScaler()
390 |
391 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
392 | self._scaler.scale(loss).backward(create_graph=create_graph)
393 | if update_grad:
394 | if clip_grad is not None:
395 | assert parameters is not None
396 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
397 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
398 | else:
399 | self._scaler.unscale_(optimizer)
400 | norm = get_grad_norm_(parameters)
401 | self._scaler.step(optimizer)
402 | self._scaler.update()
403 | else:
404 | norm = None
405 | return norm
406 |
407 | def state_dict(self):
408 | return self._scaler.state_dict()
409 |
410 | def load_state_dict(self, state_dict):
411 | self._scaler.load_state_dict(state_dict)
412 |
413 |
414 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
415 | if isinstance(parameters, torch.Tensor):
416 | parameters = [parameters]
417 | parameters = [p for p in parameters if p.grad is not None]
418 | norm_type = float(norm_type)
419 | if len(parameters) == 0:
420 | return torch.tensor(0.)
421 | device = parameters[0].grad.device
422 | if norm_type == inf:
423 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
424 | else:
425 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
426 | return total_norm
427 |
428 |
429 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
430 | start_warmup_value=0, warmup_steps=-1):
431 | warmup_schedule = np.array([])
432 | warmup_iters = warmup_epochs * niter_per_ep
433 | if warmup_steps > 0:
434 | warmup_iters = warmup_steps
435 | print("Set warmup steps = %d" % warmup_iters)
436 | if warmup_epochs > 0:
437 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
438 |
439 | iters = np.arange(epochs * niter_per_ep - warmup_iters)
440 | schedule = np.array(
441 | [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
442 |
443 | schedule = np.concatenate((warmup_schedule, schedule))
444 |
445 | assert len(schedule) == epochs * niter_per_ep
446 | return schedule
447 |
448 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
449 | output_dir = Path(args.output_dir)
450 | epoch_name = str(epoch)
451 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
452 | for checkpoint_path in checkpoint_paths:
453 | to_save = {
454 | 'model': model_without_ddp.state_dict(),
455 | 'optimizer': optimizer.state_dict(),
456 | 'epoch': epoch,
457 | 'scaler': loss_scaler.state_dict(),
458 | 'args': args,
459 | }
460 |
461 | if model_ema is not None:
462 | to_save['model_ema'] = get_state_dict(model_ema)
463 |
464 | save_on_master(to_save, checkpoint_path)
465 |
466 | if is_main_process() and isinstance(epoch, int):
467 | to_del = epoch - args.save_ckpt_num * args.save_ckpt_freq
468 | old_ckpt = output_dir / ('checkpoint-%s.pth' % to_del)
469 | if os.path.exists(old_ckpt):
470 | os.remove(old_ckpt)
471 |
472 |
473 | def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None, state_dict_name='model'):
474 | output_dir = Path(args.output_dir)
475 | if args.auto_resume and len(args.resume) == 0:
476 | import glob
477 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
478 | latest_ckpt = -1
479 | for ckpt in all_checkpoints:
480 | t = ckpt.split('-')[-1].split('.')[0]
481 | if t.isdigit():
482 | latest_ckpt = max(int(t), latest_ckpt)
483 | if latest_ckpt >= 0:
484 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
485 | print("Auto resume checkpoint: %s" % args.resume)
486 |
487 | if args.resume:
488 | if args.resume.startswith('https'):
489 | checkpoint = torch.hub.load_state_dict_from_url(
490 | args.resume, map_location='cpu', check_hash=True)
491 | else:
492 | checkpoint = torch.load(args.resume, map_location='cpu')
493 | model_without_ddp.load_state_dict(checkpoint[state_dict_name])
494 | print("Resume checkpoint %s" % args.resume)
495 | if 'optimizer' in checkpoint and 'epoch' in checkpoint:
496 | optimizer.load_state_dict(checkpoint['optimizer'])
497 | if not isinstance(checkpoint['epoch'], str): # does not support resuming with 'best', 'best-ema'
498 | args.start_epoch = checkpoint['epoch'] + 1
499 | else:
500 | assert args.eval, 'Does not support resuming with checkpoint-best'
501 | if hasattr(args, 'model_ema') and args.model_ema:
502 | if 'model_ema' in checkpoint.keys():
503 | model_ema.ema.load_state_dict(checkpoint['model_ema'])
504 | else:
505 | model_ema.ema.load_state_dict(checkpoint['model'])
506 | if 'scaler' in checkpoint:
507 | loss_scaler.load_state_dict(checkpoint['scaler'])
508 | print("With optim & sched!")
509 |
--------------------------------------------------------------------------------
/networks/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet_encoder import ResnetEncoder
2 | from .pose_decoder import PoseDecoder
3 | from .depth_decoder import DepthDecoder
4 | from .depth_encoder import LiteMono
5 |
--------------------------------------------------------------------------------
/networks/depth_decoder.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | from collections import OrderedDict
3 | from layers import *
4 | from timm.models.layers import trunc_normal_
5 |
6 |
7 | class DepthDecoder(nn.Module):
8 | def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True):
9 | super().__init__()
10 |
11 | self.num_output_channels = num_output_channels
12 | self.use_skips = use_skips
13 | self.upsample_mode = 'bilinear'
14 | self.scales = scales
15 |
16 | self.num_ch_enc = num_ch_enc
17 | self.num_ch_dec = (self.num_ch_enc / 2).astype('int')
18 |
19 | # decoder
20 | self.convs = OrderedDict()
21 | for i in range(2, -1, -1):
22 | # upconv_0
23 | num_ch_in = self.num_ch_enc[-1] if i == 2 else self.num_ch_dec[i + 1]
24 | num_ch_out = self.num_ch_dec[i]
25 | self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)
26 | # print(i, num_ch_in, num_ch_out)
27 | # upconv_1
28 | num_ch_in = self.num_ch_dec[i]
29 | if self.use_skips and i > 0:
30 | num_ch_in += self.num_ch_enc[i - 1]
31 | num_ch_out = self.num_ch_dec[i]
32 | self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)
33 |
34 | for s in self.scales:
35 | self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels)
36 |
37 | self.decoder = nn.ModuleList(list(self.convs.values()))
38 | self.sigmoid = nn.Sigmoid()
39 |
40 | self.apply(self._init_weights)
41 |
42 | def _init_weights(self, m):
43 | if isinstance(m, (nn.Conv2d, nn.Linear)):
44 | trunc_normal_(m.weight, std=.02)
45 | if m.bias is not None:
46 | nn.init.constant_(m.bias, 0)
47 |
48 | def forward(self, input_features):
49 | self.outputs = {}
50 | x = input_features[-1]
51 | for i in range(2, -1, -1):
52 | x = self.convs[("upconv", i, 0)](x)
53 | x = [upsample(x)]
54 |
55 | if self.use_skips and i > 0:
56 | x += [input_features[i - 1]]
57 | x = torch.cat(x, 1)
58 | x = self.convs[("upconv", i, 1)](x)
59 |
60 | if i in self.scales:
61 | f = upsample(self.convs[("dispconv", i)](x), mode='bilinear')
62 | self.outputs[("disp", i)] = self.sigmoid(f)
63 |
64 | return self.outputs
65 |
66 |
--------------------------------------------------------------------------------
/networks/depth_encoder.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 | from timm.models.layers import DropPath
6 | import math
7 | import torch.cuda
8 |
9 |
10 | class PositionalEncodingFourier(nn.Module):
11 | """
12 | Positional encoding relying on a fourier kernel matching the one used in the
13 | "Attention is all of Need" paper. The implementation builds on DeTR code
14 | https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
15 | """
16 |
17 | def __init__(self, hidden_dim=32, dim=768, temperature=10000):
18 | super().__init__()
19 | self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1)
20 | self.scale = 2 * math.pi
21 | self.temperature = temperature
22 | self.hidden_dim = hidden_dim
23 | self.dim = dim
24 |
25 | def forward(self, B, H, W):
26 | mask = torch.zeros(B, H, W).bool().to(self.token_projection.weight.device)
27 | not_mask = ~mask
28 | y_embed = not_mask.cumsum(1, dtype=torch.float32)
29 | x_embed = not_mask.cumsum(2, dtype=torch.float32)
30 | eps = 1e-6
31 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
32 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
33 |
34 | dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=mask.device)
35 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.hidden_dim)
36 |
37 | pos_x = x_embed[:, :, :, None] / dim_t
38 | pos_y = y_embed[:, :, :, None] / dim_t
39 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(),
40 | pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
41 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(),
42 | pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
43 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
44 | pos = self.token_projection(pos)
45 | return pos
46 |
47 |
48 | class XCA(nn.Module):
49 | """ Cross-Covariance Attention (XCA) operation where the channels are updated using a weighted
50 | sum. The weights are obtained from the (softmax normalized) Cross-covariance
51 | matrix (Q^T K \\in d_h \\times d_h)
52 | """
53 |
54 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
55 | super().__init__()
56 | self.num_heads = num_heads
57 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
58 |
59 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
60 | self.attn_drop = nn.Dropout(attn_drop)
61 | self.proj = nn.Linear(dim, dim)
62 | self.proj_drop = nn.Dropout(proj_drop)
63 |
64 | def forward(self, x):
65 | B, N, C = x.shape
66 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
67 | qkv = qkv.permute(2, 0, 3, 1, 4)
68 | q, k, v = qkv[0], qkv[1], qkv[2]
69 |
70 | q = q.transpose(-2, -1)
71 | k = k.transpose(-2, -1)
72 | v = v.transpose(-2, -1)
73 |
74 | q = torch.nn.functional.normalize(q, dim=-1)
75 | k = torch.nn.functional.normalize(k, dim=-1)
76 |
77 | attn = (q @ k.transpose(-2, -1)) * self.temperature
78 | attn = attn.softmax(dim=-1)
79 | attn = self.attn_drop(attn)
80 |
81 | x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C)
82 | x = self.proj(x)
83 | x = self.proj_drop(x)
84 | return x
85 |
86 | @torch.jit.ignore
87 | def no_weight_decay(self):
88 | return {'temperature'}
89 |
90 |
91 | class LayerNorm(nn.Module):
92 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
93 | super().__init__()
94 | self.weight = nn.Parameter(torch.ones(normalized_shape))
95 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
96 | self.eps = eps
97 | self.data_format = data_format
98 | if self.data_format not in ["channels_last", "channels_first"]:
99 | raise NotImplementedError
100 | self.normalized_shape = (normalized_shape,)
101 |
102 |
103 | def forward(self, x):
104 | if self.data_format == "channels_last":
105 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
106 | elif self.data_format == "channels_first":
107 | u = x.mean(1, keepdim=True)
108 | s = (x - u).pow(2).mean(1, keepdim=True)
109 | x = (x - u) / torch.sqrt(s + self.eps)
110 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
111 | return x
112 |
113 |
114 | class BNGELU(nn.Module):
115 | def __init__(self, nIn):
116 | super().__init__()
117 | self.bn = nn.BatchNorm2d(nIn, eps=1e-5)
118 | self.act = nn.GELU()
119 |
120 | def forward(self, x):
121 | output = self.bn(x)
122 | output = self.act(output)
123 |
124 | return output
125 |
126 |
127 | class Conv(nn.Module):
128 | def __init__(self, nIn, nOut, kSize, stride, padding=0, dilation=(1, 1), groups=1, bn_act=False, bias=False):
129 | super().__init__()
130 |
131 | self.bn_act = bn_act
132 |
133 | self.conv = nn.Conv2d(nIn, nOut, kernel_size=kSize,
134 | stride=stride, padding=padding,
135 | dilation=dilation, groups=groups, bias=bias)
136 |
137 | if self.bn_act:
138 | self.bn_gelu = BNGELU(nOut)
139 |
140 | def forward(self, x):
141 | output = self.conv(x)
142 |
143 | if self.bn_act:
144 | output = self.bn_gelu(output)
145 |
146 | return output
147 |
148 |
149 | class CDilated(nn.Module):
150 | """
151 | This class defines the dilated convolution.
152 | """
153 |
154 | def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1, bias=False):
155 | """
156 | :param nIn: number of input channels
157 | :param nOut: number of output channels
158 | :param kSize: kernel size
159 | :param stride: optional stride rate for down-sampling
160 | :param d: optional dilation rate
161 | """
162 | super().__init__()
163 | padding = int((kSize - 1) / 2) * d
164 | self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=bias,
165 | dilation=d, groups=groups)
166 |
167 | def forward(self, input):
168 | """
169 | :param input: input feature map
170 | :return: transformed feature map
171 | """
172 |
173 | output = self.conv(input)
174 | return output
175 |
176 |
177 | class DilatedConv(nn.Module):
178 | """
179 | A single Dilated Convolution layer in the Consecutive Dilated Convolutions (CDC) module.
180 | """
181 | def __init__(self, dim, k, dilation=1, stride=1, drop_path=0.,
182 | layer_scale_init_value=1e-6, expan_ratio=6):
183 | """
184 | :param dim: input dimension
185 | :param k: kernel size
186 | :param dilation: dilation rate
187 | :param drop_path: drop_path rate
188 | :param layer_scale_init_value:
189 | :param expan_ratio: inverted bottelneck residual
190 | """
191 |
192 | super().__init__()
193 |
194 | self.ddwconv = CDilated(dim, dim, kSize=k, stride=stride, groups=dim, d=dilation)
195 | self.bn1 = nn.BatchNorm2d(dim)
196 |
197 | self.norm = LayerNorm(dim, eps=1e-6)
198 | self.pwconv1 = nn.Linear(dim, expan_ratio * dim)
199 | self.act = nn.GELU()
200 | self.pwconv2 = nn.Linear(expan_ratio * dim, dim)
201 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim),
202 | requires_grad=True) if layer_scale_init_value > 0 else None
203 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
204 |
205 | def forward(self, x):
206 | input = x
207 |
208 | x = self.ddwconv(x)
209 | x = self.bn1(x)
210 |
211 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
212 | x = self.pwconv1(x)
213 | x = self.act(x)
214 | x = self.pwconv2(x)
215 | if self.gamma is not None:
216 | x = self.gamma * x
217 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
218 |
219 | x = input + self.drop_path(x)
220 |
221 | return x
222 |
223 |
224 | class LGFI(nn.Module):
225 | """
226 | Local-Global Features Interaction
227 | """
228 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, expan_ratio=6,
229 | use_pos_emb=True, num_heads=6, qkv_bias=True, attn_drop=0., drop=0.):
230 | super().__init__()
231 |
232 | self.dim = dim
233 | self.pos_embd = None
234 | if use_pos_emb:
235 | self.pos_embd = PositionalEncodingFourier(dim=self.dim)
236 |
237 | self.norm_xca = LayerNorm(self.dim, eps=1e-6)
238 |
239 | self.gamma_xca = nn.Parameter(layer_scale_init_value * torch.ones(self.dim),
240 | requires_grad=True) if layer_scale_init_value > 0 else None
241 | self.xca = XCA(self.dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
242 |
243 | self.norm = LayerNorm(self.dim, eps=1e-6)
244 | self.pwconv1 = nn.Linear(self.dim, expan_ratio * self.dim)
245 | self.act = nn.GELU()
246 | self.pwconv2 = nn.Linear(expan_ratio * self.dim, self.dim)
247 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((self.dim)),
248 | requires_grad=True) if layer_scale_init_value > 0 else None
249 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
250 |
251 | def forward(self, x):
252 | input_ = x
253 |
254 | # XCA
255 | B, C, H, W = x.shape
256 | x = x.reshape(B, C, H * W).permute(0, 2, 1)
257 |
258 | if self.pos_embd:
259 | pos_encoding = self.pos_embd(B, H, W).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
260 | x = x + pos_encoding
261 |
262 | x = x + self.gamma_xca * self.xca(self.norm_xca(x))
263 |
264 | x = x.reshape(B, H, W, C)
265 |
266 | # Inverted Bottleneck
267 | x = self.norm(x)
268 | x = self.pwconv1(x)
269 | x = self.act(x)
270 | x = self.pwconv2(x)
271 | if self.gamma is not None:
272 | x = self.gamma * x
273 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
274 |
275 | x = input_ + self.drop_path(x)
276 |
277 | return x
278 |
279 |
280 | class AvgPool(nn.Module):
281 | def __init__(self, ratio):
282 | super().__init__()
283 | self.pool = nn.ModuleList()
284 | for i in range(0, ratio):
285 | self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))
286 |
287 | def forward(self, x):
288 | for pool in self.pool:
289 | x = pool(x)
290 |
291 | return x
292 |
293 |
294 | class LiteMono(nn.Module):
295 | """
296 | Lite-Mono
297 | """
298 | def __init__(self, in_chans=3, model='lite-mono', height=192, width=640,
299 | global_block=[1, 1, 1], global_block_type=['LGFI', 'LGFI', 'LGFI'],
300 | drop_path_rate=0.2, layer_scale_init_value=1e-6, expan_ratio=6,
301 | heads=[8, 8, 8], use_pos_embd_xca=[True, False, False], **kwargs):
302 |
303 | super().__init__()
304 |
305 | if model == 'lite-mono':
306 | self.num_ch_enc = np.array([48, 80, 128])
307 | self.depth = [4, 4, 10]
308 | self.dims = [48, 80, 128]
309 | if height == 192 and width == 640:
310 | self.dilation = [[1, 2, 3], [1, 2, 3], [1, 2, 3, 1, 2, 3, 2, 4, 6]]
311 | elif height == 320 and width == 1024:
312 | self.dilation = [[1, 2, 5], [1, 2, 5], [1, 2, 5, 1, 2, 5, 2, 4, 10]]
313 |
314 | elif model == 'lite-mono-small':
315 | self.num_ch_enc = np.array([48, 80, 128])
316 | self.depth = [4, 4, 7]
317 | self.dims = [48, 80, 128]
318 | if height == 192 and width == 640:
319 | self.dilation = [[1, 2, 3], [1, 2, 3], [1, 2, 3, 2, 4, 6]]
320 | elif height == 320 and width == 1024:
321 | self.dilation = [[1, 2, 5], [1, 2, 5], [1, 2, 5, 2, 4, 10]]
322 |
323 | elif model == 'lite-mono-tiny':
324 | self.num_ch_enc = np.array([32, 64, 128])
325 | self.depth = [4, 4, 7]
326 | self.dims = [32, 64, 128]
327 | if height == 192 and width == 640:
328 | self.dilation = [[1, 2, 3], [1, 2, 3], [1, 2, 3, 2, 4, 6]]
329 | elif height == 320 and width == 1024:
330 | self.dilation = [[1, 2, 5], [1, 2, 5], [1, 2, 5, 2, 4, 10]]
331 |
332 | elif model == 'lite-mono-8m':
333 | self.num_ch_enc = np.array([64, 128, 224])
334 | self.depth = [4, 4, 10]
335 | self.dims = [64, 128, 224]
336 | if height == 192 and width == 640:
337 | self.dilation = [[1, 2, 3], [1, 2, 3], [1, 2, 3, 1, 2, 3, 2, 4, 6]]
338 | elif height == 320 and width == 1024:
339 | self.dilation = [[1, 2, 3], [1, 2, 3], [1, 2, 3, 1, 2, 3, 2, 4, 6]]
340 |
341 | for g in global_block_type:
342 | assert g in ['None', 'LGFI']
343 |
344 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
345 | stem1 = nn.Sequential(
346 | Conv(in_chans, self.dims[0], kSize=3, stride=2, padding=1, bn_act=True),
347 | Conv(self.dims[0], self.dims[0], kSize=3, stride=1, padding=1, bn_act=True),
348 | Conv(self.dims[0], self.dims[0], kSize=3, stride=1, padding=1, bn_act=True),
349 | )
350 |
351 | self.stem2 = nn.Sequential(
352 | Conv(self.dims[0]+3, self.dims[0], kSize=3, stride=2, padding=1, bn_act=False),
353 | )
354 |
355 | self.downsample_layers.append(stem1)
356 |
357 | self.input_downsample = nn.ModuleList()
358 | for i in range(1, 5):
359 | self.input_downsample.append(AvgPool(i))
360 |
361 | for i in range(2):
362 | downsample_layer = nn.Sequential(
363 | Conv(self.dims[i]*2+3, self.dims[i+1], kSize=3, stride=2, padding=1, bn_act=False),
364 | )
365 | self.downsample_layers.append(downsample_layer)
366 |
367 | self.stages = nn.ModuleList()
368 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depth))]
369 | cur = 0
370 | for i in range(3):
371 | stage_blocks = []
372 | for j in range(self.depth[i]):
373 | if j > self.depth[i] - global_block[i] - 1:
374 | if global_block_type[i] == 'LGFI':
375 | stage_blocks.append(LGFI(dim=self.dims[i], drop_path=dp_rates[cur + j],
376 | expan_ratio=expan_ratio,
377 | use_pos_emb=use_pos_embd_xca[i], num_heads=heads[i],
378 | layer_scale_init_value=layer_scale_init_value,
379 | ))
380 |
381 | else:
382 | raise NotImplementedError
383 | else:
384 | stage_blocks.append(DilatedConv(dim=self.dims[i], k=3, dilation=self.dilation[i][j], drop_path=dp_rates[cur + j],
385 | layer_scale_init_value=layer_scale_init_value,
386 | expan_ratio=expan_ratio))
387 |
388 | self.stages.append(nn.Sequential(*stage_blocks))
389 | cur += self.depth[i]
390 |
391 | self.apply(self._init_weights)
392 |
393 | def _init_weights(self, m):
394 | if isinstance(m, (nn.Conv2d, nn.Linear)):
395 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
396 |
397 | elif isinstance(m, (LayerNorm, nn.LayerNorm)):
398 | nn.init.constant_(m.bias, 0)
399 | nn.init.constant_(m.weight, 1.0)
400 |
401 | elif isinstance(m, nn.BatchNorm2d):
402 | nn.init.constant_(m.weight, 1)
403 | nn.init.constant_(m.bias, 0)
404 |
405 | def forward_features(self, x):
406 | features = []
407 | x = (x - 0.45) / 0.225
408 |
409 | x_down = []
410 | for i in range(4):
411 | x_down.append(self.input_downsample[i](x))
412 |
413 | tmp_x = []
414 | x = self.downsample_layers[0](x)
415 | x = self.stem2(torch.cat((x, x_down[0]), dim=1))
416 | tmp_x.append(x)
417 |
418 | for s in range(len(self.stages[0])-1):
419 | x = self.stages[0][s](x)
420 | x = self.stages[0][-1](x)
421 | tmp_x.append(x)
422 | features.append(x)
423 |
424 | for i in range(1, 3):
425 | tmp_x.append(x_down[i])
426 | x = torch.cat(tmp_x, dim=1)
427 | x = self.downsample_layers[i](x)
428 |
429 | tmp_x = [x]
430 | for s in range(len(self.stages[i]) - 1):
431 | x = self.stages[i][s](x)
432 | x = self.stages[i][-1](x)
433 | tmp_x.append(x)
434 |
435 | features.append(x)
436 |
437 | return features
438 |
439 | def forward(self, x):
440 | x = self.forward_features(x)
441 |
442 | return x
443 |
--------------------------------------------------------------------------------
/networks/pose_decoder.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | import torch
3 | import torch.nn as nn
4 | from collections import OrderedDict
5 | from timm.models.layers import trunc_normal_
6 |
7 |
8 | class PoseDecoder(nn.Module):
9 | def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=None, stride=1):
10 | super(PoseDecoder, self).__init__()
11 |
12 | self.num_ch_enc = num_ch_enc
13 | self.num_input_features = num_input_features
14 |
15 | if num_frames_to_predict_for is None:
16 | num_frames_to_predict_for = num_input_features - 1
17 | self.num_frames_to_predict_for = num_frames_to_predict_for
18 |
19 | self.convs = OrderedDict()
20 | self.convs[("squeeze")] = nn.Conv2d(self.num_ch_enc[-1], 256, 1)
21 | self.convs[("pose", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1)
22 | self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1)
23 | self.convs[("pose", 2)] = nn.Conv2d(256, 6 * num_frames_to_predict_for, 1)
24 |
25 | self.relu = nn.ReLU()
26 |
27 | self.net = nn.ModuleList(list(self.convs.values()))
28 |
29 | self.apply(self._init_weights)
30 |
31 | def _init_weights(self, m):
32 | if isinstance(m, (nn.Conv2d, nn.Linear)):
33 | if isinstance(m, (nn.Conv2d, nn.Linear)):
34 | trunc_normal_(m.weight, std=.02)
35 | if m.bias is not None:
36 | nn.init.constant_(m.bias, 0)
37 |
38 | def forward(self, input_features):
39 | last_features = [f[-1] for f in input_features]
40 |
41 | cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features]
42 | cat_features = torch.cat(cat_features, 1)
43 |
44 | out = cat_features
45 | for i in range(3):
46 | out = self.convs[("pose", i)](out)
47 | if i != 2:
48 | out = self.relu(out)
49 |
50 | out = out.mean(3).mean(2)
51 |
52 | out = 0.01 * out.view(-1, self.num_frames_to_predict_for, 1, 6)
53 |
54 | axisangle = out[..., :3]
55 | translation = out[..., 3:]
56 |
57 | return axisangle, translation
58 |
--------------------------------------------------------------------------------
/networks/resnet_encoder.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torchvision.models as models
6 | import torch.utils.model_zoo as model_zoo
7 | from torchvision import transforms
8 |
9 |
10 | class ResNetMultiImageInput(models.ResNet):
11 | """Constructs a resnet model with varying number of input images.
12 | Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
13 | """
14 | def __init__(self, block, layers, num_classes=1000, num_input_images=1):
15 | super(ResNetMultiImageInput, self).__init__(block, layers)
16 | self.inplanes = 64
17 | self.conv1 = nn.Conv2d(
18 | num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
19 | self.bn1 = nn.BatchNorm2d(64)
20 | self.relu = nn.ReLU(inplace=True)
21 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
22 | self.layer1 = self._make_layer(block, 64, layers[0])
23 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
24 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
25 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
26 |
27 | for m in self.modules():
28 | if isinstance(m, nn.Conv2d):
29 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
30 | elif isinstance(m, nn.BatchNorm2d):
31 | nn.init.constant_(m.weight, 1)
32 | nn.init.constant_(m.bias, 0)
33 |
34 |
35 | def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1):
36 | """Constructs a ResNet model.
37 | Args:
38 | num_layers (int): Number of resnet layers. Must be 18 or 50
39 | pretrained (bool): If True, returns a model pre-trained on ImageNet
40 | num_input_images (int): Number of frames stacked as input
41 | """
42 | assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet"
43 | blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers]
44 | block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers]
45 | model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images)
46 |
47 | if pretrained:
48 | loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)])
49 | loaded['conv1.weight'] = torch.cat(
50 | [loaded['conv1.weight']] * num_input_images, 1) / num_input_images
51 | model.load_state_dict(loaded)
52 | return model
53 |
54 |
55 | class ResnetEncoder(nn.Module):
56 | """Pytorch module for a resnet encoder
57 | """
58 | def __init__(self, num_layers, pretrained, num_input_images=1):
59 | super(ResnetEncoder, self).__init__()
60 |
61 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
62 | std=[0.229, 0.224, 0.225])
63 |
64 | self.num_ch_enc = np.array([64, 64, 128, 256, 512])
65 |
66 | resnets = {18: models.resnet18,
67 | 34: models.resnet34,
68 | 50: models.resnet50,
69 | 101: models.resnet101,
70 | 152: models.resnet152}
71 |
72 | if num_layers not in resnets:
73 | raise ValueError("{} is not a valid number of resnet layers".format(num_layers))
74 |
75 | if num_input_images > 1:
76 | self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images)
77 | else:
78 | self.encoder = resnets[num_layers](pretrained)
79 |
80 | if num_layers > 34:
81 | self.num_ch_enc[1:] *= 4
82 |
83 | def forward(self, input_image):
84 | self.features = []
85 | x = (input_image - 0.45) / 0.225
86 | x = self.encoder.conv1(x)
87 | x = self.encoder.bn1(x)
88 | self.features.append(self.encoder.relu(x))
89 | self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1])))
90 | self.features.append(self.encoder.layer2(self.features[-1]))
91 | self.features.append(self.encoder.layer3(self.features[-1]))
92 | self.features.append(self.encoder.layer4(self.features[-1]))
93 |
94 | return self.features
95 |
--------------------------------------------------------------------------------
/options.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import os
4 | import argparse
5 |
6 | file_dir = os.path.dirname(__file__) # the directory that options.py resides in
7 |
8 |
9 | class LiteMonoOptions:
10 | def __init__(self):
11 | self.parser = argparse.ArgumentParser(description="Lite-Mono options")
12 |
13 | # PATHS
14 | self.parser.add_argument("--data_path",
15 | type=str,
16 | help="path to the training data",
17 | default=os.path.join(file_dir, "kitti_data"))
18 | self.parser.add_argument("--log_dir",
19 | type=str,
20 | help="log directory",
21 | default="./tmp")
22 |
23 | # TRAINING options
24 | self.parser.add_argument("--model_name",
25 | type=str,
26 | help="the name of the folder to save the model in",
27 | default="lite-mono")
28 | self.parser.add_argument("--split",
29 | type=str,
30 | help="which training split to use",
31 | choices=["eigen_zhou", "eigen_full", "odom", "benchmark"],
32 | default="eigen_zhou")
33 | self.parser.add_argument("--model",
34 | type=str,
35 | help="which model to load",
36 | choices=["lite-mono", "lite-mono-small", "lite-mono-tiny", "lite-mono-8m"],
37 | default="lite-mono")
38 | self.parser.add_argument("--weight_decay",
39 | type=float,
40 | help="weight decay in AdamW",
41 | default=1e-2)
42 | self.parser.add_argument("--drop_path",
43 | type=float,
44 | help="drop path rate",
45 | default=0.2)
46 | self.parser.add_argument("--num_layers",
47 | type=int,
48 | help="number of resnet layers",
49 | default=18,
50 | choices=[18, 34, 50, 101, 152])
51 | self.parser.add_argument("--dataset",
52 | type=str,
53 | help="dataset to train on",
54 | default="kitti",
55 | choices=["kitti", "kitti_odom", "kitti_depth", "kitti_test"])
56 | self.parser.add_argument("--png",
57 | help="if set, trains from raw KITTI png files (instead of jpgs)",
58 | action="store_true")
59 | self.parser.add_argument("--height",
60 | type=int,
61 | help="input image height",
62 | default=192)
63 | self.parser.add_argument("--width",
64 | type=int,
65 | help="input image width",
66 | default=640)
67 | self.parser.add_argument("--disparity_smoothness",
68 | type=float,
69 | help="disparity smoothness weight",
70 | default=1e-3)
71 | self.parser.add_argument("--scales",
72 | nargs="+",
73 | type=int,
74 | help="scales used in the loss",
75 | default=[0, 1, 2])
76 | self.parser.add_argument("--min_depth",
77 | type=float,
78 | help="minimum depth",
79 | default=0.1)
80 | self.parser.add_argument("--max_depth",
81 | type=float,
82 | help="maximum depth",
83 | default=100.0)
84 | self.parser.add_argument("--use_stereo",
85 | help="if set, uses stereo pair for training",
86 | action="store_true")
87 | self.parser.add_argument("--frame_ids",
88 | nargs="+",
89 | type=int,
90 | help="frames to load",
91 | default=[0, -1, 1])
92 |
93 | self.parser.add_argument("--profile",
94 | type=bool,
95 | help="profile once at the beginning of the training",
96 | default=True)
97 |
98 | # OPTIMIZATION options
99 | self.parser.add_argument("--batch_size",
100 | type=int,
101 | help="batch size",
102 | default=16)
103 | self.parser.add_argument("--lr",
104 | nargs="+",
105 | type=float,
106 | help="learning rates of DepthNet and PoseNet. "
107 | "Initial learning rate, "
108 | "minimum learning rate, "
109 | "First cycle step size.",
110 | default=[0.0001, 5e-6, 31, 0.0001, 1e-5, 31])
111 | self.parser.add_argument("--num_epochs",
112 | type=int,
113 | help="number of epochs",
114 | default=50)
115 | self.parser.add_argument("--scheduler_step_size",
116 | type=int,
117 | help="step size of the scheduler",
118 | default=15)
119 |
120 | # ABLATION options
121 | self.parser.add_argument("--v1_multiscale",
122 | help="if set, uses monodepth v1 multiscale",
123 | action="store_true")
124 | self.parser.add_argument("--avg_reprojection",
125 | help="if set, uses average reprojection loss",
126 | action="store_true")
127 | self.parser.add_argument("--disable_automasking",
128 | help="if set, doesn't do auto-masking",
129 | action="store_true")
130 | self.parser.add_argument("--predictive_mask",
131 | help="if set, uses a predictive masking scheme as in Zhou et al",
132 | action="store_true")
133 | self.parser.add_argument("--no_ssim",
134 | help="if set, disables ssim in the loss",
135 | action="store_true")
136 | self.parser.add_argument("--mypretrain",
137 | type=str,
138 | help="if set, use my pretrained encoder")
139 | self.parser.add_argument("--weights_init",
140 | type=str,
141 | help="pretrained or scratch",
142 | default="pretrained",
143 | choices=["pretrained", "scratch"])
144 | self.parser.add_argument("--pose_model_input",
145 | type=str,
146 | help="how many images the pose network gets",
147 | default="pairs",
148 | choices=["pairs", "all"])
149 | self.parser.add_argument("--pose_model_type",
150 | type=str,
151 | help="normal or shared",
152 | default="separate_resnet",
153 | choices=["posecnn", "separate_resnet", "shared"])
154 |
155 | # SYSTEM options
156 | self.parser.add_argument("--no_cuda",
157 | help="if set disables CUDA",
158 | action="store_true")
159 | self.parser.add_argument("--num_workers",
160 | type=int,
161 | help="number of dataloader workers",
162 | default=12)
163 |
164 | # LOADING options
165 | self.parser.add_argument("--load_weights_folder",
166 | type=str,
167 | help="name of model to load")
168 | self.parser.add_argument("--models_to_load",
169 | nargs="+",
170 | type=str,
171 | help="models to load",
172 | default=["encoder", "depth", "pose_encoder", "pose"])
173 |
174 | # LOGGING options
175 | self.parser.add_argument("--log_frequency",
176 | type=int,
177 | help="number of batches between each tensorboard log",
178 | default=250)
179 | self.parser.add_argument("--save_frequency",
180 | type=int,
181 | help="number of epochs between each save",
182 | default=1)
183 |
184 | # EVALUATION options
185 | self.parser.add_argument("--disable_median_scaling",
186 | help="if set disables median scaling in evaluation",
187 | action="store_true")
188 | self.parser.add_argument("--pred_depth_scale_factor",
189 | help="if set multiplies predictions by this number",
190 | type=float,
191 | default=1)
192 | self.parser.add_argument("--ext_disp_to_eval",
193 | type=str,
194 | help="optional path to a .npy disparities file to evaluate")
195 | self.parser.add_argument("--eval_split",
196 | type=str,
197 | default="eigen",
198 | choices=[
199 | "eigen"],
200 | help="which split to run eval on")
201 | self.parser.add_argument("--save_pred_disps",
202 | help="if set saves predicted disparities",
203 | action="store_true")
204 | self.parser.add_argument("--no_eval",
205 | help="if set disables evaluation",
206 | action="store_true")
207 | self.parser.add_argument("--eval_out_dir",
208 | help="if set will output the disparities to this folder",
209 | type=str)
210 | self.parser.add_argument("--post_process",
211 | help="if set will perform the flipping post processing "
212 | "from the original monodepth paper",
213 | action="store_true")
214 |
215 | def parse(self):
216 | self.options = self.parser.parse_args()
217 | return self.options
218 |
--------------------------------------------------------------------------------
/test_simple.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import os
4 | import sys
5 | import glob
6 | import argparse
7 | import numpy as np
8 | import PIL.Image as pil
9 | import matplotlib as mpl
10 | import matplotlib.cm as cm
11 |
12 | import torch
13 | from torchvision import transforms, datasets
14 |
15 | import networks
16 | from layers import disp_to_depth
17 | import cv2
18 | import heapq
19 | from PIL import ImageFile
20 | ImageFile.LOAD_TRUNCATED_IMAGES = True
21 |
22 |
23 | def parse_args():
24 | parser = argparse.ArgumentParser(
25 | description='Simple testing function for Lite-Mono models.')
26 |
27 | parser.add_argument('--image_path', type=str,
28 | help='path to a test image or folder of images', required=True)
29 |
30 | parser.add_argument('--load_weights_folder', type=str,
31 | help='path of a pretrained model to use',
32 | )
33 |
34 | parser.add_argument('--test',
35 | action='store_true',
36 | help='if set, read images from a .txt file',
37 | )
38 |
39 | parser.add_argument('--model', type=str,
40 | help='name of a pretrained model to use',
41 | default="lite-mono",
42 | choices=[
43 | "lite-mono",
44 | "lite-mono-small",
45 | "lite-mono-tiny",
46 | "lite-mono-8m"])
47 |
48 | parser.add_argument('--ext', type=str,
49 | help='image extension to search for in folder', default="jpg")
50 | parser.add_argument("--no_cuda",
51 | help='if set, disables CUDA',
52 | action='store_true')
53 |
54 | return parser.parse_args()
55 |
56 |
57 | def test_simple(args):
58 | """Function to predict for a single image or folder of images
59 | """
60 | assert args.load_weights_folder is not None, \
61 | "You must specify the --load_weights_folder parameter"
62 |
63 | if torch.cuda.is_available() and not args.no_cuda:
64 | device = torch.device("cuda")
65 | else:
66 | device = torch.device("cpu")
67 |
68 | print("-> Loading model from ", args.load_weights_folder)
69 | encoder_path = os.path.join(args.load_weights_folder, "encoder.pth")
70 | decoder_path = os.path.join(args.load_weights_folder, "depth.pth")
71 |
72 | encoder_dict = torch.load(encoder_path)
73 | decoder_dict = torch.load(decoder_path)
74 |
75 | # extract the height and width of image that this model was trained with
76 | feed_height = encoder_dict['height']
77 | feed_width = encoder_dict['width']
78 |
79 | # LOADING PRETRAINED MODEL
80 | print(" Loading pretrained encoder")
81 | encoder = networks.LiteMono(model=args.model,
82 | height=feed_height,
83 | width=feed_width)
84 |
85 | model_dict = encoder.state_dict()
86 | encoder.load_state_dict({k: v for k, v in encoder_dict.items() if k in model_dict})
87 |
88 | encoder.to(device)
89 | encoder.eval()
90 |
91 | print(" Loading pretrained decoder")
92 | depth_decoder = networks.DepthDecoder(encoder.num_ch_enc, scales=range(3))
93 | depth_model_dict = depth_decoder.state_dict()
94 | depth_decoder.load_state_dict({k: v for k, v in decoder_dict.items() if k in depth_model_dict})
95 |
96 | depth_decoder.to(device)
97 | depth_decoder.eval()
98 |
99 | # FINDING INPUT IMAGES
100 | if os.path.isfile(args.image_path) and not args.test:
101 | # Only testing on a single image
102 | paths = [args.image_path]
103 | output_directory = os.path.dirname(args.image_path)
104 | elif os.path.isfile(args.image_path) and args.test:
105 | gt_path = os.path.join('splits', 'eigen', "gt_depths.npz")
106 | gt_depths = np.load(gt_path, fix_imports=True, encoding='latin1', allow_pickle=True)["data"]
107 |
108 | side_map = {"2": 2, "3": 3, "l": 2, "r": 3}
109 | # reading images from .txt file
110 | paths = []
111 | with open(args.image_path) as f:
112 | filenames = f.readlines()
113 | for i in range(len(filenames)):
114 | filename = filenames[i]
115 | line = filename.split()
116 | folder = line[0]
117 | if len(line) == 3:
118 | frame_index = int(line[1])
119 | side = line[2]
120 |
121 | f_str = "{:010d}{}".format(frame_index, '.jpg')
122 | image_path = os.path.join(
123 | 'kitti_data',
124 | folder,
125 | "image_0{}/data".format(side_map[side]),
126 | f_str)
127 | paths.append(image_path)
128 |
129 | elif os.path.isdir(args.image_path):
130 | # Searching folder for images
131 | paths = glob.glob(os.path.join(args.image_path, '*.{}'.format(args.ext)))
132 | output_directory = args.image_path
133 | else:
134 | raise Exception("Can not find args.image_path: {}".format(args.image_path))
135 |
136 | print("-> Predicting on {:d} test images".format(len(paths)))
137 |
138 | # PREDICTING ON EACH IMAGE IN TURN
139 | with torch.no_grad():
140 | for idx, image_path in enumerate(paths):
141 |
142 | if image_path.endswith("_disp.jpg"):
143 | # don't try to predict disparity for a disparity image!
144 | continue
145 |
146 | # Load image and preprocess
147 | input_image = pil.open(image_path).convert('RGB')
148 | original_width, original_height = input_image.size
149 | input_image = input_image.resize((feed_width, feed_height), pil.LANCZOS)
150 | input_image = transforms.ToTensor()(input_image).unsqueeze(0)
151 |
152 | # PREDICTION
153 | input_image = input_image.to(device)
154 | features = encoder(input_image)
155 | outputs = depth_decoder(features)
156 |
157 | disp = outputs[("disp", 0)]
158 |
159 | disp_resized = torch.nn.functional.interpolate(
160 | disp, (original_height, original_width), mode="bilinear", align_corners=False)
161 |
162 | # Saving numpy file
163 | output_name = os.path.splitext(os.path.basename(image_path))[0]
164 | # output_name = os.path.splitext(image_path)[0].split('/')[-1]
165 | scaled_disp, depth = disp_to_depth(disp, 0.1, 100)
166 |
167 | name_dest_npy = os.path.join(output_directory, "{}_disp.npy".format(output_name))
168 | np.save(name_dest_npy, scaled_disp.cpu().numpy())
169 |
170 | # Saving colormapped depth image
171 | disp_resized_np = disp_resized.squeeze().cpu().numpy()
172 | vmax = np.percentile(disp_resized_np, 95)
173 | normalizer = mpl.colors.Normalize(vmin=disp_resized_np.min(), vmax=vmax)
174 | mapper = cm.ScalarMappable(norm=normalizer, cmap='magma')
175 | colormapped_im = (mapper.to_rgba(disp_resized_np)[:, :, :3] * 255).astype(np.uint8)
176 | im = pil.fromarray(colormapped_im)
177 |
178 | name_dest_im = os.path.join(output_directory, "{}_disp.jpeg".format(output_name))
179 | im.save(name_dest_im)
180 |
181 | print(" Processed {:d} of {:d} images - saved predictions to:".format(
182 | idx + 1, len(paths)))
183 | print(" - {}".format(name_dest_im))
184 | print(" - {}".format(name_dest_npy))
185 |
186 |
187 | print('-> Done!')
188 |
189 |
190 | if __name__ == '__main__':
191 | args = parse_args()
192 | test_simple(args)
193 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | from options import LiteMonoOptions
4 | from trainer import Trainer
5 |
6 | options = LiteMonoOptions()
7 | opts = options.parse()
8 |
9 |
10 | if __name__ == "__main__":
11 | trainer = Trainer(opts)
12 | trainer.train()
13 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 |
4 | import time
5 | import torch.optim as optim
6 | from torch.utils.data import DataLoader
7 | from tensorboardX import SummaryWriter
8 |
9 | import json
10 |
11 | from utils import *
12 | from kitti_utils import *
13 | from layers import *
14 |
15 | import datasets
16 | import networks
17 | from linear_warmup_cosine_annealing_warm_restarts_weight_decay import ChainedScheduler
18 |
19 |
20 | # torch.backends.cudnn.benchmark = True
21 |
22 |
23 | def time_sync():
24 | # PyTorch-accurate time
25 | if torch.cuda.is_available():
26 | torch.cuda.synchronize()
27 | return time.time()
28 |
29 |
30 | class Trainer:
31 | def __init__(self, options):
32 | self.opt = options
33 | self.log_path = os.path.join(self.opt.log_dir, self.opt.model_name)
34 |
35 | # checking height and width are multiples of 32
36 | assert self.opt.height % 32 == 0, "'height' must be a multiple of 32"
37 | assert self.opt.width % 32 == 0, "'width' must be a multiple of 32"
38 |
39 | self.models = {}
40 | self.models_pose = {}
41 | self.parameters_to_train = []
42 | self.parameters_to_train_pose = []
43 |
44 | self.device = torch.device("cpu" if self.opt.no_cuda else "cuda")
45 | self.profile = self.opt.profile
46 |
47 | self.num_scales = len(self.opt.scales)
48 | self.frame_ids = len(self.opt.frame_ids)
49 | self.num_pose_frames = 2 if self.opt.pose_model_input == "pairs" else self.num_input_frames
50 |
51 | assert self.opt.frame_ids[0] == 0, "frame_ids must start with 0"
52 |
53 | self.use_pose_net = not (self.opt.use_stereo and self.opt.frame_ids == [0])
54 |
55 | if self.opt.use_stereo:
56 | self.opt.frame_ids.append("s")
57 |
58 | self.models["encoder"] = networks.LiteMono(model=self.opt.model,
59 | drop_path_rate=self.opt.drop_path,
60 | width=self.opt.width, height=self.opt.height)
61 |
62 | self.models["encoder"].to(self.device)
63 | self.parameters_to_train += list(self.models["encoder"].parameters())
64 |
65 | self.models["depth"] = networks.DepthDecoder(self.models["encoder"].num_ch_enc,
66 | self.opt.scales)
67 | self.models["depth"].to(self.device)
68 | self.parameters_to_train += list(self.models["depth"].parameters())
69 |
70 | if self.use_pose_net:
71 | if self.opt.pose_model_type == "separate_resnet":
72 | self.models_pose["pose_encoder"] = networks.ResnetEncoder(
73 | self.opt.num_layers,
74 | self.opt.weights_init == "pretrained",
75 | num_input_images=self.num_pose_frames)
76 |
77 | self.models_pose["pose_encoder"].to(self.device)
78 | self.parameters_to_train_pose += list(self.models_pose["pose_encoder"].parameters())
79 |
80 | self.models_pose["pose"] = networks.PoseDecoder(
81 | self.models_pose["pose_encoder"].num_ch_enc,
82 | num_input_features=1,
83 | num_frames_to_predict_for=2)
84 |
85 | elif self.opt.pose_model_type == "shared":
86 | self.models_pose["pose"] = networks.PoseDecoder(
87 | self.models["encoder"].num_ch_enc, self.num_pose_frames)
88 |
89 | elif self.opt.pose_model_type == "posecnn":
90 | self.models_pose["pose"] = networks.PoseCNN(
91 | self.num_input_frames if self.opt.pose_model_input == "all" else 2)
92 |
93 | self.models_pose["pose"].to(self.device)
94 | self.parameters_to_train_pose += list(self.models_pose["pose"].parameters())
95 |
96 | if self.opt.predictive_mask:
97 | assert self.opt.disable_automasking, \
98 | "When using predictive_mask, please disable automasking with --disable_automasking"
99 |
100 | # Our implementation of the predictive masking baseline has the the same architecture
101 | # as our depth decoder. We predict a separate mask for each source frame.
102 | self.models["predictive_mask"] = networks.DepthDecoder(
103 | self.models["encoder"].num_ch_enc, self.opt.scales,
104 | num_output_channels=(len(self.opt.frame_ids) - 1))
105 | self.models["predictive_mask"].to(self.device)
106 | self.parameters_to_train += list(self.models["predictive_mask"].parameters())
107 |
108 | self.model_optimizer = optim.AdamW(self.parameters_to_train, self.opt.lr[0], weight_decay=self.opt.weight_decay)
109 | if self.use_pose_net:
110 | self.model_pose_optimizer = optim.AdamW(self.parameters_to_train_pose, self.opt.lr[3], weight_decay=self.opt.weight_decay)
111 |
112 | self.model_lr_scheduler = ChainedScheduler(
113 | self.model_optimizer,
114 | T_0=int(self.opt.lr[2]),
115 | T_mul=1,
116 | eta_min=self.opt.lr[1],
117 | last_epoch=-1,
118 | max_lr=self.opt.lr[0],
119 | warmup_steps=0,
120 | gamma=0.9
121 | )
122 | self.model_pose_lr_scheduler = ChainedScheduler(
123 | self.model_pose_optimizer,
124 | T_0=int(self.opt.lr[5]),
125 | T_mul=1,
126 | eta_min=self.opt.lr[4],
127 | last_epoch=-1,
128 | max_lr=self.opt.lr[3],
129 | warmup_steps=0,
130 | gamma=0.9
131 | )
132 |
133 | if self.opt.load_weights_folder is not None:
134 | self.load_model()
135 |
136 | if self.opt.mypretrain is not None:
137 | self.load_pretrain()
138 |
139 | print("Training model named:\n ", self.opt.model_name)
140 | print("Models and tensorboard events files are saved to:\n ", self.opt.log_dir)
141 | print("Training is using:\n ", self.device)
142 |
143 | # data
144 | datasets_dict = {"kitti": datasets.KITTIRAWDataset,
145 | "kitti_odom": datasets.KITTIOdomDataset}
146 | self.dataset = datasets_dict[self.opt.dataset]
147 |
148 | fpath = os.path.join(os.path.dirname(__file__), "splits", self.opt.split, "{}_files.txt")
149 |
150 | train_filenames = readlines(fpath.format("train"))
151 | val_filenames = readlines(fpath.format("val"))
152 | img_ext = '.png' if self.opt.png else '.jpg'
153 |
154 | num_train_samples = len(train_filenames)
155 | self.num_total_steps = num_train_samples // self.opt.batch_size * self.opt.num_epochs
156 |
157 | train_dataset = self.dataset(
158 | self.opt.data_path, train_filenames, self.opt.height, self.opt.width,
159 | self.opt.frame_ids, 4, is_train=True, img_ext=img_ext)
160 | self.train_loader = DataLoader(
161 | train_dataset, self.opt.batch_size, True,
162 | num_workers=self.opt.num_workers, pin_memory=True, drop_last=True)
163 | val_dataset = self.dataset(
164 | self.opt.data_path, val_filenames, self.opt.height, self.opt.width,
165 | self.opt.frame_ids, 4, is_train=False, img_ext=img_ext)
166 | self.val_loader = DataLoader(
167 | val_dataset, self.opt.batch_size, True,
168 | num_workers=self.opt.num_workers, pin_memory=True, drop_last=True)
169 | self.val_iter = iter(self.val_loader)
170 |
171 | self.writers = {}
172 | for mode in ["train", "val"]:
173 | self.writers[mode] = SummaryWriter(os.path.join(self.log_path, mode))
174 |
175 | if not self.opt.no_ssim:
176 | self.ssim = SSIM()
177 | self.ssim.to(self.device)
178 |
179 | self.backproject_depth = {}
180 | self.project_3d = {}
181 | for scale in self.opt.scales:
182 | h = self.opt.height // (2 ** scale)
183 | w = self.opt.width // (2 ** scale)
184 |
185 | self.backproject_depth[scale] = BackprojectDepth(self.opt.batch_size, h, w)
186 | self.backproject_depth[scale].to(self.device)
187 |
188 | self.project_3d[scale] = Project3D(self.opt.batch_size, h, w)
189 | self.project_3d[scale].to(self.device)
190 |
191 | self.depth_metric_names = [
192 | "de/abs_rel", "de/sq_rel", "de/rms", "de/log_rms", "da/a1", "da/a2", "da/a3"]
193 |
194 | print("Using split:\n ", self.opt.split)
195 | print("There are {:d} training items and {:d} validation items\n".format(
196 | len(train_dataset), len(val_dataset)))
197 |
198 | self.save_opts()
199 |
200 | def set_train(self):
201 | """Convert all models to training mode
202 | """
203 | for m in self.models.values():
204 | m.train()
205 |
206 | def set_eval(self):
207 | """Convert all models to testing/evaluation mode
208 | """
209 | for m in self.models.values():
210 | m.eval()
211 |
212 | def train(self):
213 | """Run the entire training pipeline
214 | """
215 | self.epoch = 0
216 | self.step = 0
217 | self.start_time = time.time()
218 | for self.epoch in range(self.opt.num_epochs):
219 | self.run_epoch()
220 | if (self.epoch + 1) % self.opt.save_frequency == 0:
221 | self.save_model()
222 |
223 | def run_epoch(self):
224 | """Run a single epoch of training and validation
225 | """
226 |
227 | print("Training")
228 | self.set_train()
229 |
230 | self.model_lr_scheduler.step()
231 | if self.use_pose_net:
232 | self.model_pose_lr_scheduler.step()
233 |
234 | for batch_idx, inputs in enumerate(self.train_loader):
235 |
236 | before_op_time = time.time()
237 |
238 | outputs, losses = self.process_batch(inputs)
239 |
240 | self.model_optimizer.zero_grad()
241 | if self.use_pose_net:
242 | self.model_pose_optimizer.zero_grad()
243 | losses["loss"].backward()
244 | self.model_optimizer.step()
245 | if self.use_pose_net:
246 | self.model_pose_optimizer.step()
247 |
248 | duration = time.time() - before_op_time
249 |
250 | # log less frequently after the first 2000 steps to save time & disk space
251 | early_phase = batch_idx % self.opt.log_frequency == 0 and self.step < 20000
252 | late_phase = self.step % 2000 == 0
253 |
254 | if early_phase or late_phase:
255 | self.log_time(batch_idx, duration, losses["loss"].cpu().data)
256 |
257 | if "depth_gt" in inputs:
258 | self.compute_depth_losses(inputs, outputs, losses)
259 |
260 | self.log("train", inputs, outputs, losses)
261 | self.val()
262 |
263 | self.step += 1
264 |
265 | def process_batch(self, inputs):
266 | """Pass a minibatch through the network and generate images and losses
267 | """
268 | for key, ipt in inputs.items():
269 | inputs[key] = ipt.to(self.device)
270 |
271 | if self.opt.pose_model_type == "shared":
272 | # If we are using a shared encoder for both depth and pose (as advocated
273 | # in monodepthv1), then all images are fed separately through the depth encoder.
274 | all_color_aug = torch.cat([inputs[("color_aug", i, 0)] for i in self.opt.frame_ids])
275 | all_features = self.models["encoder"](all_color_aug)
276 | all_features = [torch.split(f, self.opt.batch_size) for f in all_features]
277 |
278 | features = {}
279 | for i, k in enumerate(self.opt.frame_ids):
280 | features[k] = [f[i] for f in all_features]
281 |
282 | outputs = self.models["depth"](features[0])
283 | else:
284 | # Otherwise, we only feed the image with frame_id 0 through the depth encoder
285 |
286 | features = self.models["encoder"](inputs["color_aug", 0, 0])
287 |
288 | outputs = self.models["depth"](features)
289 |
290 | if self.opt.predictive_mask:
291 | outputs["predictive_mask"] = self.models["predictive_mask"](features)
292 |
293 | if self.use_pose_net:
294 | outputs.update(self.predict_poses(inputs, features))
295 |
296 | self.generate_images_pred(inputs, outputs)
297 | losses = self.compute_losses(inputs, outputs)
298 |
299 | return outputs, losses
300 |
301 | def predict_poses(self, inputs, features):
302 | """Predict poses between input frames for monocular sequences.
303 | """
304 | outputs = {}
305 | if self.num_pose_frames == 2:
306 | # In this setting, we compute the pose to each source frame via a
307 | # separate forward pass through the pose network.
308 |
309 | # select what features the pose network takes as input
310 | if self.opt.pose_model_type == "shared":
311 | pose_feats = {f_i: features[f_i] for f_i in self.opt.frame_ids}
312 | else:
313 | pose_feats = {f_i: inputs["color_aug", f_i, 0] for f_i in self.opt.frame_ids}
314 |
315 | for f_i in self.opt.frame_ids[1:]:
316 | if f_i != "s":
317 | # To maintain ordering we always pass frames in temporal order
318 | if f_i < 0:
319 | pose_inputs = [pose_feats[f_i], pose_feats[0]]
320 | else:
321 | pose_inputs = [pose_feats[0], pose_feats[f_i]]
322 |
323 | if self.opt.pose_model_type == "separate_resnet":
324 | pose_inputs = [self.models_pose["pose_encoder"](torch.cat(pose_inputs, 1))]
325 | elif self.opt.pose_model_type == "posecnn":
326 | pose_inputs = torch.cat(pose_inputs, 1)
327 |
328 | axisangle, translation = self.models_pose["pose"](pose_inputs)
329 | outputs[("axisangle", 0, f_i)] = axisangle
330 | outputs[("translation", 0, f_i)] = translation
331 |
332 | # Invert the matrix if the frame id is negative
333 | outputs[("cam_T_cam", 0, f_i)] = transformation_from_parameters(
334 | axisangle[:, 0], translation[:, 0], invert=(f_i < 0))
335 |
336 | else:
337 | # Here we input all frames to the pose net (and predict all poses) together
338 | if self.opt.pose_model_type in ["separate_resnet", "posecnn"]:
339 | pose_inputs = torch.cat(
340 | [inputs[("color_aug", i, 0)] for i in self.opt.frame_ids if i != "s"], 1)
341 |
342 | if self.opt.pose_model_type == "separate_resnet":
343 | pose_inputs = [self.models["pose_encoder"](pose_inputs)]
344 |
345 | elif self.opt.pose_model_type == "shared":
346 | pose_inputs = [features[i] for i in self.opt.frame_ids if i != "s"]
347 |
348 | axisangle, translation = self.models_pose["pose"](pose_inputs)
349 |
350 | for i, f_i in enumerate(self.opt.frame_ids[1:]):
351 | if f_i != "s":
352 | outputs[("axisangle", 0, f_i)] = axisangle
353 | outputs[("translation", 0, f_i)] = translation
354 | outputs[("cam_T_cam", 0, f_i)] = transformation_from_parameters(
355 | axisangle[:, i], translation[:, i])
356 |
357 | return outputs
358 |
359 | def val(self):
360 | """Validate the model on a single minibatch
361 | """
362 | self.set_eval()
363 | try:
364 | inputs = self.val_iter.next()
365 | except StopIteration:
366 | self.val_iter = iter(self.val_loader)
367 | inputs = self.val_iter.next()
368 |
369 | with torch.no_grad():
370 | outputs, losses = self.process_batch(inputs)
371 |
372 | if "depth_gt" in inputs:
373 | self.compute_depth_losses(inputs, outputs, losses)
374 |
375 | self.log("val", inputs, outputs, losses)
376 | del inputs, outputs, losses
377 |
378 | self.set_train()
379 |
380 | def generate_images_pred(self, inputs, outputs):
381 | """Generate the warped (reprojected) color images for a minibatch.
382 | Generated images are saved into the `outputs` dictionary.
383 | """
384 | for scale in self.opt.scales:
385 | disp = outputs[("disp", scale)]
386 | if self.opt.v1_multiscale:
387 | source_scale = scale
388 | else:
389 | disp = F.interpolate(
390 | disp, [self.opt.height, self.opt.width], mode="bilinear", align_corners=False)
391 | source_scale = 0
392 |
393 | _, depth = disp_to_depth(disp, self.opt.min_depth, self.opt.max_depth)
394 |
395 | outputs[("depth", 0, scale)] = depth
396 |
397 | for i, frame_id in enumerate(self.opt.frame_ids[1:]):
398 |
399 | if frame_id == "s":
400 | T = inputs["stereo_T"]
401 | else:
402 | T = outputs[("cam_T_cam", 0, frame_id)]
403 |
404 | # from the authors of https://arxiv.org/abs/1712.00175
405 | if self.opt.pose_model_type == "posecnn":
406 |
407 | axisangle = outputs[("axisangle", 0, frame_id)]
408 | translation = outputs[("translation", 0, frame_id)]
409 |
410 | inv_depth = 1 / depth
411 | mean_inv_depth = inv_depth.mean(3, True).mean(2, True)
412 |
413 | T = transformation_from_parameters(
414 | axisangle[:, 0], translation[:, 0] * mean_inv_depth[:, 0], frame_id < 0)
415 |
416 | cam_points = self.backproject_depth[source_scale](
417 | depth, inputs[("inv_K", source_scale)])
418 | pix_coords = self.project_3d[source_scale](
419 | cam_points, inputs[("K", source_scale)], T)
420 |
421 | outputs[("sample", frame_id, scale)] = pix_coords
422 |
423 | outputs[("color", frame_id, scale)] = F.grid_sample(
424 | inputs[("color", frame_id, source_scale)],
425 | outputs[("sample", frame_id, scale)],
426 | padding_mode="border", align_corners=True)
427 |
428 | if not self.opt.disable_automasking:
429 | outputs[("color_identity", frame_id, scale)] = \
430 | inputs[("color", frame_id, source_scale)]
431 |
432 | def compute_reprojection_loss(self, pred, target):
433 | """Computes reprojection loss between a batch of predicted and target images
434 | """
435 | abs_diff = torch.abs(target - pred)
436 | l1_loss = abs_diff.mean(1, True)
437 |
438 | if self.opt.no_ssim:
439 | reprojection_loss = l1_loss
440 | else:
441 | ssim_loss = self.ssim(pred, target).mean(1, True)
442 | reprojection_loss = 0.85 * ssim_loss + 0.15 * l1_loss
443 |
444 | return reprojection_loss
445 |
446 | def compute_losses(self, inputs, outputs):
447 | """Compute the reprojection and smoothness losses for a minibatch
448 | """
449 |
450 | losses = {}
451 | total_loss = 0
452 |
453 | for scale in self.opt.scales:
454 | loss = 0
455 | reprojection_losses = []
456 |
457 | if self.opt.v1_multiscale:
458 | source_scale = scale
459 | else:
460 | source_scale = 0
461 |
462 | disp = outputs[("disp", scale)]
463 | color = inputs[("color", 0, scale)]
464 | target = inputs[("color", 0, source_scale)]
465 |
466 | for frame_id in self.opt.frame_ids[1:]:
467 | pred = outputs[("color", frame_id, scale)]
468 | reprojection_losses.append(self.compute_reprojection_loss(pred, target))
469 |
470 | reprojection_losses = torch.cat(reprojection_losses, 1)
471 |
472 | if not self.opt.disable_automasking:
473 | identity_reprojection_losses = []
474 | for frame_id in self.opt.frame_ids[1:]:
475 | pred = inputs[("color", frame_id, source_scale)]
476 | identity_reprojection_losses.append(
477 | self.compute_reprojection_loss(pred, target))
478 |
479 | identity_reprojection_losses = torch.cat(identity_reprojection_losses, 1)
480 |
481 | if self.opt.avg_reprojection:
482 | identity_reprojection_loss = identity_reprojection_losses.mean(1, keepdim=True)
483 | else:
484 | # save both images, and do min all at once below
485 | identity_reprojection_loss = identity_reprojection_losses
486 |
487 | elif self.opt.predictive_mask:
488 | # use the predicted mask
489 | mask = outputs["predictive_mask"]["disp", scale]
490 | if not self.opt.v1_multiscale:
491 | mask = F.interpolate(
492 | mask, [self.opt.height, self.opt.width],
493 | mode="bilinear", align_corners=False)
494 |
495 | reprojection_losses *= mask
496 |
497 | # add a loss pushing mask to 1 (using nn.BCELoss for stability)
498 | weighting_loss = 0.2 * nn.BCELoss()(mask, torch.ones(mask.shape).cuda())
499 | loss += weighting_loss.mean()
500 |
501 | if self.opt.avg_reprojection:
502 | reprojection_loss = reprojection_losses.mean(1, keepdim=True)
503 | else:
504 | reprojection_loss = reprojection_losses
505 |
506 | if not self.opt.disable_automasking:
507 | # add random numbers to break ties
508 | identity_reprojection_loss += torch.randn(
509 | identity_reprojection_loss.shape, device=self.device) * 0.00001
510 |
511 | combined = torch.cat((identity_reprojection_loss, reprojection_loss), dim=1)
512 | else:
513 | combined = reprojection_loss
514 |
515 | if combined.shape[1] == 1:
516 | to_optimise = combined
517 | else:
518 | to_optimise, idxs = torch.min(combined, dim=1)
519 |
520 | if not self.opt.disable_automasking:
521 | outputs["identity_selection/{}".format(scale)] = (
522 | idxs > identity_reprojection_loss.shape[1] - 1).float()
523 |
524 | loss += to_optimise.mean()
525 |
526 | mean_disp = disp.mean(2, True).mean(3, True)
527 | norm_disp = disp / (mean_disp + 1e-7)
528 | smooth_loss = get_smooth_loss(norm_disp, color)
529 |
530 | loss += self.opt.disparity_smoothness * smooth_loss / (2 ** scale)
531 | total_loss += loss
532 | losses["loss/{}".format(scale)] = loss
533 |
534 | total_loss /= self.num_scales
535 | losses["loss"] = total_loss
536 | return losses
537 |
538 | def compute_depth_losses(self, inputs, outputs, losses):
539 | """Compute depth metrics, to allow monitoring during training
540 |
541 | This isn't particularly accurate as it averages over the entire batch,
542 | so is only used to give an indication of validation performance
543 | """
544 | depth_pred = outputs[("depth", 0, 0)]
545 | depth_pred = torch.clamp(F.interpolate(
546 | depth_pred, [375, 1242], mode="bilinear", align_corners=False), 1e-3, 80)
547 | depth_pred = depth_pred.detach()
548 |
549 | depth_gt = inputs["depth_gt"]
550 | mask = depth_gt > 0
551 |
552 | # garg/eigen crop
553 | crop_mask = torch.zeros_like(mask)
554 | crop_mask[:, :, 153:371, 44:1197] = 1
555 | mask = mask * crop_mask
556 |
557 | depth_gt = depth_gt[mask]
558 | depth_pred = depth_pred[mask]
559 | depth_pred *= torch.median(depth_gt) / torch.median(depth_pred)
560 |
561 | depth_pred = torch.clamp(depth_pred, min=1e-3, max=80)
562 |
563 | depth_errors = compute_depth_errors(depth_gt, depth_pred)
564 |
565 | for i, metric in enumerate(self.depth_metric_names):
566 | losses[metric] = np.array(depth_errors[i].cpu())
567 |
568 | def log_time(self, batch_idx, duration, loss):
569 | """Print a logging statement to the terminal
570 | """
571 | samples_per_sec = self.opt.batch_size / duration
572 | time_sofar = time.time() - self.start_time
573 | training_time_left = (
574 | self.num_total_steps / self.step - 1.0) * time_sofar if self.step > 0 else 0
575 | print_string = "epoch {:>3} | lr {:.6f} |lr_p {:.6f} | batch {:>6} | examples/s: {:5.1f}" + \
576 | " | loss: {:.5f} | time elapsed: {} | time left: {}"
577 | print(print_string.format(self.epoch, self.model_optimizer.state_dict()['param_groups'][0]['lr'],
578 | self.model_pose_optimizer.state_dict()['param_groups'][0]['lr'],
579 | batch_idx, samples_per_sec, loss,
580 | sec_to_hm_str(time_sofar), sec_to_hm_str(training_time_left)))
581 |
582 | def log(self, mode, inputs, outputs, losses):
583 | """Write an event to the tensorboard events file
584 | """
585 | writer = self.writers[mode]
586 | for l, v in losses.items():
587 | writer.add_scalar("{}".format(l), v, self.step)
588 |
589 | for j in range(min(4, self.opt.batch_size)): # write a maxmimum of four images
590 | for s in self.opt.scales:
591 | for frame_id in self.opt.frame_ids:
592 | writer.add_image(
593 | "color_{}_{}/{}".format(frame_id, s, j),
594 | inputs[("color", frame_id, s)][j].data, self.step)
595 | if s == 0 and frame_id != 0:
596 | writer.add_image(
597 | "color_pred_{}_{}/{}".format(frame_id, s, j),
598 | outputs[("color", frame_id, s)][j].data, self.step)
599 |
600 | writer.add_image(
601 | "disp_{}/{}".format(s, j),
602 | normalize_image(outputs[("disp", s)][j]), self.step)
603 |
604 | if self.opt.predictive_mask:
605 | for f_idx, frame_id in enumerate(self.opt.frame_ids[1:]):
606 | writer.add_image(
607 | "predictive_mask_{}_{}/{}".format(frame_id, s, j),
608 | outputs["predictive_mask"][("disp", s)][j, f_idx][None, ...],
609 | self.step)
610 |
611 | elif not self.opt.disable_automasking:
612 | writer.add_image(
613 | "automask_{}/{}".format(s, j),
614 | outputs["identity_selection/{}".format(s)][j][None, ...], self.step)
615 |
616 | def save_opts(self):
617 | """Save options to disk so we know what we ran this experiment with
618 | """
619 | models_dir = os.path.join(self.log_path, "models")
620 | if not os.path.exists(models_dir):
621 | os.makedirs(models_dir)
622 | to_save = self.opt.__dict__.copy()
623 |
624 | with open(os.path.join(models_dir, 'opt.json'), 'w') as f:
625 | json.dump(to_save, f, indent=2)
626 |
627 | def save_model(self):
628 | """Save model weights to disk
629 | """
630 | save_folder = os.path.join(self.log_path, "models", "weights_{}".format(self.epoch))
631 | if not os.path.exists(save_folder):
632 | os.makedirs(save_folder)
633 |
634 | for model_name, model in self.models.items():
635 | save_path = os.path.join(save_folder, "{}.pth".format(model_name))
636 | to_save = model.state_dict()
637 | if model_name == 'encoder':
638 | # save the sizes - these are needed at prediction time
639 | to_save['height'] = self.opt.height
640 | to_save['width'] = self.opt.width
641 | to_save['use_stereo'] = self.opt.use_stereo
642 | torch.save(to_save, save_path)
643 |
644 | for model_name, model in self.models_pose.items():
645 | save_path = os.path.join(save_folder, "{}.pth".format(model_name))
646 | to_save = model.state_dict()
647 | torch.save(to_save, save_path)
648 |
649 | save_path = os.path.join(save_folder, "{}.pth".format("adam"))
650 | torch.save(self.model_optimizer.state_dict(), save_path)
651 |
652 | save_path = os.path.join(save_folder, "{}.pth".format("adam_pose"))
653 | if self.use_pose_net:
654 | torch.save(self.model_pose_optimizer.state_dict(), save_path)
655 |
656 | def load_pretrain(self):
657 | self.opt.mypretrain = os.path.expanduser(self.opt.mypretrain)
658 | path = self.opt.mypretrain
659 | model_dict = self.models["encoder"].state_dict()
660 | pretrained_dict = torch.load(path)['model']
661 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict and not k.startswith('norm'))}
662 | model_dict.update(pretrained_dict)
663 | self.models["encoder"].load_state_dict(model_dict)
664 | print('mypretrain loaded.')
665 |
666 | def load_model(self):
667 | """Load model(s) from disk
668 | """
669 | self.opt.load_weights_folder = os.path.expanduser(self.opt.load_weights_folder)
670 |
671 | assert os.path.isdir(self.opt.load_weights_folder), \
672 | "Cannot find folder {}".format(self.opt.load_weights_folder)
673 | print("loading model from folder {}".format(self.opt.load_weights_folder))
674 |
675 | for n in self.opt.models_to_load:
676 | print("Loading {} weights...".format(n))
677 | path = os.path.join(self.opt.load_weights_folder, "{}.pth".format(n))
678 |
679 | if n in ['pose_encoder', 'pose']:
680 | model_dict = self.models_pose[n].state_dict()
681 | pretrained_dict = torch.load(path)
682 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
683 | model_dict.update(pretrained_dict)
684 | self.models_pose[n].load_state_dict(model_dict)
685 | else:
686 | model_dict = self.models[n].state_dict()
687 | pretrained_dict = torch.load(path)
688 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
689 | model_dict.update(pretrained_dict)
690 | self.models[n].load_state_dict(model_dict)
691 |
692 | # loading adam state
693 |
694 | optimizer_load_path = os.path.join(self.opt.load_weights_folder, "adam.pth")
695 | optimizer_pose_load_path = os.path.join(self.opt.load_weights_folder, "adam_pose.pth")
696 | if os.path.isfile(optimizer_load_path):
697 | print("Loading Adam weights")
698 | optimizer_dict = torch.load(optimizer_load_path)
699 | optimizer_pose_dict = torch.load(optimizer_pose_load_path)
700 | self.model_optimizer.load_state_dict(optimizer_dict)
701 | self.model_pose_optimizer.load_state_dict(optimizer_pose_dict)
702 | else:
703 | print("Cannot find Adam weights so Adam is randomly initialized")
704 |
705 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | import os
3 | import hashlib
4 | import zipfile
5 | from six.moves import urllib
6 |
7 |
8 | def readlines(filename):
9 | """Read all the lines in a text file and return as a list
10 | """
11 | with open(filename, 'r') as f:
12 | lines = f.read().splitlines()
13 | return lines
14 |
15 |
16 | def normalize_image(x):
17 | """Rescale image pixels to span range [0, 1]
18 | """
19 | ma = float(x.max().cpu().data)
20 | mi = float(x.min().cpu().data)
21 | d = ma - mi if ma != mi else 1e5
22 | return (x - mi) / d
23 |
24 |
25 | def sec_to_hm(t):
26 | """Convert time in seconds to time in hours, minutes and seconds
27 | e.g. 10239 -> (2, 50, 39)
28 | """
29 | t = int(t)
30 | s = t % 60
31 | t //= 60
32 | m = t % 60
33 | t //= 60
34 | return t, m, s
35 |
36 |
37 | def sec_to_hm_str(t):
38 | """Convert time in seconds to a nice string
39 | e.g. 10239 -> '02h50m39s'
40 | """
41 | h, m, s = sec_to_hm(t)
42 | return "{:02d}h{:02d}m{:02d}s".format(h, m, s)
43 |
44 |
45 | def download_model_if_doesnt_exist(model_name):
46 | """If pretrained kitti model doesn't exist, download and unzip it
47 | """
48 | # values are tuples of (, )
49 | download_paths = {
50 | "mono_640x192":
51 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono_640x192.zip",
52 | "a964b8356e08a02d009609d9e3928f7c"),
53 | "stereo_640x192":
54 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/stereo_640x192.zip",
55 | "3dfb76bcff0786e4ec07ac00f658dd07"),
56 | "mono+stereo_640x192":
57 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono%2Bstereo_640x192.zip",
58 | "c024d69012485ed05d7eaa9617a96b81"),
59 | "mono_no_pt_640x192":
60 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono_no_pt_640x192.zip",
61 | "9c2f071e35027c895a4728358ffc913a"),
62 | "stereo_no_pt_640x192":
63 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/stereo_no_pt_640x192.zip",
64 | "41ec2de112905f85541ac33a854742d1"),
65 | "mono+stereo_no_pt_640x192":
66 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono%2Bstereo_no_pt_640x192.zip",
67 | "46c3b824f541d143a45c37df65fbab0a"),
68 | "mono_1024x320":
69 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono_1024x320.zip",
70 | "0ab0766efdfeea89a0d9ea8ba90e1e63"),
71 | "stereo_1024x320":
72 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/stereo_1024x320.zip",
73 | "afc2f2126d70cf3fdf26b550898b501a"),
74 | "mono+stereo_1024x320":
75 | ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono%2Bstereo_1024x320.zip",
76 | "cdc5fc9b23513c07d5b19235d9ef08f7"),
77 | }
78 |
79 | if not os.path.exists("models"):
80 | os.makedirs("models")
81 |
82 | model_path = os.path.join("models", model_name)
83 |
84 | def check_file_matches_md5(checksum, fpath):
85 | if not os.path.exists(fpath):
86 | return False
87 | with open(fpath, 'rb') as f:
88 | current_md5checksum = hashlib.md5(f.read()).hexdigest()
89 | return current_md5checksum == checksum
90 |
91 | # see if we have the model already downloaded...
92 | if not os.path.exists(os.path.join(model_path, "encoder.pth")):
93 |
94 | model_url, required_md5checksum = download_paths[model_name]
95 |
96 | if not check_file_matches_md5(required_md5checksum, model_path + ".zip"):
97 | print("-> Downloading pretrained model to {}".format(model_path + ".zip"))
98 | urllib.request.urlretrieve(model_url, model_path + ".zip")
99 |
100 | if not check_file_matches_md5(required_md5checksum, model_path + ".zip"):
101 | print(" Failed to download a file which matches the checksum - quitting")
102 | quit()
103 |
104 | print(" Unzipping model...")
105 | with zipfile.ZipFile(model_path + ".zip", 'r') as f:
106 | f.extractall(model_path)
107 |
108 | print(" Model unzipped to {}".format(model_path))
109 |
--------------------------------------------------------------------------------