├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── assets
└── framework.png
├── configs
└── nuscenes
│ └── main.yaml
├── dataset
├── __init__.py
├── base_dataset.py
├── data_util.py
├── nuscenes
│ ├── eval_MF.txt
│ ├── eval_SF.txt
│ └── train.txt
├── nuscenes_dataset.py
└── nuscenes_mask
│ ├── CAM_BACK_LEFT_mask.png
│ ├── CAM_BACK_RIGHT_mask.png
│ ├── CAM_BACK_mask.png
│ ├── CAM_FRONT_LEFT_mask.png
│ ├── CAM_FRONT_RIGHT_mask.png
│ └── CAM_FRONT_mask.png
├── eval.py
├── external
├── dataset
│ └── __init__.py
├── layers
│ └── __init__.py
└── utils
│ └── __init__.py
├── models
├── __init__.py
├── base_model.py
├── drivingforward_model.py
├── gaussian
│ ├── GaussianRender.py
│ ├── __init__.py
│ ├── extractor.py
│ ├── gaussian_network.py
│ ├── gaussian_renderer
│ │ └── __init__.py
│ └── utils.py
├── geometry
│ ├── __init__.py
│ ├── geometry_util.py
│ ├── pose.py
│ └── view_rendering.py
└── losses
│ ├── __init__.py
│ ├── base_loss.py
│ ├── loss_util.py
│ ├── multi_cam_loss.py
│ └── single_cam_loss.py
├── network
├── __init__.py
├── blocks.py
├── depth_network.py
├── pose_network.py
└── volumetric_fusionnet.py
├── requirements.txt
├── train.py
├── trainer
├── __init__.py
└── trainer.py
└── utils
├── __init__.py
├── logger.py
├── misc.py
└── visualize.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | ._.DS_Store
3 | .__pycache__
4 | */__pycache__/*
5 | **/__pycache__/
6 | .ipynb_checkpoints
7 | */.ipynb_checkpoints/*
8 | results
9 | weights_SF
10 | weights_MF
11 | *.pyc
12 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "external/packnet_sfm"]
2 | path = external/packnet_sfm
3 | url = https://github.com/TRI-ML/packnet-sfm
4 | [submodule "external/dgp"]
5 | path = external/dgp
6 | url = https://github.com/TRI-ML/dgp
7 | branch = v1.5
8 | [submodule "models/gaussian/gaussian-splatting"]
9 | path = models/gaussian/gaussian-splatting
10 | url = https://github.com/graphdeco-inria/gaussian-splatting
11 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Qijian Tian
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
DrivingForward: Feed-forward 3D Gaussian Splatting for Driving Scene Reconstruction from Flexible Surround-view Input
3 |
4 | Qijian Tian1
5 | ·
6 | Xin Tan2
7 | ·
8 | Yuan Xie2
9 | ·
10 | Lizhuang Ma1
11 |
12 |
13 | 1Shanghai Jiao Tong University
14 |
15 | 2East China Normal University
16 |
17 | AAAI 2025
18 |
19 |
20 |
21 | ## Introduction
22 |
23 | We propose a feed-forward Gaussian Splatting model that reconstructs driving scenes from flexible sparse surround-view input.
24 |
25 |
26 |
27 | Given sparse surround-view input from vehicle-mounted cameras, our model learns
28 | scale-aware localization for Gaussian primitives from the small overlap of spatial and temporal context views. A Gaussian
29 | network predicts other parameters from each image individually. This feed-forward pipeline enables the real-time reconstruction
30 | of driving scenes and the independent prediction from single-frame images supports flexible input modes. At the inference stage,
31 | we include only the depth network and the Gaussian network, as shown in the lower part of the figure.
32 |
33 | ## Installation
34 |
35 | To get started, clone this project, create a conda virtual environment using Python 3.8, and install the requirements:
36 |
37 | ```bash
38 | git clone https://github.com/fangzhou2000/DrivingForward
39 | git submodule update --init --recursive
40 | cd DrivingForward
41 | conda create -n DrivingForward python=3.8
42 | conda activate DrivingForward
43 | pip install torch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 --index-url https://download.pytorch.org/whl/cu113
44 | pip install -r requirements.txt
45 | cd models/gaussian/gaussian-splatting
46 | pip install submodules/diff-gaussian-rasterization
47 | cd ../../..
48 | ```
49 |
50 | ## Datasets
51 |
52 | ### nuScenes
53 | * Download [nuScenes](https://www.nuscenes.org/nuscenes) official dataset
54 | * Place the dataset in `input_data/nuscenes/`
55 |
56 | Data should be as follows:
57 | ```
58 | ├── input_data
59 | │ ├── nuscenes
60 | │ │ ├── maps
61 | │ │ ├── samples
62 | │ │ ├── sweeps
63 | │ │ ├── v1.0-test
64 | │ │ ├── v1.0-trainval
65 | ```
66 |
67 | ## Running the Code
68 |
69 | ### Evaluation
70 |
71 | Get the [pretrained models](https://drive.google.com/drive/folders/1IASOPK1RQeP-nLQvJUn7WQUtb_fwGlVS), save them to the root directory of the project, and unzip them.
72 |
73 | For SF mode, run the following:
74 | ```
75 | python -W ignore eval.py --weight_path ./weights_SF --novel_view_mode SF
76 | ```
77 |
78 |
79 | For MF mode, run the following:
80 | ```
81 | python -W ignore eval.py --weight_path ./weights_MF --novel_view_mode MF
82 | ```
83 |
84 | ### Training
85 |
86 | For SF mode, run the following:
87 | ```
88 | python -W ignore train.py --novel_view_mode SF
89 | ```
90 |
91 | For MF mode, run the following:
92 | ```
93 | python -W ignore train.py --novel_view_mode MF
94 | ```
95 |
96 | ## BibTeX
97 | ```
98 | @inproceedings{tian2025drivingforward,
99 | title={DrivingForward: Feed-forward 3D Gaussian Splatting for Driving Scene Reconstruction from Flexible Surround-view Input},
100 | author={Qijian Tian and Xin Tan and Yuan Xie and Lizhuang Ma},
101 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
102 | year={2025}
103 | }
104 | ```
105 |
106 | ## Acknowledgements
107 |
108 | The project is partially based on some awesome repos: [MVSplat](https://github.com/donydchen/mvsplat), [GPS-Gaussian](https://github.com/aipixel/GPS-Gaussian), and [VFDepth](https://github.com/42dot/VFDepth). Many thanks to these projects for their excellent contributions!
109 |
--------------------------------------------------------------------------------
/assets/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fangzhou2000/DrivingForward/5d0a2c7d6a6358f55318faac80e351cf75dff032/assets/framework.png
--------------------------------------------------------------------------------
/configs/nuscenes/main.yaml:
--------------------------------------------------------------------------------
1 | ddp:
2 | ddp_enable: False
3 | world_size: 1
4 | gpus: [0]
5 |
6 | model:
7 | # mode: 'MF' or 'SF'
8 | novel_view_mode: 'MF'
9 |
10 | num_layers: 18
11 | weights_init: True
12 |
13 | fusion_level: 2
14 | fusion_feat_in_dim: 256
15 | use_skips: False
16 |
17 | voxel_unit_size: [1.0, 1.0, 1.5]
18 | voxel_size: [100, 100, 20]
19 | voxel_str_p: [-50.0, -50.0, -15.0]
20 | voxel_pre_dim: [64]
21 | proj_d_bins: 50
22 | proj_d_str: 2
23 | proj_d_end: 50
24 |
25 | data:
26 | data_path: './input_data/nuscenes/'
27 | log_dir: './results/'
28 | dataset: 'nuscenes'
29 | back_context: 1
30 | forward_context: 1
31 | depth_type: 'lidar'
32 | cameras: ['CAM_FRONT', 'CAM_FRONT_LEFT', 'CAM_FRONT_RIGHT', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT', 'CAM_BACK']
33 | # gt_ego_pose, gt_depth is not used during training, only for visulization
34 | train_requirements: (gt_pose, gt_ego_pose, gt_depth, mask)
35 | val_requirements: (gt_pose, gt_ego_pose, gt_depth, mask)
36 |
37 | training:
38 | # basic
39 | height: 352
40 | width: 640
41 | scales: [0]
42 | frame_ids: [0, -1, 1]
43 |
44 | # optimization
45 | batch_size: 1
46 | num_workers: 4
47 | learning_rate: 0.0001
48 | num_epochs: 10
49 | scheduler_step_size: 15
50 |
51 | # model / loss setting
52 | ## depth range
53 | min_depth: 1.5
54 | max_depth: 80.0
55 |
56 | ## spatio & temporal
57 | spatio: True
58 | spatio_temporal: True
59 |
60 | ## gaussian
61 | gaussian: True
62 |
63 | ## intensity align
64 | intensity_align: True
65 |
66 | ## focal length scaling
67 | focal_length_scale: 300
68 |
69 | # Loss hyperparams
70 | loss:
71 | disparity_smoothness: 0.001
72 | spatio_coeff: 0.03
73 | spatio_tempo_coeff: 0.1
74 | gaussian_coeff: 0.01
75 | pose_loss_coeff: 0.0
76 |
77 | eval:
78 | eval_batch_size: 1
79 | eval_num_workers: 4
80 | eval_min_depth: 0
81 | eval_max_depth: 80
82 | eval_visualize: False
83 | save_images: False
84 | save_path: './results/images'
85 |
86 | load:
87 | pretrain: False
88 | models_to_load: ['depth_net', 'gs_net']
89 | load_dir: path to weights directory
90 |
91 | logging:
92 | early_phase: 5000
93 | log_frequency: 1000
94 | late_log_frequency: 5000
95 | save_frequency: 1
96 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_dataset import construct_dataset
2 |
3 | __all__ = ['construct_dataset']
--------------------------------------------------------------------------------
/dataset/base_dataset.py:
--------------------------------------------------------------------------------
1 | from external.dataset import get_transforms
2 |
3 |
4 | def construct_dataset(cfg, mode, **kwargs):
5 | """
6 | This function constructs datasets.
7 | """
8 | # dataset arguments for the dataloader
9 | if mode == 'train':
10 | dataset_args = {
11 | 'cameras': cfg['data']['cameras'],
12 | 'back_context': cfg['data']['back_context'],
13 | 'forward_context': cfg['data']['forward_context'],
14 | 'data_transform': get_transforms('train', **kwargs),
15 | 'depth_type': cfg['data']['depth_type'] if 'gt_depth' in cfg['data']['train_requirements'] else None,
16 | 'with_pose': 'gt_pose' in cfg['data']['train_requirements'],
17 | 'with_ego_pose': 'gt_ego_pose' in cfg['data']['train_requirements'],
18 | 'with_mask': 'mask' in cfg['data']['train_requirements']
19 | }
20 |
21 | elif mode == 'val' or mode == 'eval':
22 | dataset_args = {
23 | 'cameras': cfg['data']['cameras'],
24 | 'back_context': cfg['data']['back_context'],
25 | 'forward_context': cfg['data']['forward_context'],
26 | 'data_transform': get_transforms('train', **kwargs), # for aligning inputs without any augmentations
27 | 'depth_type': cfg['data']['depth_type'] if 'gt_depth' in cfg['data']['val_requirements'] else None,
28 | 'with_pose': 'gt_pose' in cfg['data']['val_requirements'],
29 | 'with_ego_pose': 'gt_ego_pose' in cfg['data']['val_requirements'],
30 | 'with_mask': 'mask' in cfg['data']['val_requirements']
31 | }
32 |
33 | # NuScenes dataset
34 | if cfg['data']['dataset'] == 'nuscenes':
35 | from dataset.nuscenes_dataset import NuScenesdataset
36 | if mode == 'train':
37 | split = 'train'
38 | else:
39 | if cfg['model']['novel_view_mode'] == 'MF':
40 | split = 'eval_MF'
41 | elif cfg['model']['novel_view_mode'] == 'SF':
42 | split = 'eval_SF'
43 | else:
44 | raise ValueError('Unknown novel view mode: ' + cfg['model']['novel_view_mode'])
45 | dataset = NuScenesdataset(
46 | cfg['data']['data_path'], split,
47 | **dataset_args
48 | )
49 | else:
50 | raise ValueError('Unknown dataset: ' + cfg['data']['dataset'])
51 | return dataset
--------------------------------------------------------------------------------
/dataset/data_util.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import PIL.Image as pil
5 |
6 | import torch.nn.functional as F
7 | import torchvision.transforms as transforms
8 |
9 | _DEL_KEYS= ['rgb', 'rgb_context', 'rgb_original', 'rgb_context_original', 'intrinsics', 'contexts', 'splitname', 'ego_pose']
10 |
11 |
12 | def transform_mask_sample(sample, data_transform):
13 | """
14 | This function transforms masks to match input rgb images.
15 | """
16 | image_shape = data_transform.keywords['image_shape']
17 | # resize transform
18 | resize_transform = transforms.Resize(image_shape, interpolation=pil.ANTIALIAS)
19 | sample['mask'] = resize_transform(sample['mask'])
20 | # totensor transform
21 | tensor_transform = transforms.ToTensor()
22 | sample['mask'] = tensor_transform(sample['mask'])
23 | return sample
24 |
25 |
26 | def img_loader(path):
27 | """
28 | This function loads rgb image.
29 | """
30 | with open(path, 'rb') as f:
31 | with pil.open(f) as img:
32 | return img.convert('RGB')
33 |
34 |
35 | def mask_loader_scene(path, mask_idx, cam):
36 | """
37 | This function loads mask that correspondes to the scene and camera.
38 | """
39 | fname = os.path.join(path, str(mask_idx), '{}_mask.png'.format(cam.upper()))
40 | with open(fname, 'rb') as f:
41 | with pil.open(f) as img:
42 | return img.convert('L')
43 |
44 |
45 | def align_dataset(sample, scales, contexts):
46 | """
47 | This function reorganize samples to match our trainer configuration.
48 | """
49 | K = sample['intrinsics']
50 | aug_images = sample['rgb']
51 | aug_contexts = sample['rgb_context']
52 | org_images = sample['rgb_original']
53 | org_contexts= sample['rgb_context_original']
54 | ego_poses = sample['ego_pose']
55 |
56 | n_cam, _, w, h = aug_images.shape
57 |
58 | # initialize intrinsics
59 | resized_K = np.expand_dims(np.eye(4), 0).repeat(n_cam, axis=0)
60 | resized_K[:, :3, :3] = K
61 |
62 | # augment images and intrinsics in accordance with scales
63 | for scale in scales:
64 | scaled_K = resized_K.copy()
65 | scaled_K[:,:2,:] /= (2**scale)
66 |
67 | sample[('K', scale)] = scaled_K.copy()
68 | sample[('inv_K', scale)]= np.linalg.pinv(scaled_K).copy()
69 |
70 | resized_org = F.interpolate(org_images,
71 | size=(w//(2**scale),h//(2**scale)),
72 | mode = 'bilinear',
73 | align_corners=False)
74 | resized_aug = F.interpolate(aug_images,
75 | size=(w//(2**scale),h//(2**scale)),
76 | mode = 'bilinear',
77 | align_corners=False)
78 |
79 | sample[('color', 0, scale)] = resized_org
80 | sample[('color_aug', 0, scale)] = resized_aug
81 |
82 | # for context data
83 | for idx, frame in enumerate(contexts):
84 | sample[('color', frame, 0)] = org_contexts[idx]
85 | sample[('color_aug',frame, 0)] = aug_contexts[idx]
86 | sample[('cam_T_cam', 0, frame)] = ego_poses[idx]
87 |
88 | # delete unused arrays
89 | for key in list(sample.keys()):
90 | if key in _DEL_KEYS:
91 | del sample[key]
92 | return sample
93 |
--------------------------------------------------------------------------------
/dataset/nuscenes/eval_SF.txt:
--------------------------------------------------------------------------------
1 | fd8420396768425eabec9bdddf7e64b6
2 | 88a090c54cc844cc878f00d08274a683
3 | b9c60cfaf1814c8bb363037dec7abd35
4 | 43cfc03be6f842ae8942c995f8e3f2fb
5 | 234b0f37af694571876a745e5bbb5e08
6 | fbd3aa2098234fb1869d34d920e16eda
7 | 134c7f63827c4b858a2a4ce0582d8f6e
8 | f0e4797551024f9487a016ee9d9e29e1
9 | d62b0f9db06440c695781c52ed71909a
10 | e1b5c775be134db696ce22413d195c72
11 | 0f39f34febc84a6689f08599105d1421
12 | 15c59626df6b4f96bb7ef7c5c1797159
13 | 8e46878ced734164aae0e83567aa02a5
14 | 6e0955c28ecb405b98d91b089761dbbb
15 | 19ffc60b7e9f429cbca0dbd5e2fc2e03
16 | 78be3bd5f1f146cdb1b7120bbf8a218c
17 | f9475d1f1f96401c81ce511a93578a96
18 | 1c81b06c812b47788462c31f58b6c1d0
19 | c292d90b90ec4dbd826dfd11aa361df1
20 | b9db804449d34414ba98868e9034abd3
21 | 028a3f3a237e469cb099f6d9c8e7f2f4
22 | d5324133e1d241658a7b763831b3af2e
23 | 806011e8a7114eca971fd0f6b95bcd4b
24 | 3e8750f331d7499e9b5123e9eb70f2e2
25 | f1b7b463072e4630b2c1d5ba8b6cd50b
26 | 8d56b71c09c04ece827107b3c103498f
27 | 04ebf63518914e0883bdffdb0b6c3f4c
28 | ad8905836f364a87a7233eb0aef915ea
29 | 339ead96177c4e338fde8235c188cfaa
30 | 4d0d06979c984a72bfb9e6c5e3a35f84
31 | 80c91574d0174206a74435200aba8ba8
32 | 79070dca796f4ffb868d84f1b067f9c2
33 | 30e55a3ec6184d8cb1944b39ba19d622
34 | adce54dcf6404d37bd170ffc4c2a7836
35 | 51f37263b06c417e94138f2d9348f909
36 | c28c950afda14f7b93ff956799873dff
37 | e261f474caa34797b0759aa2ca8a576d
38 | 1d477a7c21784a98b2e6a9289a958fe2
39 | 040a956bf0c04ad09fffc286cab5bc56
40 | 8fb167b9767d420b9c6368719bafb7b4
41 | 3502d36d2c76428f9b79f2854e96b696
42 | 7e2fd590b46d42cdbe8dfbf035761561
43 | f4b5b6fc59e34da6bd18d4d8e299dab8
44 | 269601cc1e4d45a3a36fbf291d45a447
45 | b98e65c90a7a445a907a3bb1467b551b
46 | 09361337f2bf4e819aac2a310481e62f
47 | 097c7295f13a4746a2c46d036c21bd3e
48 | 418cbb70159341c1ada4f3ba1fb69713
49 | c525629d7b3749d8968b486bb5412adf
50 | c35efe5eaf784275b8c7e31fb50aa902
51 | 4e26b42796024b45a9a852738e3701bb
52 | 6123a0c091d24c39ab274cdbad994bc4
53 | 107dd32140844573924260f9ad9390bf
54 | 5ec2ee2d55cc4582acd0f92476845a74
55 | ce03b6f0c1d2455aa18f6275917485bc
56 | 89c9204978674480aec9d2a9a377eecc
57 | 5b90526901414d2aa5db2b9fdaf858eb
58 | 8687ba92abd3406aa797115b874ebeba
59 | e0b85d628af34d6bbbbb0da216aa5bca
60 | 12a42a64274a4453b166d804b625d61f
61 | 17e6f41d3b1f40d4bc85412fe8b67688
62 | 6cb74a87748b4fab82ad1b0869b27e01
63 | b4b5e858cd9d4cf9bdb01098ea7d3c45
64 | 029af235f15d4f6f80713b778edc3c61
65 | 5c558d97108341ebb7b30b1c841ff57f
66 | 9eebbe06762c47aa8c671008df5fb474
67 | e493c6923a0a4d5fb78b75dcc0ffc5ab
68 | 970f41b6fc4544b388fe7ccd578988a9
69 | 3b86f8218554493f928b39824655872a
70 | f0c47597c226445fa1e14c82d879f9a4
71 | 67d2d74087714e4994a68c7347acf55a
72 | 57f23d2f55e6455696c6d3cc70a4f501
73 | 0fae8aae30d44b4faf7a9854eac4bc09
74 | 14384b3ce9d243ce827bf0c1598dd35f
75 | 6a808b09e5f34d33ba1de76cc8dab423
76 | 78c80690450a4e8aba63e1f2ca27c207
77 | fe5b1b2e8aa2430e90c5c90825f9f1f2
78 | 3bb017b691c34f8f89985c0c2220ccec
79 | 43f05ac397484672b9bdf86dcc0b0289
80 | 3ac231a1ef4c481c9c15921cafa87307
81 | 8b77b8e2cbf144bea91f4c1856b05a7c
82 | ccd7c315063047ddb231fceb7ceff64b
83 | 4cecbedd7c9c406ba5a0b22a3fa2e49c
84 | 5a3a15f4ebdb4e6b83eb58628dc8e83f
85 | e836d06299ce46b884b9187ac6017579
86 | bb10039077c74099b51fb67300391659
87 | 485c78dbfed04fd5a6aad2a88b116fe9
88 | 7949609177914534a89a8c7985ae4db7
89 | 87c3788da01f4f63953f35383d0f5095
90 | 9699d6a8d9384f8885e8c5318bc621ab
91 | e829650532104af6bdb060d75924421f
92 | abad69b7e6194c95b18a054aa082cfb8
93 | 159f9a70a6d64e27a38b3b2668b5a68e
94 | f8e8dd98088d458abb0a664330403168
95 | c1676a2feac74eee8aa38ca3901787d6
96 | 604d12ebcf784c3f945189f79262f19c
97 | 154dbc8467d64780a1acffda0193d46b
98 | 2eef1a0ab4484a1faa73e7c2a048149e
99 | 94cd383effba4e688570233a060a13f4
100 | 63e3bc947eed4823b00e1beadecd3be3
101 | 4e04d32a9949430287c14d49851f6f55
102 | bc94beb70f694ee8a0399c4e7e1fc2d0
103 | 2fb7b5d1eab24b13abb4e0703850dba9
104 | e693989d8ad141608999200df71aca71
105 | a4782f7ade234010b5b50553403f6e79
106 | 3abf81a7c3894000a4c508e6ced0caca
107 | 43da50265dfa4876bdf1efed38e7a261
108 | d056b9bdd56f44669540c6c323042d30
109 | 8487e4c995874bb2aeb41d6484717d10
110 | 1f048fa210c64e80b9d57e3dc757702c
111 | f8de6e9a14d34f2791c21a07ce5d47c5
112 | 7f8368b7a14341b3bef03a7aec63f9b5
113 | b5989651183643369174912bc5641d3b
114 | 8c5c08c818d247ee9de5aa08e6762003
115 | bd0c27135b3c43c6b0e59f03fdf441cb
116 | e8a25125ab404726ab03976d3871c460
117 | 852e22dbdbbf422e99224c488d7618bf
118 | 73c335a2a207457ea36cc855b2417bd9
119 | c07a480048b04e989f6fc1e002e5c5b8
120 | c201dea88033406fb304747d32b823a6
121 | e6fcfc5cd0254eee901c1bbc0a30a6ba
122 | 45eda670b5a840acbb730aac15e63b19
123 | 5d785cd90a884e0ebf9c5047a527df76
124 | fd0915986a6444ee8296894d2a307291
125 | dad9942ea8f1409985140f500c69c88e
126 | e57fa7124f6741e697b1841bc023a90f
127 | b139c133286247d48093bba9151920b2
128 | b2a93b022374432d917cc7db15ca023a
129 | 637fc435f95f4bc1b72f09811f2d78d2
130 | b5cbf927fe0d47c39a9c1d1b05279c55
131 | 23d02c83708b4f5e9579bbd99370a672
132 | aa298fdc84fb489a911a3f363d3beda9
133 | 923b34d42c5048c0a4db75c65de52f3c
134 | 45c4bd03ee6f4a369dd42abf08dc3957
135 | e80bfd61023745b0a97caae3ff4a1727
136 | cb4504edb87a40d4b650cd5860f6c3b7
137 | e74b8ce0b1824e0c8167f2534d8d4f7d
138 | 79f6489272c24d3ebc5e225ce6ff2aea
139 | 617899966d3e4fea9042a1f1a26b764e
140 | 42c250251b504ee8a5f1e49def135715
141 | b62a484b94a04ae4b1b2729fb2064696
142 | dc2def2a7bc94693b67eba2ffbe13a56
143 | 21bb21c84c3e43908cd554db2686278f
144 | b6a9d120895e4207a0527ad4b3284d42
145 | 2920a3315ccd4fefb30fb0468ebb2351
146 | b58985f7ddc2454d95cb84a65d1fe9fe
147 | 23779301ebc34c1284e539ddf057f0b4
148 | 7e351605549547f8af19432917f23b73
149 | f9ad282e31ce47d5a16dffc3b25c456a
150 | 127c5ddf95e947ef96888e56d5a355f6
--------------------------------------------------------------------------------
/dataset/nuscenes_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import PIL.Image as pil
5 |
6 | import torch
7 | from torch.utils.data import Dataset
8 |
9 | from nuscenes.nuscenes import NuScenes
10 | from pyquaternion import Quaternion
11 |
12 | from .data_util import img_loader, mask_loader_scene, align_dataset, transform_mask_sample
13 |
14 |
15 | def is_numpy(data):
16 | """Checks if data is a numpy array."""
17 | return isinstance(data, np.ndarray)
18 |
19 | def is_tensor(data):
20 | """Checks if data is a torch tensor."""
21 | return type(data) == torch.Tensor
22 |
23 | def is_list(data):
24 | """Checks if data is a list."""
25 | return isinstance(data, list)
26 |
27 | def stack_sample(sample):
28 | """Stack a sample from multiple sensors"""
29 | # If there is only one sensor don't do anything
30 | if len(sample) == 1:
31 | return sample[0]
32 |
33 | # Otherwise, stack sample
34 | stacked_sample = {}
35 | for key in sample[0]:
36 | # Global keys (do not stack)
37 | if key in ['idx', 'dataset_idx', 'sensor_name', 'filename', 'token']:
38 | stacked_sample[key] = sample[0][key]
39 | else:
40 | # Stack torch tensors
41 | if is_tensor(sample[0][key]):
42 | stacked_sample[key] = torch.stack([s[key] for s in sample], 0)
43 | # Stack numpy arrays
44 | elif is_numpy(sample[0][key]):
45 | stacked_sample[key] = np.stack([s[key] for s in sample], 0)
46 | # Stack list
47 | elif is_list(sample[0][key]):
48 | stacked_sample[key] = []
49 | # Stack list of torch tensors
50 | if is_tensor(sample[0][key][0]):
51 | for i in range(len(sample[0][key])):
52 | stacked_sample[key].append(
53 | torch.stack([s[key][i] for s in sample], 0))
54 | # Stack list of numpy arrays
55 | if is_numpy(sample[0][key][0]):
56 | for i in range(len(sample[0][key])):
57 | stacked_sample[key].append(
58 | np.stack([s[key][i] for s in sample], 0))
59 |
60 | # Return stacked sample
61 | return stacked_sample
62 |
63 | class NuScenesdataset(Dataset):
64 | """
65 | Loaders for NuScenes dataset
66 | """
67 | def __init__(self, path, split,
68 | cameras=None,
69 | back_context=0,
70 | forward_context=0,
71 | data_transform=None,
72 | depth_type=None,
73 | scale_range=2,
74 | with_pose=None,
75 | with_ego_pose=None,
76 | with_mask=None,
77 | ):
78 | super().__init__()
79 | version = 'v1.0-trainval'
80 | self.path = path
81 | self.split = split
82 | self.dataset_idx = 0
83 |
84 | self.cameras = cameras
85 | self.scales = np.arange(scale_range+2)
86 | self.num_cameras = len(cameras)
87 |
88 | self.bwd = back_context
89 | self.fwd = forward_context
90 |
91 | self.has_context = back_context + forward_context > 0
92 | self.data_transform = data_transform
93 |
94 | self.with_depth = depth_type is not None
95 | self.with_pose = with_pose
96 | self.with_ego_pose = with_ego_pose
97 |
98 | self.loader = img_loader
99 |
100 | self.with_mask = with_mask
101 | cur_path = os.path.dirname(os.path.realpath(__file__))
102 | self.mask_path = os.path.join(cur_path, 'nuscenes_mask')
103 | self.mask_loader = mask_loader_scene
104 |
105 | self.dataset = NuScenes(version=version, dataroot=self.path, verbose=True)
106 |
107 | # list of scenes for training and validation of model
108 | with open('dataset/nuscenes/{}.txt'.format(self.split), 'r') as f:
109 | self.filenames = f.readlines()
110 |
111 | def get_current(self, key, cam_sample):
112 | """
113 | This function returns samples for current contexts
114 | """
115 | # get current timestamp rgb sample
116 | if key == 'rgb':
117 | rgb_path = cam_sample['filename']
118 | return self.loader(os.path.join(self.path, rgb_path))
119 | # get current timestamp camera intrinsics
120 | elif key == 'intrinsics':
121 | cam_param = self.dataset.get('calibrated_sensor',
122 | cam_sample['calibrated_sensor_token'])
123 | return np.array(cam_param['camera_intrinsic'], dtype=np.float32)
124 | # get current timestamp camera extrinsics
125 | elif key == 'extrinsics':
126 | cam_param = self.dataset.get('calibrated_sensor',
127 | cam_sample['calibrated_sensor_token'])
128 | return self.get_tranformation_mat(cam_param)
129 | else:
130 | raise ValueError('Unknown key: ' +key)
131 |
132 | def get_context(self, key, cam_sample):
133 | """
134 | This function returns samples for backward and forward contexts
135 | """
136 | bwd_context, fwd_context = [], []
137 | if self.bwd != 0:
138 | if self.split == 'eval_SF': # validation
139 | bwd_sample = cam_sample
140 | else:
141 | bwd_sample = self.dataset.get('sample_data', cam_sample['prev'])
142 | bwd_context = [self.get_current(key, bwd_sample)]
143 |
144 | if self.fwd != 0:
145 | fwd_sample = self.dataset.get('sample_data', cam_sample['next'])
146 | fwd_context = [self.get_current(key, fwd_sample)]
147 | return bwd_context + fwd_context
148 |
149 | def get_cam_T_cam(self, cam_sample):
150 | # cam 0 to world
151 | # cam 0 to ego 0
152 | cam_to_ego = self.dataset.get(
153 | 'calibrated_sensor', cam_sample['calibrated_sensor_token'])
154 | cam_to_ego_rotation = Quaternion(cam_to_ego['rotation'])
155 | cam_to_ego_translation = np.array(cam_to_ego['translation'])[:, None]
156 | cam_to_ego = np.vstack([
157 | np.hstack((cam_to_ego_rotation.rotation_matrix,
158 | cam_to_ego_translation)),
159 | np.array([0, 0, 0, 1])
160 | ])
161 |
162 | # ego 0 to world
163 | world_to_ego = self.dataset.get(
164 | 'ego_pose', cam_sample['ego_pose_token'])
165 | world_to_ego_rotation = Quaternion(world_to_ego['rotation']).inverse
166 | world_to_ego_translation = - np.array(world_to_ego['translation'])[:, None]
167 | world_to_ego = np.vstack([
168 | np.hstack((world_to_ego_rotation.rotation_matrix,
169 | world_to_ego_rotation.rotation_matrix @ world_to_ego_translation)),
170 | np.array([0, 0, 0, 1])
171 | ])
172 | ego_to_world = np.linalg.inv(world_to_ego)
173 |
174 | cam_T_cam = []
175 |
176 | # cam_T_cam, 0, -1
177 | if self.bwd != 0:
178 | if self.split == 'eval_SF': # validation
179 | bwd_sample = cam_sample
180 | else:
181 | bwd_sample = self.dataset.get('sample_data', cam_sample['prev'])
182 |
183 | # world to ego -1
184 | world_to_ego_bwd = self.dataset.get(
185 | 'ego_pose', bwd_sample['ego_pose_token'])
186 | world_to_ego_bwd_rotation = Quaternion(world_to_ego_bwd['rotation']).inverse
187 | world_to_ego_bwd_translation = - np.array(world_to_ego_bwd['translation'])[:, None]
188 | world_to_ego_bwd = np.vstack([
189 | np.hstack((world_to_ego_bwd_rotation.rotation_matrix,
190 | world_to_ego_bwd_rotation.rotation_matrix @ world_to_ego_bwd_translation)),
191 | np.array([0, 0, 0, 1])
192 | ])
193 |
194 | # ego -1 to cam -1
195 | cam_to_ego_bwd = self.dataset.get(
196 | 'calibrated_sensor', bwd_sample['calibrated_sensor_token'])
197 | cam_to_ego_bwd_rotation = Quaternion(cam_to_ego_bwd['rotation'])
198 | cam_to_ego_bwd_translation = np.array(cam_to_ego_bwd['translation'])[:, None]
199 | cam_to_ego_bwd = np.vstack([
200 | np.hstack((cam_to_ego_bwd_rotation.rotation_matrix,
201 | cam_to_ego_bwd_translation)),
202 | np.array([0, 0, 0, 1])
203 | ])
204 | ego_to_cam_bwd = np.linalg.inv(cam_to_ego_bwd)
205 |
206 | cam_T_cam_bwd = ego_to_cam_bwd @ world_to_ego_bwd @ ego_to_world @ cam_to_ego
207 |
208 | cam_T_cam.append(cam_T_cam_bwd)
209 |
210 | # cam_T_cam, 0, 1
211 | if self.fwd != 0:
212 | fwd_sample = self.dataset.get('sample_data', cam_sample['next'])
213 |
214 | # world to ego 1
215 | world_to_ego_fwd = self.dataset.get(
216 | 'ego_pose', fwd_sample['ego_pose_token'])
217 | world_to_ego_fwd_rotation = Quaternion(world_to_ego_fwd['rotation']).inverse
218 | world_to_ego_fwd_translation = - np.array(world_to_ego_fwd['translation'])[:, None]
219 | world_to_ego_fwd = np.vstack([
220 | np.hstack((world_to_ego_fwd_rotation.rotation_matrix,
221 | world_to_ego_fwd_rotation.rotation_matrix @ world_to_ego_fwd_translation)),
222 | np.array([0, 0, 0, 1])
223 | ])
224 |
225 | # ego 1 to cam 1
226 | cam_to_ego_fwd = self.dataset.get(
227 | 'calibrated_sensor', fwd_sample['calibrated_sensor_token'])
228 | cam_to_ego_fwd_rotation = Quaternion(cam_to_ego_fwd['rotation'])
229 | cam_to_ego_fwd_translation = np.array(cam_to_ego_fwd['translation'])[:, None]
230 | cam_to_ego_fwd = np.vstack([
231 | np.hstack((cam_to_ego_fwd_rotation.rotation_matrix,
232 | cam_to_ego_fwd_translation)),
233 | np.array([0, 0, 0, 1])
234 | ])
235 | ego_to_cam_fwd = np.linalg.inv(cam_to_ego_fwd)
236 |
237 | cam_T_cam_fwd = ego_to_cam_fwd @ world_to_ego_fwd @ ego_to_world @ cam_to_ego
238 |
239 | cam_T_cam.append(cam_T_cam_fwd)
240 |
241 | return cam_T_cam
242 |
243 | def generate_depth_map(self, sample, sensor, cam_sample):
244 | """
245 | This function returns depth map for nuscenes dataset,
246 | result of depth map is saved in nuscenes/samples/DEPTH_MAP
247 | """
248 | # generate depth filename
249 | filename = '{}/{}.npz'.format(
250 | os.path.join(os.path.dirname(self.path), 'samples'),
251 | 'DEPTH_MAP/{}/{}'.format(sensor, cam_sample['filename']))
252 |
253 | # load and return if exists
254 | if os.path.exists(filename):
255 | return np.load(filename, allow_pickle=True)['depth']
256 | else:
257 | lidar_sample = self.dataset.get(
258 | 'sample_data', sample['data']['LIDAR_TOP'])
259 |
260 | # lidar points
261 | lidar_file = os.path.join(
262 | self.path, lidar_sample['filename'])
263 | lidar_points = np.fromfile(lidar_file, dtype=np.float32)
264 | lidar_points = lidar_points.reshape(-1, 5)[:, :3]
265 |
266 | # lidar -> world
267 | lidar_pose = self.dataset.get(
268 | 'ego_pose', lidar_sample['ego_pose_token'])
269 | lidar_rotation= Quaternion(lidar_pose['rotation'])
270 | lidar_translation = np.array(lidar_pose['translation'])[:, None]
271 | lidar_to_world = np.vstack([
272 | np.hstack((lidar_rotation.rotation_matrix, lidar_translation)),
273 | np.array([0, 0, 0, 1])
274 | ])
275 |
276 | # lidar -> ego
277 | sensor_sample = self.dataset.get(
278 | 'calibrated_sensor', lidar_sample['calibrated_sensor_token'])
279 | lidar_to_ego_rotation = Quaternion(
280 | sensor_sample['rotation']).rotation_matrix
281 | lidar_to_ego_translation = np.array(
282 | sensor_sample['translation']).reshape(1, 3)
283 |
284 | ego_lidar_points = np.dot(
285 | lidar_points[:, :3], lidar_to_ego_rotation.T)
286 | ego_lidar_points += lidar_to_ego_translation
287 |
288 | homo_ego_lidar_points = np.concatenate(
289 | (ego_lidar_points, np.ones((ego_lidar_points.shape[0], 1))), axis=1)
290 |
291 |
292 | # world -> ego
293 | ego_pose = self.dataset.get(
294 | 'ego_pose', cam_sample['ego_pose_token'])
295 | ego_rotation = Quaternion(ego_pose['rotation']).inverse
296 | ego_translation = - np.array(ego_pose['translation'])[:, None]
297 | world_to_ego = np.vstack([
298 | np.hstack((ego_rotation.rotation_matrix,
299 | ego_rotation.rotation_matrix @ ego_translation)),
300 | np.array([0, 0, 0, 1])
301 | ])
302 |
303 | # Ego -> sensor
304 | sensor_sample = self.dataset.get(
305 | 'calibrated_sensor', cam_sample['calibrated_sensor_token'])
306 | sensor_rotation = Quaternion(sensor_sample['rotation'])
307 | sensor_translation = np.array(
308 | sensor_sample['translation'])[:, None]
309 | sensor_to_ego = np.vstack([
310 | np.hstack((sensor_rotation.rotation_matrix,
311 | sensor_translation)),
312 | np.array([0, 0, 0, 1])
313 | ])
314 | ego_to_sensor = np.linalg.inv(sensor_to_ego)
315 |
316 | # lidar -> sensor
317 | lidar_to_sensor = ego_to_sensor @ world_to_ego @ lidar_to_world
318 | homo_ego_lidar_points = torch.from_numpy(homo_ego_lidar_points).float()
319 | cam_lidar_points = np.matmul(lidar_to_sensor, homo_ego_lidar_points.T).T
320 |
321 | # depth > 0
322 | depth_mask = cam_lidar_points[:, 2] > 0
323 | cam_lidar_points = cam_lidar_points[depth_mask]
324 |
325 | # sensor -> image
326 | intrinsics = np.eye(4)
327 | intrinsics[:3, :3] = sensor_sample['camera_intrinsic']
328 | pixel_points = np.matmul(intrinsics, cam_lidar_points.T).T
329 | pixel_points[:, :2] /= pixel_points[:, 2:3]
330 |
331 | # load image for pixel range
332 | image_filename = os.path.join(
333 | self.path, cam_sample['filename'])
334 | img = pil.open(image_filename)
335 | h, w, _ = np.array(img).shape
336 |
337 | # mask points in pixel range
338 | pixel_mask = (pixel_points[:, 0] >= 0) & (pixel_points[:, 0] <= w-1)\
339 | & (pixel_points[:,1] >= 0) & (pixel_points[:,1] <= h-1)
340 | valid_points = pixel_points[pixel_mask].round().int()
341 | valid_depth = cam_lidar_points[:, 2][pixel_mask]
342 |
343 | depth = np.zeros([h, w])
344 | depth[valid_points[:, 1], valid_points[:,0]] = valid_depth
345 |
346 | # save depth map
347 | os.makedirs(os.path.dirname(filename), exist_ok=True)
348 | np.savez_compressed(filename, depth=depth)
349 | return depth
350 |
351 | def get_tranformation_mat(self, pose):
352 | """
353 | This function transforms pose information in accordance with DDAD dataset format
354 | """
355 | extrinsics = Quaternion(pose['rotation']).transformation_matrix
356 | extrinsics[:3, 3] = np.array(pose['translation'])
357 | return extrinsics.astype(np.float32)
358 |
359 | def __len__(self):
360 | return len(self.filenames)
361 |
362 | def __getitem__(self, idx):
363 | # get nuscenes dataset sample
364 | frame_idx = self.filenames[idx].strip().split()[0]
365 | sample_nusc = self.dataset.get('sample', frame_idx)
366 |
367 | sample = []
368 | contexts = []
369 | if self.bwd:
370 | contexts.append(-1)
371 | if self.fwd:
372 | contexts.append(1)
373 |
374 | # loop over all cameras
375 | for cam in self.cameras:
376 | cam_sample = self.dataset.get(
377 | 'sample_data', sample_nusc['data'][cam])
378 |
379 | data = {
380 | 'idx': idx,
381 | 'token': frame_idx,
382 | 'sensor_name': cam,
383 | 'contexts': contexts,
384 | 'filename': cam_sample['filename'],
385 | 'rgb': self.get_current('rgb', cam_sample),
386 | 'intrinsics': self.get_current('intrinsics', cam_sample)
387 | }
388 |
389 | # if depth is returned
390 | if self.with_depth:
391 | data.update({
392 | 'depth': self.generate_depth_map(sample_nusc, cam, cam_sample)
393 | })
394 | # if pose is returned
395 | if self.with_pose:
396 | data.update({
397 | 'extrinsics':self.get_current('extrinsics', cam_sample)
398 | })
399 | # if ego_pose is returned
400 | if self.with_ego_pose:
401 | data.update({
402 | 'ego_pose': self.get_cam_T_cam(cam_sample)
403 | })
404 | # if mask is returned
405 | if self.with_mask:
406 | data.update({
407 | 'mask': self.mask_loader(self.mask_path, '', cam)
408 | })
409 | # if context is returned
410 | if self.has_context:
411 | data.update({
412 | 'rgb_context': self.get_context('rgb', cam_sample)
413 | })
414 |
415 | sample.append(data)
416 |
417 | # apply same data transformations for all sensors
418 | if self.data_transform:
419 | sample = [self.data_transform(smp) for smp in sample]
420 | sample = [transform_mask_sample(smp, self.data_transform) for smp in sample]
421 |
422 | # stack and align dataset for our trainer
423 | sample = stack_sample(sample)
424 | sample = align_dataset(sample, self.scales, contexts)
425 | return sample
426 |
--------------------------------------------------------------------------------
/dataset/nuscenes_mask/CAM_BACK_LEFT_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fangzhou2000/DrivingForward/5d0a2c7d6a6358f55318faac80e351cf75dff032/dataset/nuscenes_mask/CAM_BACK_LEFT_mask.png
--------------------------------------------------------------------------------
/dataset/nuscenes_mask/CAM_BACK_RIGHT_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fangzhou2000/DrivingForward/5d0a2c7d6a6358f55318faac80e351cf75dff032/dataset/nuscenes_mask/CAM_BACK_RIGHT_mask.png
--------------------------------------------------------------------------------
/dataset/nuscenes_mask/CAM_BACK_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fangzhou2000/DrivingForward/5d0a2c7d6a6358f55318faac80e351cf75dff032/dataset/nuscenes_mask/CAM_BACK_mask.png
--------------------------------------------------------------------------------
/dataset/nuscenes_mask/CAM_FRONT_LEFT_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fangzhou2000/DrivingForward/5d0a2c7d6a6358f55318faac80e351cf75dff032/dataset/nuscenes_mask/CAM_FRONT_LEFT_mask.png
--------------------------------------------------------------------------------
/dataset/nuscenes_mask/CAM_FRONT_RIGHT_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fangzhou2000/DrivingForward/5d0a2c7d6a6358f55318faac80e351cf75dff032/dataset/nuscenes_mask/CAM_FRONT_RIGHT_mask.png
--------------------------------------------------------------------------------
/dataset/nuscenes_mask/CAM_FRONT_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fangzhou2000/DrivingForward/5d0a2c7d6a6358f55318faac80e351cf75dff032/dataset/nuscenes_mask/CAM_FRONT_mask.png
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | torch.backends.cudnn.deterministic = True
5 | torch.backends.cudnn.benchmark = False
6 | torch.backends.cuda.matmul.allow_tf32 = False
7 |
8 | import utils
9 | from models import DrivingForwardModel
10 | from trainer import DrivingForwardTrainer
11 |
12 |
13 | def parse_args():
14 | parser = argparse.ArgumentParser(description='evaluation script')
15 | parser.add_argument('--config_file', default='./configs/nuscenes/main.yaml', type=str, help='config yaml file path')
16 | parser.add_argument('--weight_path', default='./weights', type=str, help='weight path')
17 | parser.add_argument('--novel_view_mode', default='MF', type=str, help='MF of SF')
18 | args = parser.parse_args()
19 | return args
20 |
21 | def test(cfg):
22 | print("Evaluating reconstruction")
23 | model = DrivingForwardModel(cfg, 0)
24 | trainer = DrivingForwardTrainer(cfg, 0, use_tb = False)
25 | trainer.evaluate(model)
26 |
27 | if __name__ == '__main__':
28 | args = parse_args()
29 | cfg = utils.get_config(args.config_file, mode='eval', weight_path=args.weight_path, novel_view_mode=args.novel_view_mode)
30 |
31 | test(cfg)
32 |
--------------------------------------------------------------------------------
/external/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from external.packnet_sfm.packnet_sfm.datasets.transforms import get_transforms
2 | from external.packnet_sfm.packnet_sfm.datasets.dgp_dataset import DGPDataset
3 | from external.packnet_sfm.packnet_sfm.datasets.dgp_dataset import stack_sample
4 | from external.packnet_sfm.packnet_sfm.datasets.dgp_dataset import SynchronizedSceneDataset
5 |
6 | __all__ = ['get_transforms', 'stack_sample', 'DGPDataset', 'SynchronizedSceneDataset']
--------------------------------------------------------------------------------
/external/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 42dot. All rights reserved.
2 | from external.packnet_sfm.packnet_sfm.networks.layers.resnet.resnet_encoder import ResnetEncoder
3 | from external.packnet_sfm.packnet_sfm.networks.layers.resnet.pose_decoder import PoseDecoder
4 | from external.packnet_sfm.packnet_sfm.networks.layers.resnet.depth_decoder import DepthDecoder
5 |
6 | __all__ = ['ResnetEncoder', 'PoseDecoder', 'DepthDecoder']
7 |
--------------------------------------------------------------------------------
/external/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 42dot. All rights reserved.
2 | from external.dgp.dgp.utils.camera import Camera
3 | from external.dgp.dgp.utils.camera import generate_depth_map
4 | from external.packnet_sfm.packnet_sfm.utils.misc import make_list
5 |
6 | __all__ = ['Camera', 'generate_depth_map', 'make_list']
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .drivingforward_model import DrivingForwardModel
2 |
3 | __all__ = ['DrivingForwardModel']
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | _OPTIMIZER_NAME ='adam'
5 |
6 |
7 | class BaseModel:
8 | def __init__(self, cfg):
9 | self._dataloaders = {}
10 | self.mode = None
11 | self.models = None
12 | self.optimizer = None
13 | self.lr_scheduler = None
14 | self.ddp_enable = False
15 |
16 | def read_config(self, cfg):
17 | raise NotImplementedError('Not implemented for BaseModel')
18 |
19 | def prepare_dataset(self):
20 | raise NotImplementedError('Not implemented for BaseModel')
21 |
22 | def set_optimizer(self):
23 | raise NotImplementedError('Not implemented for BaseModel')
24 |
25 | def train_dataloader(self):
26 | return self._dataloaders['train']
27 |
28 | def val_dataloader(self):
29 | return self._dataloaders['val']
30 |
31 | def eval_dataloader(self):
32 | return self._dataloaders['eval']
33 |
34 | def set_train(self):
35 | self.mode = 'train'
36 | for m in self.models.values():
37 | m.train()
38 |
39 | def set_val(self):
40 | self.mode = 'val'
41 | for m in self.models.values():
42 | m.eval()
43 |
44 | def set_eval(self):
45 | self.mode = 'eval'
46 | for m in self.models.values():
47 | m.eval()
48 |
49 | def save_model(self, epoch):
50 | curr_model_weights_dir = os.path.join(self.save_weights_root, f'weights_{epoch}')
51 | os.makedirs(curr_model_weights_dir, exist_ok=True)
52 |
53 | for model_name, model in self.models.items():
54 | model_file_path = os.path.join(curr_model_weights_dir, f'{model_name}.pth')
55 | to_save = model.state_dict()
56 | torch.save(to_save, model_file_path)
57 |
58 | # save optimizer
59 | optim_file_path = os.path.join(curr_model_weights_dir, f'{_OPTIMIZER_NAME}.pth')
60 | torch.save(self.optimizer.state_dict(), optim_file_path)
61 |
62 | def load_weights(self):
63 | assert os.path.isdir(self.load_weights_dir), f'\tCannot find {self.load_weights_dir}'
64 | print(f'Loading a model from {self.load_weights_dir}')
65 |
66 | # to retrain
67 | if self.pretrain and self.ddp_enable:
68 | map_location = {'cuda:%d' % 0: 'cuda:%d' % (self.world_size-1)}
69 |
70 | for n in self.models_to_load:
71 | print(f'Loading {n} weights...')
72 | path = os.path.join(self.load_weights_dir, f'{n}.pth')
73 | model_dict = self.models[n].state_dict()
74 |
75 | # distribute gpus for ddp retraining
76 | if self.pretrain and self.ddp_enable:
77 | pre_trained_dict = torch.load(path, map_location=map_location)
78 | else:
79 | pre_trained_dict = torch.load(path)
80 |
81 | # load parameters
82 | pre_trained_dict = {k: v for k, v in pre_trained_dict.items() if k in model_dict}
83 | model_dict.update(pre_trained_dict)
84 | self.models[n].load_state_dict(model_dict)
85 |
86 | if self.mode == 'train':
87 | # loading adam state
88 | optim_file_path = os.path.join(self.load_weights_dir, f'{_OPTIMIZER_NAME}.pth')
89 | if os.path.isfile(optim_file_path):
90 | try:
91 | print(f'Loading {_OPTIMIZER_NAME} weights')
92 | optimizer_dict = torch.load(optim_file_path)
93 | self.optimizer.load_state_dict(optimizer_dict)
94 | except ValueError:
95 | print(f'\tCannnot load {_OPTIMIZER_NAME} - the optimizer will be randomly initialized')
96 | else:
97 | print(f'\tCannot find {_OPTIMIZER_NAME} weights, so the optimizer will be randomly initialized')
--------------------------------------------------------------------------------
/models/drivingforward_model.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 | from torch.utils.data import DataLoader, Subset
7 | torch.manual_seed(0)
8 |
9 | from dataset import construct_dataset
10 | from network import *
11 |
12 | from .base_model import BaseModel
13 | from .geometry import Pose, ViewRendering
14 | from .losses import MultiCamLoss, SingleCamLoss
15 |
16 | from .gaussian import GaussianNetwork, depth2pc, pts2render, focal2fov, getProjectionMatrix, getWorld2View2, rotate_sh
17 | from einops import rearrange
18 |
19 | _NO_DEVICE_KEYS = ['idx', 'dataset_idx', 'sensor_name', 'filename', 'token']
20 |
21 |
22 | class DrivingForwardModel(BaseModel):
23 | def __init__(self, cfg, rank):
24 | super(DrivingForwardModel, self).__init__(cfg)
25 | self.rank = rank
26 | self.read_config(cfg)
27 | self.prepare_dataset(cfg, rank)
28 | self.models = self.prepare_model(cfg, rank)
29 | self.losses = self.init_losses(cfg, rank)
30 | self.view_rendering, self.pose = self.init_geometry(cfg, rank)
31 | self.set_optimizer()
32 |
33 | if self.pretrain and rank == 0:
34 | self.load_weights()
35 |
36 | self.left_cam_dict = {2:0, 0:1, 4:2, 1:3, 5:4, 3:5}
37 | self.right_cam_dict = {0:2, 1:0, 2:4, 3:1, 4:5, 5:3}
38 |
39 | def read_config(self, cfg):
40 | for attr in cfg.keys():
41 | for k, v in cfg[attr].items():
42 | setattr(self, k, v)
43 |
44 | def init_geometry(self, cfg, rank):
45 | view_rendering = ViewRendering(cfg, rank)
46 | pose = Pose(cfg)
47 | return view_rendering, pose
48 |
49 | def init_losses(self, cfg, rank):
50 | if self.spatio_temporal or self.spatio:
51 | loss_model = MultiCamLoss(cfg, rank)
52 | else :
53 | loss_model = SingleCamLoss(cfg, rank)
54 | return loss_model
55 |
56 | def prepare_model(self, cfg, rank):
57 | models = {}
58 | models['pose_net'] = self.set_posenet(cfg)
59 | models['depth_net'] = self.set_depthnet(cfg)
60 | if self.gaussian:
61 | models['gs_net'] = self.set_gaussiannet(cfg)
62 |
63 | return models
64 |
65 | def set_posenet(self, cfg):
66 | return PoseNetwork(cfg).cuda()
67 |
68 | def set_depthnet(self, cfg):
69 | return DepthNetwork(cfg).cuda()
70 |
71 | def set_gaussiannet(self, cfg):
72 | return GaussianNetwork(rgb_dim=3, depth_dim=1).cuda()
73 |
74 | def prepare_dataset(self, cfg, rank):
75 | if rank == 0:
76 | print('### Preparing Datasets')
77 |
78 | if self.mode == 'train':
79 | self.set_train_dataloader(cfg, rank)
80 | if rank == 0 :
81 | self.set_val_dataloader(cfg)
82 |
83 | if self.mode == 'eval':
84 | self.set_eval_dataloader(cfg)
85 |
86 | def set_train_dataloader(self, cfg, rank):
87 | # jittering augmentation and image resizing for the training data
88 | _augmentation = {
89 | 'image_shape': (int(self.height), int(self.width)),
90 | 'jittering': (0.2, 0.2, 0.2, 0.05),
91 | 'crop_train_borders': (),
92 | 'crop_eval_borders': ()
93 | }
94 |
95 | # construct train dataset
96 | train_dataset = construct_dataset(cfg, 'train', **_augmentation)
97 |
98 | dataloader_opts = {
99 | 'batch_size': self.batch_size,
100 | 'shuffle': True,
101 | 'num_workers': self.num_workers,
102 | 'pin_memory': True,
103 | 'drop_last': True
104 | }
105 |
106 | self._dataloaders['train'] = DataLoader(train_dataset, **dataloader_opts)
107 | num_train_samples = len(train_dataset)
108 | self.num_total_steps = num_train_samples // (self.batch_size * self.world_size) * self.num_epochs
109 |
110 | def set_val_dataloader(self, cfg):
111 | # Image resizing for the validation data
112 | _augmentation = {
113 | 'image_shape': (int(self.height), int(self.width)),
114 | 'jittering': (0.0, 0.0, 0.0, 0.0),
115 | 'crop_train_borders': (),
116 | 'crop_eval_borders': ()
117 | }
118 |
119 | # construct validation dataset
120 | val_dataset = construct_dataset(cfg, 'val', **_augmentation)
121 |
122 | dataloader_opts = {
123 | 'batch_size': self.batch_size,
124 | 'shuffle': False,
125 | 'num_workers': 0,
126 | 'pin_memory': True,
127 | 'drop_last': True
128 | }
129 |
130 | self._dataloaders['val'] = DataLoader(val_dataset, **dataloader_opts)
131 |
132 | def set_eval_dataloader(self, cfg):
133 | # Image resizing for the validation data
134 | _augmentation = {
135 | 'image_shape': (int(self.height), int(self.width)),
136 | 'jittering': (0.0, 0.0, 0.0, 0.0),
137 | 'crop_train_borders': (),
138 | 'crop_eval_borders': ()
139 | }
140 |
141 | dataloader_opts = {
142 | 'batch_size': self.eval_batch_size,
143 | 'shuffle': False,
144 | 'num_workers': self.eval_num_workers,
145 | 'pin_memory': True,
146 | 'drop_last': True
147 | }
148 |
149 | eval_dataset = construct_dataset(cfg, 'eval', **_augmentation)
150 |
151 | self._dataloaders['eval'] = DataLoader(eval_dataset, **dataloader_opts)
152 |
153 | def set_optimizer(self):
154 | parameters_to_train = []
155 | for v in self.models.values():
156 | parameters_to_train += list(v.parameters())
157 |
158 | self.optimizer = optim.Adam(
159 | parameters_to_train,
160 | self.learning_rate
161 | )
162 |
163 | self.lr_scheduler = optim.lr_scheduler.StepLR(
164 | self.optimizer,
165 | self.scheduler_step_size,
166 | 0.1
167 | )
168 |
169 | def process_batch(self, inputs, rank):
170 | """
171 | Pass a minibatch through the network and generate images, depth maps, and losses.
172 | """
173 | for key, ipt in inputs.items():
174 | if key not in _NO_DEVICE_KEYS:
175 | if 'context' in key:
176 | inputs[key] = [ipt[k].float().to(rank) for k in range(len(inputs[key]))]
177 | if 'ego_pose' in key:
178 | inputs[key] = [ipt[k].float().to(rank) for k in range(len(inputs[key]))]
179 | else:
180 | inputs[key] = ipt.float().to(rank)
181 |
182 | outputs = self.estimate(inputs)
183 | losses = self.compute_losses(inputs, outputs)
184 | return outputs, losses
185 |
186 | def estimate(self, inputs):
187 | """
188 | This function estimates the outputs of the network.
189 | """
190 | # pre-calculate inverse of the extrinsic matrix
191 | inputs['extrinsics_inv'] = torch.inverse(inputs['extrinsics'])
192 |
193 | # init dictionary
194 | outputs = {}
195 | for cam in range(self.num_cams):
196 | outputs[('cam', cam)] = {}
197 |
198 | pose_pred = self.predict_pose(inputs)
199 | depth_feats = self.predict_depth(inputs)
200 |
201 | for cam in range(self.num_cams):
202 | if self.mode != 'train':
203 | outputs[('cam', cam)].update({('cam_T_cam', 0, 1): inputs[('cam_T_cam', 0, 1)][:, cam, ...]})
204 | outputs[('cam', cam)].update({('cam_T_cam', 0, -1): inputs[('cam_T_cam', 0, -1)][:, cam, ...]})
205 | elif self.mode == 'train':
206 | outputs[('cam', cam)].update(pose_pred[('cam', cam)])
207 | outputs[('cam', cam)].update(depth_feats[('cam', cam)])
208 |
209 | self.compute_depth_maps(inputs, outputs)
210 | return outputs
211 |
212 | def predict_pose(self, inputs):
213 | """
214 | This function predicts poses.
215 | """
216 | net = self.models['pose_net']
217 |
218 | pose = self.pose.compute_pose(net, inputs)
219 | return pose
220 |
221 | def predict_depth(self, inputs):
222 | """
223 | This function predicts disparity maps.
224 | """
225 | net = self.models['depth_net']
226 |
227 | depth_feats = net(inputs)
228 | return depth_feats
229 |
230 | def compute_depth_maps(self, inputs, outputs):
231 | """
232 | This function computes depth map for each viewpoint.
233 | """
234 | source_scale = 0
235 | for cam in range(self.num_cams):
236 | ref_K = inputs[('K', source_scale)][:, cam, ...]
237 | for scale in self.scales:
238 | disp = outputs[('cam', cam)][('disp', scale)]
239 | outputs[('cam', cam)][('depth', 0, scale)] = self.to_depth(disp, ref_K)
240 | if self.novel_view_mode == 'MF':
241 | disp_last = outputs[('cam', cam)][('disp', -1, scale)]
242 | outputs[('cam', cam)][('depth', -1, scale)] = self.to_depth(disp_last, ref_K)
243 | disp_next = outputs[('cam', cam)][('disp', 1, scale)]
244 | outputs[('cam', cam)][('depth', 1, scale)] = self.to_depth(disp_next, ref_K)
245 |
246 | def to_depth(self, disp_in, K_in):
247 | """
248 | This function transforms disparity value into depth map while multiplying the value with the focal length.
249 | """
250 | min_disp = 1/self.max_depth
251 | max_disp = 1/self.min_depth
252 | disp_range = max_disp-min_disp
253 |
254 | disp_in = F.interpolate(disp_in, [self.height, self.width], mode='bilinear', align_corners=False)
255 | disp = min_disp + disp_range * disp_in
256 | depth = 1/disp
257 | return depth * K_in[:, 0:1, 0:1].unsqueeze(2)/self.focal_length_scale
258 |
259 | def get_gaussian_data(self, inputs, outputs, cam):
260 | """
261 | This function computes gaussian data for each viewpoint.
262 | """
263 | bs, _, height, width = inputs[('color', 0, 0)][:, cam, ...].shape
264 | zfar = self.max_depth
265 | znear = 0.01
266 |
267 | if self.novel_view_mode == 'MF':
268 | for frame_id in self.frame_ids:
269 | if frame_id == 0:
270 | outputs[('cam', cam)][('e2c_extr', frame_id, 0)] = inputs['extrinsics_inv'][:, cam, ...]
271 | outputs[('cam', cam)][('c2e_extr', frame_id, 0)] = inputs['extrinsics'][:, cam, ...]
272 | FovX_list = []
273 | FovY_list = []
274 | world_view_transform_list = []
275 | full_proj_transform_list = []
276 | camera_center_list = []
277 | for i in range(bs):
278 | intr = inputs[('K', 0)][:, cam, ...][i,:]
279 | extr = inputs['extrinsics_inv'][:, cam, ...][i,:]
280 | FovX = focal2fov(intr[0, 0], width)
281 | FovY = focal2fov(intr[1, 1], height)
282 | projection_matrix = getProjectionMatrix(znear=znear, zfar=zfar, K=intr, h=height, w=width).transpose(0, 1).cuda()
283 | world_view_transform = torch.tensor(extr).transpose(0, 1).cuda()
284 | # full_proj_transform: (E^T K^T) = (K E)^T
285 | full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0)
286 | camera_center = world_view_transform.inverse()[3, :3]
287 |
288 | FovX_list.append(FovX)
289 | FovY_list.append(FovY)
290 | world_view_transform_list.append(world_view_transform.unsqueeze(0))
291 | full_proj_transform_list.append(full_proj_transform.unsqueeze(0))
292 | camera_center_list.append(camera_center.unsqueeze(0))
293 |
294 | outputs[('cam', cam)][('FovX', frame_id, 0)] = torch.tensor(FovX_list).cuda()
295 | outputs[('cam', cam)][('FovY', frame_id, 0)] = torch.tensor(FovY_list).cuda()
296 | outputs[('cam', cam)][('world_view_transform', frame_id, 0)] = torch.cat(world_view_transform_list, dim=0)
297 | outputs[('cam', cam)][('full_proj_transform', frame_id, 0)] = torch.cat(full_proj_transform_list, dim=0)
298 | outputs[('cam', cam)][('camera_center', frame_id, 0)] = torch.cat(camera_center_list, dim=0)
299 | else:
300 | outputs[('cam', cam)][('e2c_extr', frame_id, 0)] = \
301 | torch.matmul(outputs[('cam', cam)][('cam_T_cam', 0, frame_id)], inputs['extrinsics_inv'][:, cam, ...])
302 | outputs[('cam', cam)][('c2e_extr', frame_id, 0)] = \
303 | torch.matmul(inputs['extrinsics'][:, cam, ...], torch.inverse(outputs[('cam', cam)][('cam_T_cam', 0, frame_id)]))
304 | outputs[('cam', cam)][('xyz', frame_id, 0)] = depth2pc(outputs[('cam', cam)][('depth', frame_id, 0)], outputs[('cam', cam)][('e2c_extr', frame_id, 0)], inputs[('K', 0)][:, cam, ...])
305 | valid = outputs[('cam', cam)][('depth', frame_id, 0)] != 0.0
306 | outputs[('cam', cam)][('pts_valid', frame_id, 0)] = valid.view(bs, -1)
307 | rot_maps, scale_maps, opacity_maps, sh_maps = \
308 | self.gs_net(inputs[('color', frame_id, 0)][:, cam, ...], outputs[('cam', cam)][('depth', frame_id, 0)], outputs[('cam', cam)][('img_feat', frame_id, 0)])
309 | c2w_rotations = rearrange(outputs[('cam', cam)][('c2e_extr', frame_id, 0)][..., :3, :3], "k i j -> k () () () i j")
310 | sh_maps = rotate_sh(sh_maps, c2w_rotations[..., None, :, :])
311 | outputs[('cam', cam)][('rot_maps', frame_id, 0)] = rot_maps
312 | outputs[('cam', cam)][('scale_maps', frame_id, 0)] = scale_maps
313 | outputs[('cam', cam)][('opacity_maps', frame_id, 0)] = opacity_maps
314 | outputs[('cam', cam)][('sh_maps', frame_id, 0)] = sh_maps
315 | elif self.novel_view_mode == 'SF':
316 | frame_id = 0
317 | outputs[('cam', cam)][('e2c_extr', frame_id, 0)] = inputs['extrinsics_inv'][:, cam, ...]
318 | outputs[('cam', cam)][('c2e_extr', frame_id, 0)] = inputs['extrinsics'][:, cam, ...]
319 | outputs[('cam', cam)][('xyz', frame_id, 0)] = depth2pc(outputs[('cam', cam)][('depth', frame_id, 0)], outputs[('cam', cam)][('e2c_extr', frame_id, 0)], inputs[('K', 0)][:, cam, ...])
320 | valid = outputs[('cam', cam)][('depth', frame_id, 0)] != 0.0
321 | outputs[('cam', cam)][('pts_valid', frame_id, 0)] = valid.view(bs, -1)
322 | rot_maps, scale_maps, opacity_maps, sh_maps = \
323 | self.gs_net(inputs[('color', frame_id, 0)][:, cam, ...], outputs[('cam', cam)][('depth', frame_id, 0)], outputs[('cam', cam)][('img_feat', frame_id, 0)])
324 | c2w_rotations = rearrange(outputs[('cam', cam)][('c2e_extr', frame_id, 0)][..., :3, :3], "k i j -> k () () () i j")
325 | sh_maps = rotate_sh(sh_maps, c2w_rotations[..., None, :, :])
326 | outputs[('cam', cam)][('rot_maps', frame_id, 0)] = rot_maps
327 | outputs[('cam', cam)][('scale_maps', frame_id, 0)] = scale_maps
328 | outputs[('cam', cam)][('opacity_maps', frame_id, 0)] = opacity_maps
329 | outputs[('cam', cam)][('sh_maps', frame_id, 0)] = sh_maps
330 |
331 | # novel view
332 | for frame_id in self.frame_ids[1:]:
333 | outputs[('cam', cam)][('e2c_extr', frame_id, 0)] = \
334 | torch.matmul(outputs[('cam', cam)][('cam_T_cam', 0, frame_id)], inputs['extrinsics_inv'][:, cam, ...])
335 | outputs[('cam', cam)][('c2e_extr', frame_id, 0)] = \
336 | torch.matmul(inputs['extrinsics'][:, cam, ...], torch.inverse(outputs[('cam', cam)][('cam_T_cam', 0, frame_id)]))
337 |
338 | FovX_list = []
339 | FovY_list = []
340 | world_view_transform_list = []
341 | full_proj_transform_list = []
342 | camera_center_list = []
343 | for i in range(bs):
344 | intr = inputs[('K', 0)][:, cam, ...][i,:]
345 | extr = inputs['extrinsics_inv'][:, cam, ...][i,:]
346 | T_i = outputs[('cam', cam)][('cam_T_cam', 0, frame_id)][i,:]
347 | FovX = focal2fov(intr[0, 0], width)
348 | FovY = focal2fov(intr[1, 1], height)
349 | projection_matrix = getProjectionMatrix(znear=znear, zfar=zfar, K=intr, h=height, w=width).transpose(0, 1).cuda()
350 | world_view_transform = torch.matmul(T_i, torch.tensor(extr).cuda()).transpose(0, 1)
351 | # full_proj_transform: (E^T K^T) = (K E)^T
352 | full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0)
353 | camera_center = world_view_transform.inverse()[3, :3]
354 | FovX_list.append(FovX)
355 | FovY_list.append(FovY)
356 | world_view_transform_list.append(world_view_transform.unsqueeze(0))
357 | full_proj_transform_list.append(full_proj_transform.unsqueeze(0))
358 | camera_center_list.append(camera_center.unsqueeze(0))
359 | outputs[('cam', cam)][('FovX', frame_id, 0)] = torch.tensor(FovX_list).cuda()
360 | outputs[('cam', cam)][('FovY', frame_id, 0)] = torch.tensor(FovY_list).cuda()
361 | outputs[('cam', cam)][('world_view_transform', frame_id, 0)] = torch.cat(world_view_transform_list, dim=0)
362 | outputs[('cam', cam)][('full_proj_transform', frame_id, 0)] = torch.cat(full_proj_transform_list, dim=0)
363 | outputs[('cam', cam)][('camera_center', frame_id, 0)] = torch.cat(camera_center_list, dim=0)
364 |
365 | def compute_losses(self, inputs, outputs):
366 | """
367 | This function computes losses.
368 | """
369 | losses = 0
370 | loss_fn = defaultdict(list)
371 | loss_mean = defaultdict(float)
372 |
373 | # compute gaussian data
374 | if self.gaussian:
375 | self.gs_net = self.models['gs_net']
376 | for cam in range(self.num_cams):
377 | self.get_gaussian_data(inputs, outputs, cam)
378 |
379 | # generate image and compute loss per cameara
380 | for cam in range(self.num_cams):
381 | self.pred_cam_imgs(inputs, outputs, cam)
382 | if self.gaussian:
383 | self.pred_gaussian_imgs(inputs, outputs, cam)
384 | cam_loss, loss_dict = self.losses(inputs, outputs, cam)
385 |
386 | losses += cam_loss
387 | for k, v in loss_dict.items():
388 | loss_fn[k].append(v)
389 |
390 | losses /= self.num_cams
391 |
392 | for k in loss_fn.keys():
393 | loss_mean[k] = sum(loss_fn[k]) / float(len(loss_fn[k]))
394 |
395 | loss_mean['total_loss'] = losses
396 | return loss_mean
397 |
398 | def pred_cam_imgs(self, inputs, outputs, cam):
399 | """
400 | This function renders projected images using camera parameters and depth information.
401 | """
402 | rel_pose_dict = self.pose.compute_relative_cam_poses(inputs, outputs, cam)
403 | self.view_rendering(inputs, outputs, cam, rel_pose_dict)
404 |
405 | def pred_gaussian_imgs(self, inputs, outputs, cam):
406 | if self.novel_view_mode == 'MF':
407 | outputs[('cam', cam)][('gaussian_color', 0, 0)] = \
408 | pts2render(inputs=inputs,
409 | outputs=outputs,
410 | cam_num=self.num_cams,
411 | novel_cam=cam,
412 | novel_frame_id=0,
413 | bg_color=[1.0, 1.0, 1.0],
414 | mode=self.novel_view_mode)
415 | elif self.novel_view_mode == 'SF':
416 | for novel_frame_id in self.frame_ids[1:]:
417 | outputs[('cam', cam)][('gaussian_color', novel_frame_id, 0)] = \
418 | pts2render(inputs=inputs,
419 | outputs=outputs,
420 | cam_num=self.num_cams,
421 | novel_cam=cam,
422 | novel_frame_id=novel_frame_id,
423 | bg_color=[1.0, 1.0, 1.0],
424 | mode=self.novel_view_mode)
425 |
426 |
--------------------------------------------------------------------------------
/models/gaussian/GaussianRender.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from .gaussian_renderer import render
4 | from einops import rearrange
5 |
6 | left_cam_dict = {2:0, 0:1, 4:2, 1:3, 5:4, 3:5}
7 | right_cam_dict = {0:2, 1:0, 2:4, 3:1, 4:5, 5:3}
8 |
9 | def get_adj_cams(cam):
10 | adj_cams = [cam]
11 | adj_cams.append(left_cam_dict[cam])
12 | adj_cams.append(right_cam_dict[cam])
13 | return adj_cams
14 |
15 | def pts2render(inputs, outputs, cam_num, novel_cam, novel_frame_id, bg_color, mode='MF'):
16 | bs, _, height, width = inputs[('color', 0, 0)][:, novel_cam, ...].shape
17 | render_novel_list = []
18 | for i in range(bs):
19 | xyz_i_valid = []
20 | # rgb_i_valid = []
21 | rot_i_valid = []
22 | scale_i_valid = []
23 | opacity_i_valid = []
24 | sh_i_valid = []
25 | if mode == 'SF':
26 | frame_id = 0
27 | for cam in range(cam_num):
28 | valid_i = outputs[('cam', cam)][('pts_valid', frame_id, 0)][i, :]
29 | xyz_i = outputs[('cam', cam)][('xyz', frame_id, 0)][i, :, :]
30 | # rgb_i = inputs[('color', frame_id, 0)][:, cam, ...][i, :, :, :].permute(1, 2, 0).view(-1, 3) # HWC
31 |
32 | rot_i = outputs[('cam', cam)][('rot_maps', frame_id, 0)][i, :, :, :].permute(1, 2, 0).view(-1, 4)
33 | scale_i = outputs[('cam', cam)][('scale_maps', frame_id, 0)][i, :, :, :].permute(1, 2, 0).view(-1, 3)
34 | opacity_i = outputs[('cam', cam)][('opacity_maps', frame_id, 0)][i, :, :, :].permute(1, 2, 0).view(-1, 1)
35 | sh_i = rearrange(outputs[('cam', cam)][('sh_maps', frame_id, 0)][i, :, :, :], "p srf r xyz d_sh -> (p srf r) d_sh xyz").contiguous()
36 |
37 | xyz_i_valid.append(xyz_i[valid_i].view(-1, 3))
38 | # rgb_i_valid.append(rgb_i[valid_i].view(-1, 3))
39 | rot_i_valid.append(rot_i[valid_i].view(-1, 4))
40 | scale_i_valid.append(scale_i[valid_i].view(-1, 3))
41 | opacity_i_valid.append(opacity_i[valid_i].view(-1, 1))
42 | sh_i_valid.append(sh_i[valid_i])
43 |
44 | elif mode == 'MF':
45 | for frame_id in [-1, 1]:
46 | cam = novel_cam
47 | valid_i = outputs[('cam', cam)][('pts_valid', frame_id, 0)][i, :]
48 | xyz_i = outputs[('cam', cam)][('xyz', frame_id, 0)][i, :, :]
49 | # rgb_i = inputs[('color', frame_id, 0)][:, cam, ...][i, :, :, :].permute(1, 2, 0).view(-1, 3) # HWC
50 |
51 | rot_i = outputs[('cam', cam)][('rot_maps', frame_id, 0)][i, :, :, :].permute(1, 2, 0).view(-1, 4)
52 | scale_i = outputs[('cam', cam)][('scale_maps', frame_id, 0)][i, :, :, :].permute(1, 2, 0).view(-1, 3)
53 | opacity_i = outputs[('cam', cam)][('opacity_maps', frame_id, 0)][i, :, :, :].permute(1, 2, 0).view(-1, 1)
54 | sh_i = rearrange(outputs[('cam', cam)][('sh_maps', frame_id, 0)][i, :, :, :], "p srf r xyz d_sh -> (p srf r) d_sh xyz").contiguous()
55 |
56 | xyz_i_valid.append(xyz_i[valid_i].view(-1, 3))
57 | # rgb_i_valid.append(rgb_i[valid_i].view(-1, 3))
58 | rot_i_valid.append(rot_i[valid_i].view(-1, 4))
59 | scale_i_valid.append(scale_i[valid_i].view(-1, 3))
60 | opacity_i_valid.append(opacity_i[valid_i].view(-1, 1))
61 | sh_i_valid.append(sh_i[valid_i])
62 |
63 | pts_xyz_i = torch.concat(xyz_i_valid, dim=0)
64 | # pts_rgb_i = torch.concat(rgb_i_valid, dim=0)
65 | # pts_rgb_i = pts_rgb_i * 0.5 + 0.5
66 | rot_i = torch.concat(rot_i_valid, dim=0)
67 | scale_i = torch.concat(scale_i_valid, dim=0)
68 | opacity_i = torch.concat(opacity_i_valid, dim=0)
69 | sh_i = torch.concat(sh_i_valid, dim=0)
70 |
71 | novel_FovX_i = outputs[('cam', novel_cam)][('FovX', novel_frame_id, 0)][i]
72 | novel_FovY_i = outputs[('cam', novel_cam)][('FovY', novel_frame_id, 0)][i]
73 | novel_world_view_transform_i = outputs[('cam', novel_cam)][('world_view_transform', novel_frame_id, 0)][i]
74 | novel_function_proj_transform_i = outputs[('cam', novel_cam)][('full_proj_transform', novel_frame_id, 0)][i]
75 | novel_camera_center_i = outputs[('cam', novel_cam)][('camera_center', novel_frame_id, 0)][i]
76 |
77 | render_novel_i = render(novel_FovX=novel_FovX_i,
78 | novel_FovY=novel_FovY_i,
79 | novel_height=height,
80 | novel_width=width,
81 | novel_world_view_transform=novel_world_view_transform_i,
82 | novel_full_proj_transform=novel_function_proj_transform_i,
83 | novel_camera_center=novel_camera_center_i,
84 | pts_xyz=pts_xyz_i,
85 | pts_rgb=None,
86 | rotations=rot_i,
87 | scales=scale_i,
88 | opacity=opacity_i,
89 | shs=sh_i,
90 | bg_color=bg_color)
91 | render_novel_list.append(render_novel_i.unsqueeze(0))
92 |
93 | novel = torch.concat(render_novel_list, dim=0)
94 |
95 | return novel
96 |
--------------------------------------------------------------------------------
/models/gaussian/__init__.py:
--------------------------------------------------------------------------------
1 | from .gaussian_network import GaussianNetwork
2 | from .GaussianRender import pts2render
3 | from .utils import depth2pc, rotate_sh, focal2fov, getProjectionMatrix, getWorld2View2
4 |
--------------------------------------------------------------------------------
/models/gaussian/extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ResidualBlock(nn.Module):
7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8 | super(ResidualBlock, self).__init__()
9 |
10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12 | self.relu = nn.ReLU(inplace=True)
13 |
14 | num_groups = planes // 8
15 |
16 | if norm_fn == 'group':
17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19 | if not (stride == 1 and in_planes == planes):
20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21 |
22 | elif norm_fn == 'batch':
23 | self.norm1 = nn.BatchNorm2d(planes)
24 | self.norm2 = nn.BatchNorm2d(planes)
25 | if not (stride == 1 and in_planes == planes):
26 | self.norm3 = nn.BatchNorm2d(planes)
27 |
28 | elif norm_fn == 'instance':
29 | self.norm1 = nn.InstanceNorm2d(planes)
30 | self.norm2 = nn.InstanceNorm2d(planes)
31 | if not (stride == 1 and in_planes == planes):
32 | self.norm3 = nn.InstanceNorm2d(planes)
33 |
34 | elif norm_fn == 'none':
35 | self.norm1 = nn.Sequential()
36 | self.norm2 = nn.Sequential()
37 | if not (stride == 1 and in_planes == planes):
38 | self.norm3 = nn.Sequential()
39 |
40 | if stride == 1 and in_planes == planes:
41 | self.downsample = None
42 |
43 | else:
44 | self.downsample = nn.Sequential(
45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46 |
47 |
48 | def forward(self, x):
49 | y = x
50 | y = self.conv1(y)
51 | y = self.norm1(y)
52 | y = self.relu(y)
53 | y = self.conv2(y)
54 | y = self.norm2(y)
55 | y = self.relu(y)
56 |
57 | if self.downsample is not None:
58 | x = self.downsample(x)
59 |
60 | return self.relu(x+y)
61 |
62 |
63 | class UnetExtractor(nn.Module):
64 | def __init__(self, in_channel=3, encoder_dim=[32, 48, 96], norm_fn='group'):
65 | super().__init__()
66 | self.in_ds = nn.Sequential(
67 | nn.Conv2d(in_channel, 32, kernel_size=5, stride=2, padding=2),
68 | nn.GroupNorm(num_groups=8, num_channels=32),
69 | nn.ReLU(inplace=True)
70 | )
71 |
72 | self.res1 = nn.Sequential(
73 | ResidualBlock(32, encoder_dim[0], norm_fn=norm_fn),
74 | ResidualBlock(encoder_dim[0], encoder_dim[0], norm_fn=norm_fn)
75 | )
76 | self.res2 = nn.Sequential(
77 | ResidualBlock(encoder_dim[0], encoder_dim[1], stride=2, norm_fn=norm_fn),
78 | ResidualBlock(encoder_dim[1], encoder_dim[1], norm_fn=norm_fn)
79 | )
80 | self.res3 = nn.Sequential(
81 | ResidualBlock(encoder_dim[1], encoder_dim[2], stride=2, norm_fn=norm_fn),
82 | ResidualBlock(encoder_dim[2], encoder_dim[2], norm_fn=norm_fn),
83 | )
84 |
85 | def forward(self, x):
86 | x = self.in_ds(x)
87 | x1 = self.res1(x)
88 | x2 = self.res2(x1)
89 | x3 = self.res3(x2)
90 |
91 | return x1, x2, x3
92 |
93 |
94 | class MultiBasicEncoder(nn.Module):
95 | def __init__(self, output_dim=[128], encoder_dim=[64, 96, 128]):
96 | super(MultiBasicEncoder, self).__init__()
97 |
98 | # output convolution for feature
99 | self.conv2 = nn.Sequential(
100 | ResidualBlock(encoder_dim[2], encoder_dim[2], stride=1),
101 | nn.Conv2d(encoder_dim[2], encoder_dim[2]*2, 3, padding=1))
102 |
103 | # output convolution for context
104 | output_list = []
105 | for dim in output_dim:
106 | conv_out = nn.Sequential(
107 | ResidualBlock(encoder_dim[2], encoder_dim[2], stride=1),
108 | nn.Conv2d(encoder_dim[2], dim[2], 3, padding=1))
109 | output_list.append(conv_out)
110 |
111 | self.outputs08 = nn.ModuleList(output_list)
112 |
113 | def forward(self, x):
114 | feat1, feat2 = self.conv2(x).split(dim=0, split_size=x.shape[0]//2)
115 |
116 | outputs08 = [f(x) for f in self.outputs08]
117 | return outputs08, feat1, feat2
118 |
119 |
120 | if __name__ == '__main__':
121 |
122 | data = torch.ones((1, 3, 1024, 1024))
123 |
124 | model = UnetExtractor(in_channel=3, encoder_dim=[64, 96, 128])
125 |
126 | x1, x2, x3 = model(data)
127 | print(x1.shape, x2.shape, x3.shape)
128 |
--------------------------------------------------------------------------------
/models/gaussian/gaussian_network.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from torch import nn
4 | from .extractor import UnetExtractor, ResidualBlock
5 | from einops import rearrange
6 |
7 |
8 | class GaussianNetwork(nn.Module):
9 | def __init__(self, rgb_dim=3, depth_dim=1, norm_fn='group'):
10 | super().__init__()
11 | self.rgb_dims = [64, 64, 128]
12 | self.depth_dims = [32, 48, 96]
13 | self.decoder_dims = [48, 64, 96]
14 | self.head_dim = 32
15 |
16 | self.sh_degree = 4
17 | self.d_sh = (self.sh_degree + 1) ** 2
18 |
19 | self.register_buffer(
20 | "sh_mask",
21 | torch.ones((self.d_sh,), dtype=torch.float32),
22 | persistent=False,
23 | )
24 | for degree in range(1, self.sh_degree + 1):
25 | self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.1 * 0.25**degree
26 |
27 | self.depth_encoder = UnetExtractor(in_channel=depth_dim, encoder_dim=self.depth_dims)
28 |
29 | self.decoder3 = nn.Sequential(
30 | ResidualBlock(self.rgb_dims[2]+self.depth_dims[2], self.decoder_dims[2], norm_fn=norm_fn),
31 | ResidualBlock(self.decoder_dims[2], self.decoder_dims[2], norm_fn=norm_fn)
32 | )
33 |
34 | self.decoder2 = nn.Sequential(
35 | ResidualBlock(self.rgb_dims[1]+self.depth_dims[1]+self.decoder_dims[2], self.decoder_dims[1], norm_fn=norm_fn),
36 | ResidualBlock(self.decoder_dims[1], self.decoder_dims[1], norm_fn=norm_fn)
37 | )
38 |
39 | self.decoder1 = nn.Sequential(
40 | ResidualBlock(self.rgb_dims[0]+self.depth_dims[0]+self.decoder_dims[1], self.decoder_dims[0], norm_fn=norm_fn),
41 | ResidualBlock(self.decoder_dims[0], self.decoder_dims[0], norm_fn=norm_fn)
42 | )
43 | self.up = nn.Upsample(scale_factor=2, mode="bilinear")
44 | self.out_conv = nn.Conv2d(self.decoder_dims[0]+rgb_dim+1, self.head_dim, kernel_size=3, padding=1)
45 | self.out_relu = nn.ReLU(inplace=True)
46 |
47 | self.rot_head = nn.Sequential(
48 | nn.Conv2d(self.head_dim, self.head_dim, kernel_size=3, padding=1),
49 | nn.ReLU(inplace=True),
50 | nn.Conv2d(self.head_dim, 4, kernel_size=1),
51 | )
52 | self.scale_head = nn.Sequential(
53 | nn.Conv2d(self.head_dim, self.head_dim, kernel_size=3, padding=1),
54 | nn.ReLU(inplace=True),
55 | nn.Conv2d(self.head_dim, 3, kernel_size=1),
56 | nn.Softplus(beta=100)
57 | )
58 | self.opacity_head = nn.Sequential(
59 | nn.Conv2d(self.head_dim, self.head_dim, kernel_size=3, padding=1),
60 | nn.ReLU(inplace=True),
61 | nn.Conv2d(self.head_dim, 1, kernel_size=1),
62 | nn.Sigmoid()
63 | )
64 | self.sh_head = nn.Sequential(
65 | nn.Conv2d(self.head_dim, self.head_dim, kernel_size=3, padding=1),
66 | nn.ReLU(inplace=True),
67 | nn.Conv2d(self.head_dim, 3 * self.d_sh, kernel_size=1),
68 | )
69 |
70 | def forward(self, img, depth, img_feat):
71 | # img_feat1: [4, 64, 176, 320]
72 | # img_feat2: [4, 64, 88, 160]
73 | # img_feat3: [4, 128, 44, 80]
74 | img_feat1, img_feat2, img_feat3 = img_feat
75 | # depth_feat1: [4, 32, 176, 320]
76 | # depth_feat2: [4, 48, 88, 160]
77 | # depth_feat3: [4, 96, 44, 80]
78 | depth_feat1, depth_feat2, depth_feat3 = self.depth_encoder(depth)
79 |
80 | feat3 = torch.concat([img_feat3, depth_feat3], dim=1)
81 | feat2 = torch.concat([img_feat2, depth_feat2], dim=1)
82 | feat1 = torch.concat([img_feat1, depth_feat1], dim=1)
83 |
84 | up3 = self.decoder3(feat3)
85 | up3 = self.up(up3)
86 | up2 = self.decoder2(torch.cat([up3, feat2], dim=1))
87 | up2 = self.up(up2)
88 | up1 = self.decoder1(torch.cat([up2, feat1], dim=1))
89 |
90 | up1 = self.up(up1)
91 | out = torch.cat([up1, img, depth], dim=1)
92 | out = self.out_conv(out)
93 | out = self.out_relu(out)
94 |
95 | # rot head
96 | rot_out = self.rot_head(out)
97 | rot_out = torch.nn.functional.normalize(rot_out, dim=1)
98 |
99 | # scale head
100 | scale_out = torch.clamp_max(self.scale_head(out), 0.01)
101 |
102 | # opacity head
103 | opacity_out = self.opacity_head(out)
104 |
105 | # sh head
106 | sh_out = self.sh_head(out)
107 | # sh_out: [(b * v), C, H, W]
108 |
109 | sh_out = rearrange(
110 | sh_out, "n c h w -> n (h w) c",
111 | )
112 | sh_out = rearrange(
113 | sh_out,
114 | "... (srf c) -> ... srf () c",
115 | srf=1,
116 | )
117 |
118 | sh_out = rearrange(sh_out, "... (xyz d_sh) -> ... xyz d_sh", xyz=3)
119 | # [(b * v), (H * W), 1, 1 3, 25]
120 |
121 | # sh_out = sh_out.broadcast_to(sh_out.shape) * self.sh_mask
122 | sh_out = sh_out * self.sh_mask
123 |
124 |
125 | return rot_out, scale_out, opacity_out, sh_out
126 |
--------------------------------------------------------------------------------
/models/gaussian/gaussian_renderer/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
4 |
5 |
6 | def render(novel_FovX,
7 | novel_FovY,
8 | novel_height,
9 | novel_width,
10 | novel_world_view_transform,
11 | novel_full_proj_transform,
12 | novel_camera_center,
13 | pts_xyz,
14 | pts_rgb,
15 | rotations,
16 | scales,
17 | opacity,
18 | shs,
19 | bg_color):
20 | """
21 | Render the scene.
22 |
23 | Background tensor (bg_color) must be on GPU!
24 | """
25 | bg_color = torch.tensor(bg_color, dtype=torch.float32).cuda()
26 |
27 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
28 | screenspace_points = torch.zeros_like(pts_xyz, dtype=torch.float32, requires_grad=True).cuda()
29 | try:
30 | screenspace_points.retain_grad()
31 | except:
32 | pass
33 |
34 | # Set up rasterization configuration
35 | tanfovx = math.tan(novel_FovX * 0.5)
36 | tanfovy = math.tan(novel_FovY * 0.5)
37 |
38 | raster_settings = GaussianRasterizationSettings(
39 | image_height=int(novel_height),
40 | image_width=int(novel_width),
41 | tanfovx=tanfovx,
42 | tanfovy=tanfovy,
43 | bg=bg_color,
44 | scale_modifier=1.0,
45 | viewmatrix=novel_world_view_transform,
46 | projmatrix=novel_full_proj_transform,
47 | sh_degree=3,
48 | campos=novel_camera_center,
49 | prefiltered=False,
50 | debug=False,
51 | antialiasing=False
52 | )
53 |
54 | rasterizer = GaussianRasterizer(raster_settings=raster_settings)
55 |
56 | # Rasterize visible Gaussians to image, obtain their radii (on screen).
57 | rendered_image, _, _ = rasterizer(
58 | means3D=pts_xyz,
59 | means2D=screenspace_points,
60 | shs=shs,
61 | colors_precomp=pts_rgb,
62 | opacities=opacity,
63 | scales=scales,
64 | rotations=rotations,
65 | cov3D_precomp=None)
66 |
67 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
68 | # They will be excluded from value updates used in the splitting criteria.
69 |
70 | return rendered_image
71 |
--------------------------------------------------------------------------------
/models/gaussian/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from jaxtyping import Float
4 | from math import isqrt
5 | from e3nn.o3 import matrix_to_angles, wigner_D
6 | from einops import einsum
7 | import math
8 | import numpy as np
9 |
10 | def depth2pc(depth, extrinsic, intrinsic):
11 | B, C, H, W = depth.shape
12 | depth = depth[:, 0, :, :]
13 | rot = extrinsic[:, :3, :3]
14 | trans = extrinsic[:, :3, 3:]
15 |
16 | y, x = torch.meshgrid(torch.linspace(0.5, H-0.5, H, device=depth.device), torch.linspace(0.5, W-0.5, W, device=depth.device))
17 | pts_2d = torch.stack([x, y, torch.ones_like(x)], dim=-1).unsqueeze(0).repeat(B, 1, 1, 1) # B H W 3
18 |
19 | pts_2d[..., 2] = depth
20 | pts_2d[:, :, :, 0] -= intrinsic[:, None, None, 0, 2]
21 | pts_2d[:, :, :, 1] -= intrinsic[:, None, None, 1, 2]
22 | pts_2d_xy = pts_2d[:, :, :, :2] * pts_2d[:, :, :, 2:]
23 | pts_2d = torch.cat([pts_2d_xy, pts_2d[..., 2:]], dim=-1)
24 |
25 | pts_2d[..., 0] /= intrinsic[:, 0, 0][:, None, None]
26 | pts_2d[..., 1] /= intrinsic[:, 1, 1][:, None, None]
27 |
28 | pts_2d = pts_2d.view(B, -1, 3).permute(0, 2, 1)
29 |
30 | rot_t = rot.permute(0, 2, 1)
31 | pts = torch.bmm(rot_t, pts_2d) - torch.bmm(rot_t, trans)
32 |
33 | return pts.permute(0, 2, 1)
34 |
35 | def rotate_sh(
36 | sh_coefficients: Float[Tensor, "*#batch n"],
37 | rotations: Float[Tensor, "*#batch 3 3"],
38 | ) -> Float[Tensor, "*batch n"]:
39 | device = sh_coefficients.device
40 | dtype = sh_coefficients.dtype
41 |
42 | *_, n = sh_coefficients.shape
43 | alpha, beta, gamma = matrix_to_angles(rotations)
44 | result = []
45 | for degree in range(isqrt(n)):
46 | sh_rotations = wigner_D(degree, alpha, beta, gamma).type(dtype)
47 | sh_rotated = einsum(
48 | sh_rotations,
49 | sh_coefficients[..., degree**2 : (degree + 1) ** 2],
50 | "... i j, ... j -> ... i",
51 | )
52 | result.append(sh_rotated)
53 |
54 | return torch.cat(result, dim=-1)
55 |
56 | def focal2fov(focal, pixels):
57 | return 2*math.atan(pixels/(2*focal))
58 |
59 | def getProjectionMatrix(znear, zfar, K, h, w):
60 | near_fx = znear / K[0, 0]
61 | near_fy = znear / K[1, 1]
62 | left = - (w - K[0, 2]) * near_fx
63 | right = K[0, 2] * near_fx
64 | bottom = (K[1, 2] - h) * near_fy
65 | top = K[1, 2] * near_fy
66 |
67 | P = torch.zeros(4, 4)
68 | z_sign = 1.0
69 | P[0, 0] = 2.0 * znear / (right - left)
70 | P[1, 1] = 2.0 * znear / (top - bottom)
71 | P[0, 2] = (right + left) / (right - left)
72 | P[1, 2] = (top + bottom) / (top - bottom)
73 | P[3, 2] = z_sign
74 | P[2, 2] = z_sign * zfar / (zfar - znear)
75 | P[2, 3] = -(zfar * znear) / (zfar - znear)
76 | return P
77 |
78 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
79 | Rt = np.zeros((4, 4))
80 | Rt[:3, :3] = R.transpose()
81 | Rt[:3, 3] = t
82 | Rt[3, 3] = 1.0
83 |
84 | C2W = np.linalg.inv(Rt)
85 | cam_center = C2W[:3, 3]
86 | cam_center = (cam_center + translate) * scale
87 | C2W[:3, 3] = cam_center
88 | Rt = np.linalg.inv(C2W)
89 | return np.float32(Rt)
--------------------------------------------------------------------------------
/models/geometry/__init__.py:
--------------------------------------------------------------------------------
1 | from .pose import Pose
2 | from .view_rendering import ViewRendering
3 |
4 | __all__ = ['Pose', 'ViewRendering']
--------------------------------------------------------------------------------
/models/geometry/geometry_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from pytorch3d.transforms import axis_angle_to_matrix
5 |
6 |
7 | def vec_to_matrix(rot_angle, trans_vec, invert=False):
8 | """
9 | This function transforms rotation angle and translation vector into 4x4 matrix.
10 | """
11 | # initialize matrices
12 | b, _, _ = rot_angle.shape
13 | R_mat = torch.eye(4).repeat([b, 1, 1]).to(device=rot_angle.device)
14 | T_mat = torch.eye(4).repeat([b, 1, 1]).to(device=rot_angle.device)
15 |
16 | R_mat[:, :3, :3] = axis_angle_to_matrix(rot_angle).squeeze(1)
17 | t_vec = trans_vec.clone().contiguous().view(-1, 3, 1)
18 |
19 | if invert == True:
20 | R_mat = R_mat.transpose(1,2)
21 | t_vec = -1 * t_vec
22 |
23 | T_mat[:, :3, 3:] = t_vec
24 |
25 | if invert == True:
26 | P_mat = torch.matmul(R_mat, T_mat)
27 | else :
28 | P_mat = torch.matmul(T_mat, R_mat)
29 | return P_mat
30 |
31 |
32 | class Projection(nn.Module):
33 | """
34 | This class computes projection and reprojection function.
35 | """
36 | def __init__(self, batch_size, height, width, device):
37 | super().__init__()
38 | self.batch_size = batch_size
39 | self.width = width
40 | self.height = height
41 |
42 | # initialize img point grid
43 | img_points = np.meshgrid(range(width), range(height), indexing='xy')
44 | img_points = torch.from_numpy(np.stack(img_points, 0)).float()
45 | img_points = torch.stack([img_points[0].view(-1), img_points[1].view(-1)], 0).repeat(batch_size, 1, 1)
46 | img_points = img_points.to(device)
47 |
48 | self.to_homo = torch.ones([batch_size, 1, width*height]).to(device)
49 | self.homo_points = torch.cat([img_points, self.to_homo], 1)
50 |
51 | def backproject(self, invK, depth):
52 | """
53 | This function back-projects 2D image points to 3D.
54 | """
55 | depth = depth.view(self.batch_size, 1, -1)
56 |
57 | points3D = torch.matmul(invK[:, :3, :3], self.homo_points)
58 | points3D = depth*points3D
59 | return torch.cat([points3D, self.to_homo], 1)
60 |
61 | def reproject(self, K, points3D, T):
62 | """
63 | This function reprojects transformed 3D points to 2D image coordinate.
64 | """
65 | # project points
66 | points2D = (K @ T)[:,:3, :] @ points3D
67 |
68 | # normalize projected points for grid sample function
69 | norm_points2D = points2D[:, :2, :]/(points2D[:, 2:, :] + 1e-7)
70 | norm_points2D = norm_points2D.view(self.batch_size, 2, self.height, self.width)
71 | norm_points2D = norm_points2D.permute(0, 2, 3, 1)
72 |
73 | norm_points2D[..., 0 ] /= self.width - 1
74 | norm_points2D[..., 1 ] /= self.height - 1
75 | norm_points2D = (norm_points2D-0.5)*2
76 | return norm_points2D
77 |
78 | def forward(self, depth, T, bp_invK, rp_K):
79 | cam_points = self.backproject(bp_invK, depth)
80 | pix_coords = self.reproject(rp_K, cam_points, T)
81 | return pix_coords
--------------------------------------------------------------------------------
/models/geometry/pose.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .geometry_util import vec_to_matrix
4 |
5 |
6 | class Pose:
7 | """
8 | Class for multi-camera pose calculation
9 | """
10 | def __init__(self, cfg):
11 | self.read_config(cfg)
12 |
13 | def read_config(self, cfg):
14 | for attr in cfg.keys():
15 | for k, v in cfg[attr].items():
16 | setattr(self, k, v)
17 |
18 | def compute_pose(self, net, inputs):
19 | """
20 | This function computes multi-camera posse in accordance with the network structure.
21 | """
22 | pose = self.get_single_pose(net, inputs, None)
23 | pose = self.distribute_pose(pose, inputs['extrinsics'], inputs['extrinsics_inv'])
24 |
25 | return pose
26 |
27 | def get_single_pose(self, net, inputs ,cam):
28 | """
29 | This function computes pose for a single camera.
30 | """
31 | output = {}
32 | for f_i in self.frame_ids[1:]:
33 | # To maintain ordering we always pass frames in temporal order
34 | frame_ids = [-1, 0] if f_i < 0 else [0, 1]
35 | axisangle, translation = net(inputs, frame_ids, cam)
36 | output[('cam_T_cam', 0, f_i)] = vec_to_matrix(axisangle[:, 0], translation[:, 0], invert=(f_i < 0))
37 | return output
38 |
39 | def distribute_pose(self, poses, exts, exts_inv):
40 | """
41 | This function distrubutes pose to each camera by using the canonical pose and camera extrinsics.
42 | (default: reference camera 0)
43 | """
44 | outputs = {}
45 | for cam in range(self.num_cams):
46 | outputs[('cam',cam)] = {}
47 | # Refernce camera(canonical)
48 | ref_ext = exts[:, 0, ...]
49 | ref_ext_inv = exts_inv[:, 0, ...]
50 | for f_i in self.frame_ids[1:]:
51 | ref_T = poses['cam_T_cam', 0, f_i].float() # canonical pose
52 | # Relative cameras(canonical)
53 | for cam in range(self.num_cams):
54 | cur_ext = exts[:,cam,...]
55 | cur_ext_inv = exts_inv[:,cam,...]
56 | cur_T = cur_ext_inv @ ref_ext @ ref_T @ ref_ext_inv @ cur_ext
57 |
58 | outputs[('cam',cam)][('cam_T_cam', 0, f_i)] = cur_T
59 | return outputs
60 |
61 | def compute_relative_cam_poses(self, inputs, outputs, cam):
62 | """
63 | This function computes spatio & spatio-temporal transformation for images from different viewpoints.
64 | """
65 | ref_ext = inputs['extrinsics'][:, cam, ...]
66 | target_view = outputs[('cam', cam)]
67 |
68 | rel_pose_dict = {}
69 | # precompute the relative pose
70 | if self.spatio:
71 | # current time step (spatio)
72 | for cur_index in self.rel_cam_list[cam]:
73 | # for partial surround view training
74 | if cur_index >= self.num_cams:
75 | continue
76 |
77 | cur_ext_inv = inputs['extrinsics_inv'][:, cur_index, ...]
78 | rel_pose_dict[(0, cur_index)] = torch.matmul(cur_ext_inv, ref_ext)
79 |
80 | if self.spatio_temporal:
81 | # different time step (spatio-temporal)
82 | for frame_id in self.frame_ids[1:]:
83 | for cur_index in self.rel_cam_list[cam]:
84 | # for partial surround view training
85 | if cur_index >= self.num_cams:
86 | continue
87 |
88 | T = target_view[('cam_T_cam', 0, frame_id)]
89 | # assuming that extrinsic doesn't change
90 | rel_ext = rel_pose_dict[(0, cur_index)]
91 | rel_pose_dict[(frame_id, cur_index)] = torch.matmul(rel_ext, T) # using matmul speed up
92 | return rel_pose_dict
--------------------------------------------------------------------------------
/models/geometry/view_rendering.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .geometry_util import Projection
6 |
7 |
8 | class ViewRendering(nn.Module):
9 | """
10 | Class for rendering images from given camera parameters and pixel wise depth information
11 | """
12 | def __init__(self, cfg, rank):
13 | super().__init__()
14 | self.read_config(cfg)
15 | self.rank = rank
16 | self.project = self.init_project_imgs(rank)
17 |
18 | def read_config(self, cfg):
19 | for attr in cfg.keys():
20 | for k, v in cfg[attr].items():
21 | setattr(self, k, v)
22 |
23 | def init_project_imgs(self, rank):
24 | project_imgs = {}
25 | project_imgs = Projection(
26 | self.batch_size, self.height, self.width, rank)
27 | return project_imgs
28 |
29 | def get_mean_std(self, feature, mask):
30 | """
31 | This function returns mean and standard deviation of the overlapped features.
32 | """
33 | _, c, h, w = mask.size()
34 | mean = (feature * mask).sum(dim=(1,2,3), keepdim=True) / (mask.sum(dim=(1,2,3), keepdim=True) + 1e-8)
35 | var = ((feature - mean) ** 2).sum(dim=(1,2,3), keepdim=True) / (c*h*w)
36 | return mean, torch.sqrt(var + 1e-16)
37 |
38 | def get_norm_image_single(self, src_img, src_mask, warp_img, warp_mask):
39 | """
40 | obtain normalized warped images using the mean and the variance from the overlapped regions of the target frame.
41 | """
42 | warp_mask = warp_mask.detach()
43 |
44 | with torch.no_grad():
45 | mask = (src_mask * warp_mask).bool()
46 | if mask.size(1) != 3:
47 | mask = mask.repeat(1,3,1,1)
48 |
49 | mask_sum = mask.sum(dim=(-3,-2,-1))
50 | # skip when there is no overlap
51 | if torch.any(mask_sum == 0):
52 | return warp_img
53 |
54 | s_mean, s_std = self.get_mean_std(src_img, mask)
55 | w_mean, w_std = self.get_mean_std(warp_img, mask)
56 |
57 | norm_warp = (warp_img - w_mean) / (w_std + 1e-8) * s_std + s_mean
58 | return norm_warp * warp_mask.float()
59 |
60 | def get_virtual_image(self, src_img, src_mask, tar_depth, tar_invK, src_K, T, scale=0):
61 | """
62 | This function warps source image to target image using backprojection and reprojection process.
63 | """
64 | # do reconstruction for target from source
65 | pix_coords = self.project(tar_depth, T, tar_invK, src_K)
66 |
67 | img_warped = F.grid_sample(src_img, pix_coords, mode='bilinear',
68 | padding_mode='zeros', align_corners=True)
69 | mask_warped = F.grid_sample(src_mask, pix_coords, mode='nearest',
70 | padding_mode='zeros', align_corners=True)
71 |
72 | # nan handling
73 | inf_img_regions = torch.isnan(img_warped)
74 | img_warped[inf_img_regions] = 2.0
75 | inf_mask_regions = torch.isnan(mask_warped)
76 | mask_warped[inf_mask_regions] = 0
77 |
78 | pix_coords = pix_coords.permute(0, 3, 1, 2)
79 | invalid_mask = torch.logical_or(pix_coords > 1,
80 | pix_coords < -1).sum(dim=1, keepdim=True) > 0
81 | return img_warped, (~invalid_mask).float() * mask_warped
82 |
83 | def get_virtual_depth(self, src_depth, src_mask, src_invK, src_K, tar_depth, tar_invK, tar_K, T, min_depth, max_depth, scale=0):
84 | """
85 | This function backward-warp source depth into the target coordinate.
86 | src -> target
87 | """
88 | # transform source depth
89 | b, _, h, w = src_depth.size()
90 | src_points = self.project.backproject(src_invK, src_depth)
91 | src_points_warped = torch.matmul(T[:, :3, :], src_points)
92 | src_depth_warped = src_points_warped.reshape(b, 3, h, w)[:, 2:3, :, :]
93 |
94 | # reconstruct depth: backward-warp source depth to the target coordinate
95 | pix_coords = self.project(tar_depth, torch.inverse(T), tar_invK, src_K)
96 | depth_warped = F.grid_sample(src_depth_warped, pix_coords, mode='bilinear',
97 | padding_mode='zeros', align_corners=True)
98 | mask_warped = F.grid_sample(src_mask, pix_coords, mode='nearest',
99 | padding_mode='zeros', align_corners=True)
100 |
101 | # nan handling
102 | inf_depth = torch.isnan(depth_warped)
103 | depth_warped[inf_depth] = 2.0
104 | inf_regions = torch.isnan(mask_warped)
105 | mask_warped[inf_regions] = 0
106 |
107 | pix_coords = pix_coords.permute(0, 3, 1, 2)
108 | invalid_mask = torch.logical_or(pix_coords > 1, pix_coords < -1).sum(dim=1, keepdim=True) > 0
109 |
110 | # range handling
111 | valid_depth_min = (depth_warped > min_depth)
112 | depth_warped[~valid_depth_min] = min_depth
113 | valid_depth_max = (depth_warped < max_depth)
114 | depth_warped[~valid_depth_max] = max_depth
115 | return depth_warped, (~invalid_mask).float() * mask_warped * valid_depth_min * valid_depth_max
116 |
117 | def forward(self, inputs, outputs, cam, rel_pose_dict):
118 | # predict images for each scale(default = scale 0 only)
119 | source_scale = 0
120 |
121 | # ref inputs
122 | ref_color = inputs['color', 0, source_scale][:,cam, ...]
123 | ref_mask = inputs['mask'][:, cam, ...]
124 | ref_K = inputs[('K', source_scale)][:,cam, ...]
125 | ref_invK = inputs[('inv_K', source_scale)][:,cam, ...]
126 |
127 | # output
128 | target_view = outputs[('cam', cam)]
129 |
130 | for scale in self.scales:
131 | ref_depth = target_view[('depth', 0, scale)]
132 | for frame_id in self.frame_ids[1:]:
133 | # for temporal learning
134 | T = target_view[('cam_T_cam', 0, frame_id)]
135 | src_color = inputs['color', frame_id, source_scale][:, cam, ...]
136 | src_mask = inputs['mask'][:, cam, ...]
137 | warped_img, warped_mask = self.get_virtual_image(
138 | src_color,
139 | src_mask,
140 | ref_depth,
141 | ref_invK,
142 | ref_K,
143 | T,
144 | source_scale
145 | )
146 |
147 | if self.intensity_align:
148 | warped_img = self.get_norm_image_single(
149 | ref_color,
150 | ref_mask,
151 | warped_img,
152 | warped_mask
153 | )
154 |
155 | target_view[('color', frame_id, scale)] = warped_img
156 | target_view[('color_mask', frame_id, scale)] = warped_mask
157 |
158 | # spatio-temporal learning
159 | if self.spatio or self.spatio_temporal:
160 | for frame_id in self.frame_ids:
161 | overlap_img = torch.zeros_like(ref_color)
162 | overlap_mask = torch.zeros_like(ref_mask)
163 |
164 | for cur_index in self.rel_cam_list[cam]:
165 | # for partial surround view training
166 | if cur_index >= self.num_cams:
167 | continue
168 |
169 | src_color = inputs['color', frame_id, source_scale][:, cur_index, ...]
170 | src_mask = inputs['mask'][:, cur_index, ...]
171 | src_K = inputs[('K', source_scale)][:, cur_index, ...]
172 |
173 | rel_pose = rel_pose_dict[(frame_id, cur_index)]
174 | warped_img, warped_mask = self.get_virtual_image(
175 | src_color,
176 | src_mask,
177 | ref_depth,
178 | ref_invK,
179 | src_K,
180 | rel_pose,
181 | source_scale
182 | )
183 |
184 | if self.intensity_align:
185 | warped_img = self.get_norm_image_single(
186 | ref_color,
187 | ref_mask,
188 | warped_img,
189 | warped_mask
190 | )
191 |
192 | # assuming no overlap between warped images
193 | overlap_img = overlap_img + warped_img
194 | overlap_mask = overlap_mask + warped_mask
195 |
196 | target_view[('overlap', frame_id, scale)] = overlap_img
197 | target_view[('overlap_mask', frame_id, scale)] = overlap_mask
198 |
199 | outputs[('cam', cam)] = target_view
--------------------------------------------------------------------------------
/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .single_cam_loss import SingleCamLoss
2 | from .multi_cam_loss import MultiCamLoss
3 |
4 | __all__ = ['SingleCamLoss', 'MultiCamLoss']
--------------------------------------------------------------------------------
/models/losses/base_loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class BaseLoss(nn.Module):
5 | """
6 | Base class loss calculation
7 | """
8 | def __init__(self, cfg, rank):
9 | super().__init__()
10 | self.rank = rank
11 | self.init_weights(cfg)
12 | self.init_attrib(cfg)
13 |
14 | def init_weights(self, cfg):
15 | for attr in cfg.keys():
16 | if attr == 'loss' or attr == 'training':
17 | for k, v in cfg[attr].items():
18 | setattr(self, k, v)
19 |
20 | def init_attrib(self, cfg):
21 | for attr in cfg.keys():
22 | for k, v in cfg[attr].items():
23 | if k == 'scales' or k=='frame_ids' or k == 'novel_view_mode':
24 | setattr(self, k, v)
25 |
26 | def get_logs(self, loss_dict, output, cam):
27 | """
28 | This function logs depth and pose information for monitoring training process.
29 | """
30 | # log statistics
31 | depth_log = output[('depth', 0, 0)].clone().detach()
32 | loss_dict['depth/mean'] = depth_log.mean()
33 | loss_dict['depth/max'] = depth_log.max()
34 | loss_dict['depth/min'] = depth_log.min()
35 |
36 | if cam == 0:
37 | pose_t = output[('cam_T_cam', 0, -1)].clone().detach()
38 | loss_dict['pose/tx'] = pose_t[:, 0, 3].abs().mean()
39 | loss_dict['pose/ty'] = pose_t[:, 1, 3].abs().mean()
40 | loss_dict['pose/tz'] = pose_t[:, 2, 3].abs().mean()
41 |
42 | return loss_dict
43 |
44 | def forward(self, *args, **kwargs):
45 | raise NotImplementedError('Not implemented for BaseLoss')
--------------------------------------------------------------------------------
/models/losses/loss_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def compute_auto_masks(reprojection_loss, identity_reprojection_loss):
6 | """
7 | This function computes auto mask using reprojection loss and identity reprojection loss.
8 | """
9 | if identity_reprojection_loss is None:
10 | # without using auto(identity loss) mask
11 | reprojection_loss_mask = torch.ones_like(reprojection_loss)
12 | else:
13 | # using auto(identity loss) mask
14 | losses = torch.cat([reprojection_loss, identity_reprojection_loss], dim=1)
15 | idxs = torch.argmin(losses, dim=1, keepdim=True)
16 | reprojection_loss_mask = (idxs == 0).float()
17 | return reprojection_loss_mask
18 |
19 |
20 | def compute_masked_loss(loss, mask):
21 | """
22 | This function masks losses while avoiding zero division.
23 | """
24 | return (loss * mask).sum() / (mask.sum() + 1e-8)
25 |
26 |
27 | def compute_edg_smooth_loss(rgb, disp_map):
28 | """
29 | This function calculates edge-aware smoothness.
30 | """
31 | grad_rgb_x = (rgb[:, :, :, :-1] - rgb[:, :, :, 1:]).abs().mean(1, True)
32 | grad_rgb_y = (rgb[:, :, :-1, :] - rgb[:, :, 1:, :]).abs().mean(1, True)
33 |
34 | grad_disp_x = (disp_map[:, :, :, :-1] - disp_map[:, :, :, 1:]).abs()
35 | grad_disp_y = (disp_map[:, :, :-1, :] - disp_map[:, :, 1:, :]).abs()
36 |
37 | grad_disp_x *= (-1.0 * grad_rgb_x).exp()
38 | grad_disp_y *= (-1.0 * grad_rgb_y).exp()
39 | return grad_disp_x.mean() + grad_disp_y.mean()
40 |
41 |
42 | def compute_ssim_loss(pred, target):
43 | """
44 | This function calculates SSIM loss between predicted image and target image.
45 | """
46 | ref_pad = torch.nn.ReflectionPad2d(1)
47 | pred = ref_pad(pred)
48 | target = ref_pad(target)
49 |
50 | mu_pred = F.avg_pool2d(pred, kernel_size = 3, stride = 1)
51 | mu_target = F.avg_pool2d(target, kernel_size = 3, stride = 1)
52 |
53 | musq_pred = mu_pred.pow(2)
54 | musq_target = mu_target.pow(2)
55 | mu_pred_target = mu_pred*mu_target
56 |
57 | sigma_pred = F.avg_pool2d(pred.pow(2), kernel_size = 3, stride = 1)-musq_pred
58 | sigma_target = F.avg_pool2d(target.pow(2), kernel_size = 3, stride = 1)-musq_target
59 | sigma_pred_target = F.avg_pool2d(pred*target, kernel_size = 3, stride = 1)-mu_pred_target
60 |
61 | C1 = 0.01**2
62 | C2 = 0.03**2
63 |
64 | ssim_map = ((2*mu_pred_target + C1)*(2*sigma_pred_target + C2)) \
65 | /((musq_pred + musq_target + C1)*(sigma_pred + sigma_target + C2)+1e-8)
66 | return torch.clamp((1-ssim_map)/2, 0, 1)
67 |
68 |
69 | def compute_photometric_loss(pred=None, target=None):
70 | """
71 | This function calculates photometric reconstruction loss (0.85*SSIM + 0.15*L1)
72 | """
73 | abs_diff = torch.abs(target - pred)
74 | l1_loss = abs_diff.mean(1, True)
75 | ssim_loss = compute_ssim_loss(pred, target).mean(1, True)
76 | rep_loss = 0.85 * ssim_loss + 0.15 * l1_loss
77 | return rep_loss
78 |
--------------------------------------------------------------------------------
/models/losses/multi_cam_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytorch3d.transforms import matrix_to_euler_angles
3 |
4 | from .loss_util import compute_photometric_loss, compute_masked_loss
5 | from .single_cam_loss import SingleCamLoss
6 |
7 | from lpips import LPIPS
8 |
9 | class MultiCamLoss(SingleCamLoss):
10 | """
11 | Class for multi-camera(spatio & temporal) loss calculation
12 | """
13 | def __init__(self, cfg, rank):
14 | super(MultiCamLoss, self).__init__(cfg, rank)
15 |
16 | self.lpips = LPIPS(net="vgg").cuda(rank)
17 |
18 | def compute_spatio_loss(self, inputs, target_view, cam=None, scale=None, ref_mask=None):
19 | """
20 | This function computes spatial loss.
21 | """
22 | # self occlusion mask * overlap region mask
23 | spatio_mask = ref_mask * target_view[('overlap_mask', 0, scale)]
24 | loss_args = {
25 | 'pred': target_view[('overlap', 0, scale)],
26 | 'target': inputs['color',0, 0][:,cam, ...]
27 | }
28 | spatio_loss = compute_photometric_loss(**loss_args)
29 |
30 | target_view[('overlap_mask', 0, scale)] = spatio_mask
31 | return compute_masked_loss(spatio_loss, spatio_mask)
32 |
33 | def compute_spatio_tempo_loss(self, inputs, target_view, cam=None, scale=None, ref_mask=None, reproj_loss_mask=None) :
34 | """
35 | This function computes spatio-temporal loss.
36 | """
37 | spatio_tempo_losses = []
38 | spatio_tempo_masks = []
39 | for frame_id in self.frame_ids[1:]:
40 |
41 | pred_mask = ref_mask * target_view[('overlap_mask', frame_id, scale)]
42 | pred_mask = pred_mask * reproj_loss_mask
43 |
44 | loss_args = {
45 | 'pred': target_view[('overlap', frame_id, scale)],
46 | 'target': inputs['color',0, 0][:,cam, ...]
47 | }
48 |
49 | spatio_tempo_losses.append(compute_photometric_loss(**loss_args))
50 | spatio_tempo_masks.append(pred_mask)
51 |
52 | # concatenate losses and masks
53 | spatio_tempo_losses = torch.cat(spatio_tempo_losses, 1)
54 | spatio_tempo_masks = torch.cat(spatio_tempo_masks, 1)
55 |
56 | # for the loss, take minimum value between reprojection loss and identity loss(moving object)
57 | # for the mask, take maximum value between reprojection mask and overlap mask to apply losses on all the True values of masks.
58 | spatio_tempo_loss, _ = torch.min(spatio_tempo_losses, dim=1, keepdim=True)
59 | spatio_tempo_mask, _ = torch.max(spatio_tempo_masks.float(), dim=1, keepdim=True)
60 |
61 | return compute_masked_loss(spatio_tempo_loss, spatio_tempo_mask)
62 |
63 | def compute_pose_con_loss(self, inputs, outputs, cam=None, scale=None, ref_mask=None, reproj_loss_mask=None) :
64 | """
65 | This function computes pose consistency loss in "Full surround monodepth from multiple cameras"
66 | """
67 | ref_output = outputs[('cam', 0)]
68 | ref_ext = inputs['extrinsics'][:, 0, ...]
69 | ref_ext_inv = inputs['extrinsics_inv'][:, 0, ...]
70 |
71 | cur_output = outputs[('cam', cam)]
72 | cur_ext = inputs['extrinsics'][:, cam, ...]
73 | cur_ext_inv = inputs['extrinsics_inv'][:, cam, ...]
74 |
75 | trans_loss = 0.
76 | angle_loss = 0.
77 |
78 | for frame_id in self.frame_ids[1:]:
79 | ref_T = ref_output[('cam_T_cam', 0, frame_id)]
80 | cur_T = cur_output[('cam_T_cam', 0, frame_id)]
81 |
82 | cur_T_aligned = ref_ext_inv@cur_ext@cur_T@cur_ext_inv@ref_ext
83 |
84 | ref_ang = matrix_to_euler_angles(ref_T[:,:3,:3], 'XYZ')
85 | cur_ang = matrix_to_euler_angles(cur_T_aligned[:,:3,:3], 'XYZ')
86 |
87 | ang_diff = torch.norm(ref_ang - cur_ang, p=2, dim=1).mean()
88 | t_diff = torch.norm(ref_T[:,:3,3] - cur_T_aligned[:,:3,3], p=2, dim=1).mean()
89 |
90 | trans_loss += t_diff
91 | angle_loss += ang_diff
92 |
93 | pose_loss = (trans_loss + 10 * angle_loss) / len(self.frame_ids[1:])
94 | return pose_loss
95 |
96 | def compute_gaussian_loss(self, inputs, target_view, cam=0, scale=0):
97 | """
98 | This function computes gaussian loss.
99 | """
100 | # self occlusion mask * overlap region mask
101 | if self.novel_view_mode == 'MF':
102 | pred = target_view[('gaussian_color', 0, scale)]
103 | target_view = inputs['color', 0, 0][:,cam, ...]
104 | lpips_loss = self.lpips(pred, target_view, normalize=True).mean()
105 | l2_loss = ((pred - target_view)**2).mean()
106 | return 1 * l2_loss + 0.05 * lpips_loss
107 | elif self.novel_view_mode == 'SF':
108 | gaussian_loss = 0.0
109 | for frame_id in self.frame_ids[1:]:
110 | pred = target_view[('gaussian_color', frame_id, scale)]
111 | gt = inputs['color', frame_id, 0][:,cam, ...]
112 | lpips_loss = self.lpips(pred, gt, normalize=True).mean()
113 | l2_loss = ((pred - gt)**2).mean()
114 | gaussian_loss += 1 * l2_loss + 0.05 * lpips_loss
115 | return gaussian_loss / 2
116 |
117 |
118 | def forward(self, inputs, outputs, cam):
119 | loss_dict = {}
120 | cam_loss = 0. # loss across the multi-scale
121 | target_view = outputs[('cam', cam)]
122 | for scale in self.scales:
123 | kargs = {
124 | 'cam': cam,
125 | 'scale': scale,
126 | 'ref_mask': inputs['mask'][:,cam,...]
127 | }
128 |
129 | reprojection_loss = self.compute_reproj_loss(inputs, target_view, **kargs)
130 | smooth_loss = self.compute_smooth_loss(inputs, target_view, **kargs)
131 | spatio_loss = self.compute_spatio_loss(inputs, target_view, **kargs)
132 |
133 | kargs['reproj_loss_mask'] = target_view[('reproj_mask', scale)]
134 | spatio_tempo_loss = self.compute_spatio_tempo_loss(inputs, target_view, **kargs)
135 |
136 | if self.gaussian:
137 | gaussian_loss = self.compute_gaussian_loss(inputs, target_view, cam, scale)
138 |
139 | pose_loss = 0
140 |
141 | cam_loss += reprojection_loss
142 | cam_loss += self.disparity_smoothness * smooth_loss / (2 ** scale)
143 | cam_loss += self.spatio_coeff * spatio_loss + self.spatio_tempo_coeff * spatio_tempo_loss
144 | if self.gaussian:
145 | cam_loss += self.gaussian_coeff * gaussian_loss
146 | cam_loss += self.pose_loss_coeff* pose_loss
147 |
148 | ##########################
149 | # for logger
150 | ##########################
151 | if scale == 0:
152 | loss_dict['reproj_loss'] = reprojection_loss.item()
153 | loss_dict['spatio_loss'] = spatio_loss.item()
154 | if self.gaussian:
155 | loss_dict['gaussian_loss'] = gaussian_loss.item()
156 | loss_dict['spatio_tempo_loss'] = spatio_tempo_loss.item()
157 | loss_dict['smooth'] = smooth_loss.item()
158 |
159 | # log statistics
160 | self.get_logs(loss_dict, target_view, cam)
161 |
162 | cam_loss /= len(self.scales)
163 | return cam_loss, loss_dict
--------------------------------------------------------------------------------
/models/losses/single_cam_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .loss_util import compute_photometric_loss, compute_edg_smooth_loss, compute_masked_loss, compute_auto_masks
4 | from .base_loss import BaseLoss
5 |
6 | _EPSILON = 0.00001
7 |
8 |
9 | class SingleCamLoss(BaseLoss):
10 | """
11 | Class for single camera(temporal only) loss calculation
12 | """
13 | def __init__(self, cfg, rank):
14 | super().__init__(cfg, rank)
15 |
16 | def compute_reproj_loss(self, inputs, target_view, cam=0, scale=0, ref_mask=None):
17 | """
18 | This function computes reprojection loss using auto mask.
19 | """
20 | reprojection_losses = []
21 | for frame_id in self.frame_ids[1:]:
22 | reproj_loss_args = {
23 | 'pred': target_view[('color', frame_id, scale)],
24 | 'target': inputs['color',0, 0][:, cam, ...]
25 | }
26 | reprojection_losses.append(
27 | compute_photometric_loss(**reproj_loss_args)
28 | )
29 |
30 | reprojection_losses = torch.cat(reprojection_losses, 1)
31 | reprojection_loss, _ = torch.min(reprojection_losses, dim=1, keepdim=True)
32 |
33 | identity_reprojection_losses = []
34 | for frame_id in self.frame_ids[1:]:
35 | identity_reproj_loss_args = {
36 | 'pred': inputs[('color', frame_id, 0)][:, cam, ...],
37 | 'target': inputs['color',0, 0][:, cam, ...]
38 | }
39 | identity_reprojection_losses.append(
40 | compute_photometric_loss(**identity_reproj_loss_args)
41 | )
42 |
43 | identity_reprojection_losses = torch.cat(identity_reprojection_losses, 1)
44 | identity_reprojection_losses = identity_reprojection_losses + \
45 | _EPSILON * torch.randn(identity_reprojection_losses.shape).to(self.rank)
46 | identity_reprojection_loss, _ = torch.min(identity_reprojection_losses, dim=1, keepdim=True)
47 |
48 | # find minimum losses
49 | reprojection_auto_mask = compute_auto_masks(reprojection_loss, identity_reprojection_loss)
50 | reprojection_auto_mask *= ref_mask
51 |
52 | target_view[('reproj_loss', scale)] = reprojection_auto_mask * reprojection_loss
53 | target_view[('reproj_mask', scale)] = reprojection_auto_mask
54 | return compute_masked_loss(reprojection_loss, reprojection_auto_mask)
55 |
56 | def compute_smooth_loss(self, inputs, target_view, cam = 0, scale = 0, ref_mask=None):
57 | """
58 | This function computes edge-aware smoothness loss for the disparity map.
59 | """
60 | color = inputs['color', 0, scale][:, cam, ...]
61 | disp = target_view[('disp', scale)]
62 | mean_disp = disp.mean(2, True).mean(3, True)
63 | norm_disp = disp / (mean_disp + 1e-8)
64 | return compute_edg_smooth_loss(color, norm_disp)
65 |
66 | def forward(self, inputs, outputs, cam):
67 | loss_dict = {}
68 | cam_loss = 0. # loss across the multi-scale
69 | target_view = outputs[('cam', cam)]
70 | for scale in self.scales:
71 | kargs = {
72 | 'cam': cam,
73 | 'scale': scale,
74 | 'ref_mask': inputs['mask'][:,cam,...]
75 | }
76 |
77 | reprojection_loss = self.compute_reproj_loss(inputs, target_view, **kargs)
78 | smooth_loss = self.compute_smooth_loss(inputs, target_view, **kargs)
79 |
80 | cam_loss += reprojection_loss
81 | cam_loss += self.disparity_smoothness * smooth_loss / (2 ** scale)
82 |
83 | ##########################
84 | # for logger
85 | ##########################
86 | if scale == 0:
87 | loss_dict['reproj_loss'] = reprojection_loss.item()
88 | loss_dict['smooth'] = smooth_loss.item()
89 |
90 | # log statistics
91 | self.get_logs(loss_dict, target_view, cam)
92 |
93 | cam_loss /= len(self.scales)
94 | return cam_loss, loss_dict
--------------------------------------------------------------------------------
/network/__init__.py:
--------------------------------------------------------------------------------
1 | from .pose_network import PoseNetwork
2 | from .depth_network import DepthNetwork
3 |
4 | __all__ = ['DepthNetwork', 'PoseNetwork']
--------------------------------------------------------------------------------
/network/blocks.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | def pack_cam_feat(x):
6 | if isinstance(x, dict):
7 | for k, v in x.items():
8 | b, n_cam = v.shape[:2]
9 | x[k] = v.view(b*n_cam, *v.shape[2:])
10 | return x
11 | else:
12 | b, n_cam = x.shape[:2]
13 | x = x.view(b*n_cam, *x.shape[2:])
14 | return x
15 |
16 |
17 | def unpack_cam_feat(x, b, n_cam):
18 | if isinstance(x, dict):
19 | for k, v in x.items():
20 | x[k] = v.view(b, n_cam, *v.shape[1:])
21 | return x
22 | else:
23 | x = x.view(b, n_cam, *x.shape[1:])
24 | return x
25 |
26 |
27 | def upsample(x):
28 | return F.interpolate(x, scale_factor=2, mode='nearest')
29 |
30 |
31 | def conv2d(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, nonlin = 'LRU', padding_mode = 'reflect', norm = False):
32 | if nonlin== 'LRU':
33 | act = nn.LeakyReLU(0.1, inplace=True)
34 | elif nonlin == 'ELU':
35 | act = nn.ELU(inplace=True)
36 | else:
37 | act = nn.Identity()
38 |
39 | if norm:
40 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
41 | padding=((kernel_size - 1) * dilation) // 2, bias=False, padding_mode=padding_mode)
42 | bnorm = nn.BatchNorm2d(out_planes)
43 | else:
44 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
45 | padding=((kernel_size - 1) * dilation) // 2, bias=True, padding_mode=padding_mode)
46 | bnorm = nn.Identity()
47 | return nn.Sequential(conv, bnorm, act)
48 |
49 |
50 | def conv1d(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, nonlin='LRU', padding_mode='reflect', norm = False):
51 | if nonlin== 'LRU':
52 | act = nn.LeakyReLU(0.1, inplace=True)
53 | elif nonlin == 'ELU':
54 | act = nn.ELU(inplace=True)
55 | else:
56 | act = nn.Identity()
57 |
58 | if norm:
59 | conv = nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
60 | padding=((kernel_size - 1) * dilation) // 2, bias=False, padding_mode=padding_mode)
61 | bnorm = nn.BatchNorm1d(out_planes)
62 | else:
63 | conv = nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
64 | padding=((kernel_size - 1) * dilation) // 2, bias=True, padding_mode=padding_mode)
65 | bnorm = nn.Identity()
66 |
67 | return nn.Sequential(conv, bnorm, act)
--------------------------------------------------------------------------------
/network/depth_network.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from .blocks import upsample, conv2d, pack_cam_feat, unpack_cam_feat
8 | from .volumetric_fusionnet import VFNet
9 |
10 | from external.layers import ResnetEncoder
11 |
12 | class DepthNetwork(nn.Module):
13 | """
14 | Depth fusion module
15 | """
16 | def __init__(self, cfg):
17 | super(DepthNetwork, self).__init__()
18 | self.read_config(cfg)
19 |
20 | # feature encoder
21 | # resnet feat: 64(1/2), 64(1/4), 128(1/8), 256(1/16), 512(1/32)
22 | self.encoder = ResnetEncoder(self.num_layers, self.weights_init, 1) # number of layers, pretrained, number of input images
23 | del self.encoder.encoder.fc # del fc in weights_init
24 | enc_feat_dim = sum(self.encoder.num_ch_enc[self.fusion_level:])
25 | self.conv1x1 = conv2d(enc_feat_dim, self.fusion_feat_in_dim, kernel_size=1, padding_mode = 'reflect')
26 |
27 | # fusion net
28 | fusion_feat_out_dim = self.encoder.num_ch_enc[self.fusion_level]
29 | self.fusion_net = VFNet(cfg, self.fusion_feat_in_dim, fusion_feat_out_dim, model ='depth')
30 |
31 | # depth decoder
32 | num_ch_enc = self.encoder.num_ch_enc[:(self.fusion_level+1)]
33 | num_ch_dec = [16, 32, 64, 128, 256]
34 | self.decoder = DepthDecoder(self.fusion_level, num_ch_enc, num_ch_dec, self.scales, use_skips = self.use_skips)
35 |
36 | def read_config(self, cfg):
37 | for attr in cfg.keys():
38 | for k, v in cfg[attr].items():
39 | setattr(self, k, v)
40 |
41 | def forward(self, inputs):
42 | '''
43 | dict_keys(['idx', 'sensor_name', 'filename', 'extrinsics', 'mask',
44 | ('K', 0), ('inv_K', 0), ('color', 0, 0), ('color_aug', 0, 0),
45 | ('K', 1), ('inv_K', 1), ('color', 0, 1), ('color_aug', 0, 1),
46 | ('K', 2), ('inv_K', 2), ('color', 0, 2), ('color_aug', 0, 2),
47 | ('K', 3), ('inv_K', 3), ('color', 0, 3), ('color_aug', 0, 3),
48 | ('color', -1, 0), ('color_aug', -1, 0), ('color', 1, 0), ('color_aug', 1, 0), 'extrinsics_inv'])
49 | '''
50 |
51 | outputs = {}
52 |
53 | # dictionary initialize
54 | for cam in range(self.num_cams): # self.num_cames = 6
55 | outputs[('cam', cam)] = {} # outputs = {('cam', 0): {}, ..., ('cam', 5): {}}
56 |
57 | lev = self.fusion_level # 2
58 |
59 | # packed images for surrounding view
60 | sf_images = torch.stack([inputs[('color_aug', 0, 0)][:, cam, ...] for cam in range(self.num_cams)], 1)
61 | if self.novel_view_mode == 'MF':
62 | sf_images_last = torch.stack([inputs[('color_aug', -1, 0)][:, cam, ...] for cam in range(self.num_cams)], 1)
63 | sf_images_next = torch.stack([inputs[('color_aug', 1, 0)][:, cam, ...] for cam in range(self.num_cams)], 1)
64 | packed_input = pack_cam_feat(sf_images)
65 | if self.novel_view_mode == 'MF':
66 | packed_input_last = pack_cam_feat(sf_images_last)
67 | packed_input_next = pack_cam_feat(sf_images_next)
68 |
69 | # feature encoder
70 | packed_feats = self.encoder(packed_input)
71 | if self.novel_view_mode == 'MF':
72 | packed_feats_last = self.encoder(packed_input_last)
73 | packed_feats_next = self.encoder(packed_input_next)
74 | # aggregate feature H / 2^(lev+1) x W / 2^(lev+1)
75 | _, _, up_h, up_w = packed_feats[lev].size()
76 |
77 | packed_feats_list = packed_feats[lev:lev+1] \
78 | + [F.interpolate(feat, [up_h, up_w], mode='bilinear', align_corners=True) for feat in packed_feats[lev+1:]]
79 | if self.novel_view_mode == 'MF':
80 | packed_feats_last_list = packed_feats_last[lev:lev+1] \
81 | + [F.interpolate(feat, [up_h, up_w], mode='bilinear', align_corners=True) for feat in packed_feats_last[lev+1:]]
82 | packed_feats_next_list = packed_feats_next[lev:lev+1] \
83 | + [F.interpolate(feat, [up_h, up_w], mode='bilinear', align_corners=True) for feat in packed_feats_next[lev+1:]]
84 |
85 | packed_feats_agg = self.conv1x1(torch.cat(packed_feats_list, dim=1))
86 | if self.novel_view_mode == 'MF':
87 | packed_feats_agg_last = self.conv1x1(torch.cat(packed_feats_last_list, dim=1))
88 | packed_feats_agg_next = self.conv1x1(torch.cat(packed_feats_next_list, dim=1))
89 |
90 | feats_agg = unpack_cam_feat(packed_feats_agg, self.batch_size, self.num_cams)
91 | if self.novel_view_mode == 'MF':
92 | feats_agg_last = unpack_cam_feat(packed_feats_agg_last, self.batch_size, self.num_cams)
93 | feats_agg_next = unpack_cam_feat(packed_feats_agg_next, self.batch_size, self.num_cams)
94 |
95 | # fusion_net, backproject each feature into the 3D voxel space
96 | fusion_dict = self.fusion_net(inputs, feats_agg)
97 | if self.novel_view_mode == 'MF':
98 | fusion_dict_last = self.fusion_net(inputs, feats_agg_last)
99 | fusion_dict_next = self.fusion_net(inputs, feats_agg_next)
100 |
101 | feat_in = packed_feats[:lev] + [fusion_dict['proj_feat']]
102 | img_feat = []
103 | for i in range(len(feat_in)):
104 | img_feat.append(unpack_cam_feat(feat_in[i], self.batch_size, self.num_cams))
105 |
106 | if self.novel_view_mode == 'MF':
107 | feat_in_last = packed_feats_last[:lev] + [fusion_dict_last['proj_feat']]
108 | img_feat_last = []
109 | for i in range(len(feat_in_last)):
110 | img_feat_last.append(unpack_cam_feat(feat_in_last[i], self.batch_size, self.num_cams))
111 |
112 | feat_in_next = packed_feats_next[:lev] + [fusion_dict_next['proj_feat']]
113 | img_feat_next = []
114 | for i in range(len(feat_in_next)):
115 | img_feat_next.append(unpack_cam_feat(feat_in_next[i], self.batch_size, self.num_cams))
116 |
117 | packed_depth_outputs = self.decoder(feat_in)
118 | if self.novel_view_mode == 'MF':
119 | packed_depth_outputs_last = self.decoder(feat_in_last)
120 | packed_depth_outputs_next = self.decoder(feat_in_next)
121 |
122 | depth_outputs = unpack_cam_feat(packed_depth_outputs, self.batch_size, self.num_cams)
123 | if self.novel_view_mode == 'MF':
124 | depth_outputs_last = unpack_cam_feat(packed_depth_outputs_last, self.batch_size, self.num_cams)
125 | depth_outputs_next = unpack_cam_feat(packed_depth_outputs_next, self.batch_size, self.num_cams)
126 |
127 | for cam in range(self.num_cams):
128 | for k in depth_outputs.keys():
129 | outputs[('cam', cam)][k] = depth_outputs[k][:, cam, ...]
130 | outputs[('cam', cam)][('img_feat', 0, 0)] = [feat[:, cam, ...] for feat in img_feat]
131 | if self.novel_view_mode == 'MF':
132 | outputs[('cam', cam)][('img_feat', -1, 0)] = [feat[:, cam, ...] for feat in img_feat_last]
133 | outputs[('cam', cam)][('img_feat', 1, 0)] = [feat[:, cam, ...] for feat in img_feat_next]
134 | outputs[('cam', cam)][('disp', -1, 0)] = depth_outputs_last[('disp', 0)][:, cam, ...]
135 | outputs[('cam', cam)][('disp', 1, 0)] = depth_outputs_next[('disp', 0)][:, cam, ...]
136 |
137 | return outputs
138 |
139 |
140 | class DepthDecoder(nn.Module):
141 | def __init__(self, level_in, num_ch_enc, num_ch_dec, scales=range(2), use_skips=False):
142 | super(DepthDecoder, self).__init__()
143 |
144 | self.num_output_channels = 1
145 | self.scales = scales
146 | self.use_skips = use_skips
147 |
148 | self.level_in = level_in
149 | self.num_ch_enc = num_ch_enc
150 | self.num_ch_dec = num_ch_dec
151 |
152 | self.convs = OrderedDict()
153 | for i in range(self.level_in, -1, -1):
154 | num_ch_in = self.num_ch_enc[-1] if i == self.level_in else self.num_ch_dec[i + 1]
155 | num_ch_out = self.num_ch_dec[i]
156 | self.convs[('upconv', i, 0)] = conv2d(num_ch_in, num_ch_out, kernel_size=3, nonlin = 'ELU')
157 |
158 | num_ch_in = self.num_ch_dec[i]
159 | if self.use_skips and i > 0:
160 | num_ch_in += self.num_ch_enc[i - 1]
161 | num_ch_out = self.num_ch_dec[i]
162 | self.convs[('upconv', i, 1)] = conv2d(num_ch_in, num_ch_out, kernel_size=3, nonlin = 'ELU')
163 |
164 | for s in self.scales:
165 | self.convs[('dispconv', s)] = conv2d(self.num_ch_dec[s], self.num_output_channels, 3, nonlin = None)
166 |
167 | self.decoder = nn.ModuleList(list(self.convs.values()))
168 | self.sigmoid = nn.Sigmoid()
169 |
170 | def forward(self, input_features):
171 | outputs = {}
172 |
173 | # decode
174 | x = input_features[-1]
175 | for i in range(self.level_in, -1, -1):
176 | x = self.convs[('upconv', i, 0)](x)
177 | x = [upsample(x)]
178 | if self.use_skips and i > 0:
179 | x += [input_features[i - 1]]
180 | x = torch.cat(x, 1)
181 | x = self.convs[('upconv', i, 1)](x)
182 | if i in self.scales:
183 | outputs[('disp', i)] = self.sigmoid(self.convs[('dispconv', i)](x))
184 | return outputs
--------------------------------------------------------------------------------
/network/pose_network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .blocks import conv2d, pack_cam_feat, unpack_cam_feat
6 | from .volumetric_fusionnet import VFNet
7 |
8 | from external.layers import ResnetEncoder, PoseDecoder
9 |
10 |
11 | class PoseNetwork(nn.Module):
12 | """
13 | Canonical motion estimation module.
14 | """
15 | def __init__(self, cfg):
16 | super(PoseNetwork, self).__init__()
17 | self.read_config(cfg)
18 |
19 | # feature encoder
20 | # resnet feat: 64(1/2), 64(1/4), 128(1/8), 256(1/16), 512(1/32)
21 | self.encoder = ResnetEncoder(self.num_layers, self.weights_init, 2) # number of layers, pretrained, number of input images
22 | del self.encoder.encoder.fc
23 | enc_feat_dim = sum(self.encoder.num_ch_enc[self.fusion_level:])
24 | self.conv1x1 = conv2d(enc_feat_dim, self.fusion_feat_in_dim, kernel_size=1, padding_mode='reflect')
25 |
26 | # fusion net
27 | fusion_feat_out_dim = self.encoder.num_ch_enc[self.fusion_level]
28 | self.fusion_net = VFNet(cfg, self.fusion_feat_in_dim, fusion_feat_out_dim, model = 'pose')
29 |
30 | # depth decoder
31 | self.pose_decoder = PoseDecoder(num_ch_enc = [fusion_feat_out_dim],
32 | num_input_features=1,
33 | num_frames_to_predict_for=1,
34 | stride=2)
35 |
36 | def read_config(self, cfg):
37 | for attr in cfg.keys():
38 | for k, v in cfg[attr].items():
39 | setattr(self, k, v)
40 |
41 | def forward(self, inputs, frame_ids, _):
42 | outputs = {}
43 |
44 | # initialize dictionary
45 | for cam in range(self.num_cams):
46 | outputs[('cam', cam)] = {}
47 |
48 | lev = self.fusion_level
49 |
50 | # packed images for surrounding view
51 | cur_image = inputs[('color_aug', frame_ids[0], 0)]
52 | next_image = inputs[('color_aug', frame_ids[1], 0)]
53 |
54 | pose_images = torch.cat([cur_image, next_image], 2)
55 | packed_pose_images = pack_cam_feat(pose_images)
56 |
57 | packed_feats = self.encoder(packed_pose_images)
58 |
59 | # aggregate feature H / 2^(lev+1) x W / 2^(lev+1)
60 | _, _, up_h, up_w = packed_feats[lev].size()
61 |
62 | packed_feats_list = packed_feats[lev:lev+1] \
63 | + [F.interpolate(feat, [up_h, up_w], mode='bilinear', align_corners=True) for feat in packed_feats[lev+1:]]
64 |
65 | packed_feats_agg = self.conv1x1(torch.cat(packed_feats_list, dim=1))
66 | feats_agg = unpack_cam_feat(packed_feats_agg, self.batch_size, self.num_cams)
67 |
68 | # fusion_net, backproject each feature into the 3D voxel space
69 | bev_feat = self.fusion_net(inputs, feats_agg)
70 | axis_angle, translation = self.pose_decoder([[bev_feat]])
71 | return axis_angle, torch.clamp(translation, -4.0, 4.0) # for DDAD dataset
--------------------------------------------------------------------------------
/network/volumetric_fusionnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from pytorch3d.transforms import axis_angle_to_matrix
5 |
6 | from .blocks import conv2d, conv1d, pack_cam_feat
7 |
8 |
9 | class VFNet(nn.Module):
10 | """
11 | Surround-view fusion module that estimates a single 3D feature using surround-view images
12 | """
13 | def __init__(self, cfg, feat_in_dim, feat_out_dim, model='depth'):
14 | super(VFNet, self).__init__()
15 | self.read_config(cfg)
16 | self.eps = 1e-8
17 | self.model = model
18 | # define the 3D voxel space(follows the DDAD extrinsic coordinate -- x: forward, y: left, z: up)
19 | # define voxel end range in accordance with voxel_str_p, voxel_size, voxel_unit_size
20 | self.voxel_end_p = [self.voxel_str_p[i] + self.voxel_unit_size[i] * (self.voxel_size[i] - 1) for i in range(3)]
21 |
22 | # define a voxel space, [1, 3, z, y, x], each voxel contains its 3D position
23 | voxel_grid = self.create_voxel_grid(self.voxel_str_p, self.voxel_end_p, self.voxel_size)
24 | b, _, self.z_dim, self.y_dim, self.x_dim = voxel_grid.size()
25 | self.n_voxels = self.z_dim * self.y_dim * self.x_dim
26 | ones = torch.ones(self.batch_size, 1, self.n_voxels)
27 | self.voxel_pts = torch.cat([voxel_grid.view(b, 3, self.n_voxels), ones], dim=1)
28 |
29 | # define grids in pixel space
30 | self.img_h = self.height // (2 ** (self.fusion_level+1))
31 | self.img_w = self.width // (2 ** (self.fusion_level+1))
32 | self.num_pix = self.img_h * self.img_w
33 | self.pixel_grid = self.create_pixel_grid(self.batch_size, self.img_h, self.img_w)
34 | self.pixel_ones = torch.ones(self.batch_size, 1, self.proj_d_bins, self.num_pix)
35 |
36 | # define a depth grid for projection
37 | depth_bins = torch.linspace(self.proj_d_str, self.proj_d_end, self.proj_d_bins)
38 | self.depth_grid = self.create_depth_grid(self.batch_size, self.num_pix, self.proj_d_bins, depth_bins)
39 |
40 | # depth fusion(process overlap and non-overlap regions)
41 | if model == 'depth':
42 | # voxel - preprocessing layer
43 | self.v_dim_o = [(feat_in_dim + 1) * 2] + self.voxel_pre_dim
44 | self.v_dim_no = [feat_in_dim + 1] + self.voxel_pre_dim
45 |
46 | self.conv_overlap = conv1d(self.v_dim_o[0], self.v_dim_o[1], kernel_size=1)
47 | self.conv_non_overlap = conv1d(self.v_dim_no[0], self.v_dim_no[1], kernel_size=1)
48 |
49 | encoder_dims = self.proj_d_bins * self.v_dim_o[-1]
50 | stride = 1
51 |
52 | else:
53 | encoder_dims = (feat_in_dim + 1)*self.z_dim
54 | stride = 2
55 |
56 | # channel dimension reduction
57 | self.reduce_dim = nn.Sequential(*conv2d(encoder_dims, 256, kernel_size=3, stride = stride).children(),
58 | *conv2d(256, feat_out_dim, kernel_size=3, stride = stride).children())
59 |
60 | def read_config(self, cfg):
61 | for attr in cfg.keys():
62 | for k, v in cfg[attr].items():
63 | setattr(self, k, v)
64 |
65 | def create_voxel_grid(self, str_p, end_p, v_size):
66 | grids = [torch.linspace(str_p[i], end_p[i], v_size[i]) for i in range(3)]
67 |
68 | x_dim, y_dim, z_dim = v_size
69 | grids[0] = grids[0].view(1, 1, 1, 1, x_dim)
70 | grids[1] = grids[1].view(1, 1, 1, y_dim, 1)
71 | grids[2] = grids[2].view(1, 1, z_dim, 1, 1)
72 |
73 | grids = [grid.expand(self.batch_size, 1, z_dim, y_dim, x_dim) for grid in grids]
74 | return torch.cat(grids, 1)
75 |
76 | def create_pixel_grid(self, batch_size, height, width):
77 | grid_xy = torch.meshgrid(torch.arange(width), torch.arange(height), indexing='xy')
78 | pix_coords = torch.stack(grid_xy, axis=0).unsqueeze(0).view(1, 2, height * width)
79 | pix_coords = pix_coords.repeat(batch_size, 1, 1)
80 | ones = torch.ones(batch_size, 1, height * width)
81 | pix_coords = torch.cat([pix_coords, ones], 1)
82 | return pix_coords
83 |
84 | def create_depth_grid(self, batch_size, n_pixels, n_depth_bins, depth_bins):
85 | depth_layers = []
86 | for d in depth_bins:
87 | depth_layer = torch.ones((1, n_pixels)) * d
88 | depth_layers.append(depth_layer)
89 | depth_layers = torch.cat(depth_layers, dim=0).view(1, 1, n_depth_bins, n_pixels)
90 | depth_layers = depth_layers.expand(batch_size, 3, n_depth_bins, n_pixels)
91 | return depth_layers
92 |
93 | def type_check(self, sample_tensor):
94 | d_dtype, d_device = sample_tensor.dtype, sample_tensor.device
95 | if (self.voxel_pts.dtype != d_dtype) or (self.voxel_pts.device != d_device):
96 | self.voxel_pts = self.voxel_pts.to(device=d_device, dtype=d_dtype)
97 | self.pixel_grid = self.pixel_grid.to(device=d_device, dtype=d_dtype)
98 | self.depth_grid = self.depth_grid.to(device=d_device, dtype=d_dtype)
99 | self.pixel_ones = self.pixel_ones.to(device=d_device, dtype=d_dtype)
100 |
101 | def backproject_into_voxel(self, feats_agg, input_mask, intrinsics, extrinsics_inv):
102 | voxel_feat_list = []
103 | voxel_mask_list = []
104 |
105 | for cam in range(self.num_cams):
106 | feats_img = feats_agg[:, cam, ...]
107 | _, _, h_dim, w_dim = feats_img.size()
108 |
109 | mask_img = input_mask[:, cam, ...]
110 | mask_img = F.interpolate(mask_img, [h_dim, w_dim], mode='bilinear', align_corners=True)
111 |
112 | # 3D points in the voxel grid -> 3D points referenced at each view. [b, 3, n_voxels]
113 | ext_inv_mat = extrinsics_inv[:, cam, :3, :]
114 | v_pts_local = torch.matmul(ext_inv_mat, self.voxel_pts)
115 |
116 | # calculate pixel coordinate that each point are projected in the image. [b, n_voxels, 1, 2]
117 | K_mat = intrinsics[:, cam, :, :]
118 | pix_coords = self.calculate_sample_pixel_coords(K_mat, v_pts_local, w_dim, h_dim)
119 |
120 | # compute validity mask. [b, 1, n_voxels]
121 | valid_mask = self.calculate_valid_mask(mask_img, pix_coords, v_pts_local)
122 |
123 | # retrieve each per-pixel feature. [b, feat_dim, n_voxels, 1]
124 | feat_warped = F.grid_sample(feats_img, pix_coords, mode='bilinear', padding_mode='zeros', align_corners=True)
125 | # concatenate relative depth as the feature. [b, feat_dim + 1, n_voxels]
126 | feat_warped = torch.cat([feat_warped.squeeze(-1), v_pts_local[:, 2:3, :]/(self.voxel_size[0])], dim=1)
127 | feat_warped = feat_warped * valid_mask.float()
128 |
129 | voxel_feat_list.append(feat_warped) # [[batch_size, feat_dim, n_voxels], ...]
130 | voxel_mask_list.append(valid_mask)
131 |
132 | # compute overlap region
133 | voxel_mask_count = torch.sum(torch.cat(voxel_mask_list, dim=1), dim=1, keepdim=True)
134 |
135 | if self.model == 'depth':
136 | # discriminatively process overlap and non_overlap regions using different MLPs
137 | voxel_non_overlap = self.preprocess_non_overlap(voxel_feat_list, voxel_mask_list, voxel_mask_count)
138 | voxel_overlap = self.preprocess_overlap(voxel_feat_list, voxel_mask_list, voxel_mask_count)
139 | voxel_feat = voxel_non_overlap + voxel_overlap # [batch_size, feat_dim, n_voxels]
140 |
141 | elif self.model == 'pose':
142 | voxel_feat = torch.sum(torch.stack(voxel_feat_list, dim=1), dim=1, keepdim=False)
143 | voxel_feat = voxel_feat/(voxel_mask_count+1e-7)
144 |
145 | return voxel_feat
146 |
147 | def calculate_sample_pixel_coords(self, K, v_pts, w_dim, h_dim):
148 | """
149 | This function calculates pixel coords for each point([batch, n_voxels, 1, 2]) to sample the per-pixel feature.
150 | """
151 | cam_points = torch.matmul(K[:, :3, :3], v_pts) # [batch_size, 3, n_voxels]
152 | pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
153 |
154 | if not torch.all(torch.isfinite(pix_coords)):
155 | pix_coords = torch.clamp(pix_coords, min=-w_dim*2, max=w_dim*2)
156 |
157 | pix_coords = pix_coords.view(self.batch_size, 2, self.n_voxels, 1) #[batch_size, 2, n_voxels] -> [batch_size, 2, n_voxels, 1]
158 | pix_coords = pix_coords.permute(0, 2, 3, 1) # [batch_size, 2, n_voxels, 1] -> [batch_size, n_voxels, 1, 2]
159 | pix_coords[:, :, :, 0] = pix_coords[:, :, :, 0] / (w_dim - 1)
160 | pix_coords[:, :, :, 1] = pix_coords[:, :, :, 1] / (h_dim - 1)
161 | pix_coords = (pix_coords - 0.5) * 2 # [batch_size, n_voxels, 1, 2]
162 | return pix_coords
163 |
164 | def calculate_valid_mask(self, mask_img, pix_coords, v_pts_local):
165 | """
166 | This function creates valid mask in voxel coordinate by projecting self-occlusion mask to 3D voxel coords.
167 | """
168 | # compute validity mask, [b, 1, n_voxels, 1]
169 | mask_selfocc = (F.grid_sample(mask_img, pix_coords, mode='nearest', padding_mode='zeros', align_corners=True) > 0.5)
170 | # discard points behind the camera, [b, 1, n_voxels]
171 | mask_depth = (v_pts_local[:, 2:3, :] > 0)
172 | # compute validity mask, [b, 1, n_voxels, 1]
173 | pix_coords_mask = pix_coords.permute(0, 3, 1, 2)
174 | mask_oob = ~(torch.logical_or(pix_coords_mask > 1, pix_coords_mask < -1).sum(dim=1, keepdim=True) > 0)
175 | valid_mask = mask_selfocc.squeeze(-1) * mask_depth * mask_oob.squeeze(-1)
176 | return valid_mask
177 |
178 | def preprocess_non_overlap(self, voxel_feat_list, voxel_mask_list, voxel_mask_count):
179 | """
180 | This function applies 1x1 convolutions to features from non-overlapping features.
181 | """
182 | non_overlap_mask = (voxel_mask_count == 1)
183 | voxel = sum(voxel_feat_list)
184 | voxel = voxel * non_overlap_mask.float()
185 |
186 | for conv_no in self.conv_non_overlap:
187 | voxel = conv_no(voxel)
188 | return voxel * non_overlap_mask.float()
189 |
190 | def preprocess_overlap(self, voxel_feat_list, voxel_mask_list, voxel_mask_count):
191 | """
192 | This function applies 1x1 convolutions on overlapping features.
193 | Camera configuration [0,1,2] or [0,1,2,3,4,5]:
194 | 3 1
195 | rear cam <- 5 0 -> front cam
196 | 4 2
197 | """
198 | overlap_mask = (voxel_mask_count == 2)
199 | if self.num_cams == 3:
200 | feat1 = voxel_feat_list[0]
201 | feat2 = voxel_feat_list[1] + voxel_feat_list[2]
202 | elif self.num_cams == 6:
203 | feat1 = voxel_feat_list[0] + voxel_feat_list[3] + voxel_feat_list[4]
204 | feat2 = voxel_feat_list[1] + voxel_feat_list[2] + voxel_feat_list[5]
205 | else:
206 | raise NotImplementedError
207 |
208 | voxel = torch.cat([feat1, feat2], dim=1)
209 | for conv_o in self.conv_overlap:
210 | voxel = conv_o(voxel)
211 | return voxel * overlap_mask.float()
212 |
213 | def project_voxel_into_image(self, voxel_feat, inv_K, extrinsics):
214 | """
215 | This function projects voxels into 2D image coordinate.
216 | [b, feat_dim, n_voxels] -> [b, feat_dim, d, h, w]
217 | """
218 | # define depth bin
219 | # [b, feat_dim, n_voxels] -> [b, feat_dim, d, h, w]
220 | b, feat_dim, _ = voxel_feat.size()
221 | voxel_feat = voxel_feat.view(b, feat_dim, self.z_dim, self.y_dim, self.x_dim)
222 |
223 | proj_feats = []
224 | for cam in range(self.num_cams):
225 | # construct 3D point grid for each view
226 | cam_points = torch.matmul(inv_K[:, cam, :3, :3], self.pixel_grid)
227 | cam_points = self.depth_grid * cam_points.view(self.batch_size, 3, 1, self.num_pix)
228 | cam_points = torch.cat([cam_points, self.pixel_ones], dim=1) # [b, 4, n_depthbins, n_pixels]
229 | cam_points = cam_points.view(self.batch_size, 4, -1) # [b, 4, n_depthbins * n_pixels]
230 |
231 | # apply extrinsic: local 3D point -> global coordinate, [b, 3, n_depthbins * n_pixels]
232 | points = torch.matmul(extrinsics[:, cam, :3, :], cam_points)
233 |
234 | # 3D grid_sample [b, n_voxels, 3], value: (x, y, z) point
235 | grid = points.permute(0, 2, 1)
236 |
237 | for i in range(3):
238 | v_length = self.voxel_end_p[i] - self.voxel_str_p[i]
239 | grid[:, :, i] = (grid[:, :, i] - self.voxel_str_p[i]) / v_length * 2. - 1.
240 |
241 | grid = grid.view(self.batch_size, self.proj_d_bins, self.img_h, self.img_w, 3)
242 | proj_feat = F.grid_sample(voxel_feat, grid, mode='bilinear', padding_mode='zeros', align_corners=True)
243 | proj_feat = proj_feat.view(b, self.proj_d_bins * self.v_dim_o[-1], self.img_h, self.img_w)
244 |
245 | # conv, reduce dimension
246 | proj_feat = self.reduce_dim(proj_feat)
247 | proj_feats.append(proj_feat)
248 | return proj_feats
249 |
250 | def augment_extrinsics(self, ext):
251 | """
252 | This function augments depth estimation results using augmented extrinsics [batch, cam, 4, 4]
253 | """
254 | with torch.no_grad():
255 | b, cam, _, _ = ext.size()
256 | ext_aug = ext.clone()
257 |
258 | # rotation augmentation
259 | angle = torch.rand(b, cam, 3)
260 | self.aug_angle = [15, 15, 40]
261 | for i in range(3):
262 | angle[:, :, i] = (angle[:, :, i] - 0.5) * self.aug_angle[i]
263 | angle_mat = axis_angle_to_matrix(angle) # 3x3
264 | tform_mat = torch.eye(4).repeat(b, cam, 1, 1)
265 | tform_mat[:, :, :3, :3] = angle_mat
266 | tform_mat = tform_mat.to(device=ext.device, dtype=ext.dtype)
267 |
268 | ext_aug = tform_mat @ ext_aug
269 | return ext_aug
270 |
271 | def forward(self, inputs, feats_agg):
272 | mask = inputs['mask']
273 | K = inputs['K', self.fusion_level+1]
274 | inv_K = inputs['inv_K', self.fusion_level+1]
275 | extrinsics = inputs['extrinsics']
276 | extrinsics_inv = inputs['extrinsics_inv']
277 |
278 | fusion_dict = {}
279 | for cam in range(self.num_cams):
280 | fusion_dict[('cam', cam)] = {}
281 |
282 | # device, dtype check, match dtype and device
283 | sample_tensor = feats_agg[0, 0, ...] # B, n_cam, c, h, w
284 | self.type_check(sample_tensor)
285 |
286 | # backproject each per-pixel feature into 3D space (or sample per-pixel features for each voxel)
287 | voxel_feat = self.backproject_into_voxel(feats_agg, mask, K, extrinsics_inv) # [batch_size, feat_dim, n_voxels]
288 |
289 | if self.model == 'depth':
290 | # for each pixel, collect voxel features -> output image feature
291 | proj_feats = self.project_voxel_into_image(voxel_feat, inv_K, extrinsics)
292 | fusion_dict['proj_feat'] = pack_cam_feat(torch.stack(proj_feats, 1))
293 | return fusion_dict
294 |
295 | elif self.model == 'pose':
296 | b, c, _ = voxel_feat.shape
297 | voxel_feat = voxel_feat.view(b, c*self.z_dim,
298 | self.y_dim, self.x_dim)
299 | bev_feat= self.reduce_dim(voxel_feat)
300 | return bev_feat
301 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | anyio==4.5.2
2 | argon2-cffi==23.1.0
3 | argon2-cffi-bindings==21.2.0
4 | arrow==1.3.0
5 | asttokens==2.4.1
6 | async-lru==2.0.4
7 | attrs==24.2.0
8 | babel==2.16.0
9 | backcall==0.2.0
10 | beautifulsoup4==4.12.3
11 | bleach==6.1.0
12 | boto3==1.34.74
13 | botocore==1.34.162
14 | cachetools==5.5.0
15 | certifi==2024.8.30
16 | cffi==1.17.1
17 | charset-normalizer==3.4.0
18 | comm==0.2.2
19 | cycler==0.12.1
20 | debugpy==1.8.7
21 | decorator==5.1.1
22 | defusedxml==0.7.1
23 | descartes==1.1.0
24 | diskcache==5.6.3
25 | e3nn==0.4.0
26 | einops==0.7.0
27 | exceptiongroup==1.2.2
28 | executing==2.1.0
29 | fastjsonschema==2.20.0
30 | fire==0.7.0
31 | fonttools==4.54.1
32 | fqdn==1.5.1
33 | fvcore==0.1.5.post20221221
34 | h11==0.14.0
35 | httpcore==1.0.6
36 | httpx==0.27.2
37 | idna==3.10
38 | imageio==2.35.1
39 | importlib_metadata==8.5.0
40 | importlib_resources==6.4.5
41 | iopath==0.1.10
42 | ipykernel==6.29.5
43 | ipython==8.12.3
44 | ipywidgets==8.1.5
45 | isoduration==20.11.0
46 | jaxtyping==0.2.19
47 | jedi==0.19.1
48 | Jinja2==3.1.4
49 | jmespath==1.0.1
50 | joblib==1.4.2
51 | json5==0.9.25
52 | jsonpointer==3.0.0
53 | jsonschema==4.23.0
54 | jsonschema-specifications==2023.12.1
55 | jupyter==1.1.1
56 | jupyter-console==6.6.3
57 | jupyter-events==0.10.0
58 | jupyter-lsp==2.2.5
59 | jupyter_client==8.6.3
60 | jupyter_core==5.7.2
61 | jupyter_server==2.14.2
62 | jupyter_server_terminals==0.5.3
63 | jupyterlab==4.2.5
64 | jupyterlab_pygments==0.3.0
65 | jupyterlab_server==2.27.3
66 | jupyterlab_widgets==3.0.13
67 | kiwisolver==1.4.7
68 | lazy_loader==0.4
69 | lpips==0.1.4
70 | MarkupSafe==2.1.5
71 | matplotlib==3.5.3
72 | matplotlib-inline==0.1.7
73 | mistune==3.0.2
74 | mpmath==1.3.0
75 | nbclient==0.10.0
76 | nbconvert==7.16.4
77 | nbformat==5.10.4
78 | nest-asyncio==1.6.0
79 | networkx==3.1
80 | notebook==7.2.2
81 | notebook_shim==0.2.4
82 | numpy==1.24.4
83 | nuscenes-devkit==1.1.9
84 | opencv-python==4.10.0.84
85 | opt-einsum-fx==0.1.4
86 | opt_einsum==3.4.0
87 | overrides==7.7.0
88 | packaging==24.1
89 | pandas==2.0.3
90 | pandocfilters==1.5.1
91 | parso==0.8.4
92 | pexpect==4.9.0
93 | pickleshare==0.7.5
94 | Pillow==9.5.0
95 | pkgutil_resolve_name==1.3.10
96 | platformdirs==4.3.6
97 | portalocker==2.10.1
98 | prometheus_client==0.21.0
99 | prompt_toolkit==3.0.48
100 | protobuf==3.20.3
101 | psutil==6.1.0
102 | ptyprocess==0.7.0
103 | pure_eval==0.2.3
104 | pycocotools==2.0.7
105 | pycparser==2.22
106 | Pygments==2.18.0
107 | pyparsing==3.1.4
108 | pyquaternion==0.9.5
109 | python-dateutil==2.9.0.post0
110 | python-json-logger==2.0.7
111 | pytorch3d==0.3.0
112 | pytz==2024.2
113 | PyWavelets==1.4.1
114 | PyYAML==6.0.1
115 | pyzmq==26.2.0
116 | referencing==0.35.1
117 | requests==2.32.3
118 | rfc3339-validator==0.1.4
119 | rfc3986-validator==0.1.1
120 | rpds-py==0.20.0
121 | s3transfer==0.10.3
122 | scikit-image==0.21.0
123 | scikit-learn==1.3.2
124 | scipy==1.10.1
125 | Send2Trash==1.8.3
126 | shapely==2.0.6
127 | six==1.16.0
128 | sniffio==1.3.1
129 | soupsieve==2.6
130 | stack-data==0.6.3
131 | sympy==1.13.3
132 | tabulate==0.9.0
133 | tenacity==8.2.3
134 | tensorboardX==2.6.2.2
135 | termcolor==2.4.0
136 | terminado==0.18.1
137 | threadpoolctl==3.5.0
138 | tifffile==2023.7.10
139 | tinycss2==1.3.0
140 | tomli==2.0.2
141 | tornado==6.4.1
142 | tqdm==4.66.5
143 | traitlets==5.14.3
144 | trimesh==4.5.2
145 | typeguard==4.3.0
146 | types-python-dateutil==2.9.0.20241003
147 | typing_extensions==4.12.2
148 | tzdata==2024.2
149 | uri-template==1.3.0
150 | urllib3==1.26.20
151 | wcwidth==0.2.13
152 | webcolors==24.8.0
153 | webencodings==0.5.1
154 | websocket-client==1.8.0
155 | widgetsnbextension==4.0.13
156 | xarray==2023.1.0
157 | yacs==0.1.8
158 | zipp==3.20.2
159 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | torch.backends.cudnn.deterministic = True
4 | torch.backends.cudnn.benchmark = False
5 | torch.backends.cuda.matmul.allow_tf32 = False
6 | torch.manual_seed(0)
7 |
8 | import utils
9 | from models import DrivingForwardModel
10 | from trainer import DrivingForwardTrainer
11 |
12 | def parse_args():
13 | parser = argparse.ArgumentParser(description='training script')
14 | parser.add_argument('--config_file', default ='./configs/nuscenes/main.yaml', type=str, help='config yaml file')
15 | parser.add_argument('--novel_view_mode', default='MF', type=str, help='MF of SF')
16 | args = parser.parse_args()
17 | return args
18 |
19 | def train(cfg):
20 | model = DrivingForwardModel(cfg, 0)
21 | trainer = DrivingForwardTrainer(cfg, 0)
22 | trainer.learn(model)
23 |
24 | if __name__ == '__main__':
25 | args = parse_args()
26 | cfg = utils.get_config(args.config_file, mode='train', novel_view_mode=args.novel_view_mode)
27 |
28 | train(cfg)
29 |
--------------------------------------------------------------------------------
/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .trainer import DrivingForwardTrainer
2 |
3 | __all__ = ['DrivingForwardTrainer']
4 |
--------------------------------------------------------------------------------
/trainer/trainer.py:
--------------------------------------------------------------------------------
1 | import time
2 | from collections import defaultdict
3 | from tqdm import tqdm
4 |
5 | import torch
6 | import torch.distributed as dist
7 | from torch import Tensor
8 |
9 | from utils import Logger
10 |
11 | from lpips import LPIPS
12 | from jaxtyping import Float, UInt8
13 | from skimage.metrics import structural_similarity
14 | from einops import reduce
15 |
16 | from PIL import Image
17 | from pathlib import Path
18 | from einops import rearrange, repeat
19 | from typing import Union
20 | import numpy as np
21 |
22 | from tqdm import tqdm
23 |
24 | FloatImage = Union[
25 | Float[Tensor, "height width"],
26 | Float[Tensor, "channel height width"],
27 | Float[Tensor, "batch channel height width"],
28 | ]
29 |
30 | class DrivingForwardTrainer:
31 | """
32 | Trainer class for training and evaluation
33 | """
34 | def __init__(self, cfg, rank, use_tb=True):
35 | self.read_config(cfg)
36 | self.rank = rank
37 | if rank == 0:
38 | self.logger = Logger(cfg, use_tb)
39 | self.depth_metric_names = self.logger.get_metric_names()
40 |
41 | self.lpips = LPIPS(net="vgg").cuda(rank)
42 |
43 | def read_config(self, cfg):
44 | for attr in cfg.keys():
45 | for k, v in cfg[attr].items():
46 | setattr(self, k, v)
47 |
48 | def learn(self, model):
49 | """
50 | This function sets training process.
51 | """
52 | train_dataloader = model.train_dataloader()
53 | if self.rank == 0:
54 | val_dataloader = model.val_dataloader()
55 | self.val_iter = iter(val_dataloader)
56 |
57 | self.step = 0
58 | start_time = time.time()
59 | for self.epoch in range(self.num_epochs):
60 |
61 | self.train(model, train_dataloader, start_time)
62 |
63 | # save model after each epoch using rank 0 gpu
64 | if self.rank == 0:
65 | model.save_model(self.epoch)
66 | print('-'*110)
67 |
68 | if self.ddp_enable:
69 | dist.barrier()
70 |
71 | if self.rank == 0:
72 | self.logger.close_tb()
73 |
74 | def train(self, model, data_loader, start_time):
75 | """
76 | This function trains models.
77 | """
78 | model.set_train()
79 | pbar = tqdm(total=len(data_loader), desc='training on epoch {}'.format(self.epoch), mininterval=100)
80 | for batch_idx, inputs in enumerate(data_loader):
81 | before_op_time = time.time()
82 | model.optimizer.zero_grad(set_to_none=True)
83 | outputs, losses = model.process_batch(inputs, self.rank)
84 | losses['total_loss'].backward()
85 | model.optimizer.step()
86 |
87 | if self.rank == 0:
88 | self.logger.update(
89 | 'train',
90 | self.epoch,
91 | self.world_size,
92 | batch_idx,
93 | self.step,
94 | start_time,
95 | before_op_time,
96 | inputs,
97 | outputs,
98 | losses
99 | )
100 |
101 | if self.logger.is_checkpoint(self.step):
102 | self.validate(model)
103 |
104 | self.step += 1
105 | pbar.update(1)
106 |
107 | pbar.close()
108 | model.lr_scheduler.step()
109 |
110 | @torch.no_grad()
111 | def validate(self, model, vis_results=False):
112 | """
113 | This function validates models on the validation dataset to monitor training process.
114 | """
115 | val_dataloader = model.val_dataloader()
116 | val_iter = iter(val_dataloader)
117 |
118 | # Ensure the model is in validation mode
119 | model.set_val()
120 |
121 | avg_reconstruction_metric = defaultdict(float)
122 |
123 | inputs = next(val_iter)
124 | outputs, _ = model.process_batch(inputs, self.rank)
125 |
126 | psnr, ssim, lpips= self.compute_reconstruction_metrics(inputs, outputs)
127 |
128 | avg_reconstruction_metric['psnr'] += psnr
129 | avg_reconstruction_metric['ssim'] += ssim
130 | avg_reconstruction_metric['lpips'] += lpips
131 |
132 | print('Validation reconstruction result...\n')
133 | print(f"\n{inputs['token'][0]}")
134 | self.logger.print_perf(avg_reconstruction_metric, 'reconstruction')
135 |
136 | # Set the model back to training mode
137 | model.set_train()
138 |
139 | @torch.no_grad()
140 | def evaluate(self, model):
141 | """
142 | This function evaluates models on validation dataset of samples with context.
143 | """
144 | eval_dataloader = model.eval_dataloader()
145 |
146 | # load model
147 | model.load_weights()
148 | model.set_eval()
149 |
150 | avg_reconstruction_metric = defaultdict(float)
151 |
152 | count = 0
153 |
154 | process = tqdm(eval_dataloader)
155 | for batch_idx, inputs in enumerate(process):
156 | outputs, _ = model.process_batch(inputs, self.rank)
157 |
158 | psnr, ssim, lpips= self.compute_reconstruction_metrics(inputs, outputs)
159 |
160 | avg_reconstruction_metric['psnr'] += psnr
161 | avg_reconstruction_metric['ssim'] += ssim
162 | avg_reconstruction_metric['lpips'] += lpips
163 | count += 1
164 |
165 | process.set_description(f"PSNR: {psnr:.4f}, SSIM: {ssim:.4f}, LPIPS: {lpips:.4f}")
166 |
167 | print(f"\n{inputs['token'][0]}")
168 | print(f"avg PSNR: {avg_reconstruction_metric['psnr']/count:.4f}, avg SSIM: {avg_reconstruction_metric['ssim']/count:.4f}, avg LPIPS: {avg_reconstruction_metric['lpips']/count:.4f}")
169 |
170 | avg_reconstruction_metric['psnr'] /= len(eval_dataloader)
171 | avg_reconstruction_metric['ssim'] /= len(eval_dataloader)
172 | avg_reconstruction_metric['lpips'] /= len(eval_dataloader)
173 |
174 | print('Evaluation reconstruction result...\n')
175 | self.logger.print_perf(avg_reconstruction_metric, 'reconstruction')
176 |
177 | def save_image(
178 | self,
179 | image: FloatImage,
180 | path: Union[Path, str],
181 | ) -> None:
182 | """Save an image. Assumed to be in range 0-1."""
183 |
184 | # Create the parent directory if it doesn't already exist.
185 | path = Path(path)
186 | path.parent.mkdir(exist_ok=True, parents=True)
187 |
188 | # Save the image.
189 | Image.fromarray(self.prep_image(image)).save(path)
190 |
191 |
192 | def prep_image(self, image: FloatImage) -> UInt8[np.ndarray, "height width channel"]:
193 | # Handle batched images.
194 | if image.ndim == 4:
195 | image = rearrange(image, "b c h w -> c h (b w)")
196 |
197 | # Handle single-channel images.
198 | if image.ndim == 2:
199 | image = rearrange(image, "h w -> () h w")
200 |
201 | # Ensure that there are 3 or 4 channels.
202 | channel, _, _ = image.shape
203 | if channel == 1:
204 | image = repeat(image, "() h w -> c h w", c=3)
205 | assert image.shape[0] in (3, 4)
206 |
207 | image = (image.detach().clip(min=0, max=1) * 255).type(torch.uint8)
208 | return rearrange(image, "c h w -> h w c").cpu().numpy()
209 |
210 | @torch.no_grad()
211 | def compute_reconstruction_metrics(self, inputs, outputs):
212 | """
213 | This function computes reconstruction metrics.
214 | """
215 | psnr = 0.0
216 | ssim = 0.0
217 | lpips = 0.0
218 | if self.novel_view_mode == 'SF':
219 | frame_id = 1
220 | elif self.novel_view_mode == 'MF':
221 | frame_id = 0
222 | else:
223 | raise ValueError(f"Invalid novel view mode: {self.novel_view_mode}")
224 | for cam in range(self.num_cams):
225 | rgb_gt = inputs[('color', frame_id, 0)][:, cam, ...]
226 | image = outputs[('cam', cam)][('gaussian_color', frame_id, 0)]
227 | psnr += self.compute_psnr(rgb_gt, image).mean()
228 | ssim += self.compute_ssim(rgb_gt, image).mean()
229 | lpips += self.compute_lpips(rgb_gt, image).mean()
230 | if self.save_images:
231 | assert self.eval_batch_size == 1
232 | if self.novel_view_mode == 'SF':
233 | self.save_image(image, Path(self.save_path) / inputs['token'][0] / f"{cam}.png")
234 | self.save_image(rgb_gt, Path(self.save_path) / inputs['token'][0] / f"{cam}_gt.png")
235 | self.save_image(inputs[('color', 0, 0)][:, cam, ...], Path(self.save_path) / inputs['token'][0] / f"{cam}_0_gt.png")
236 | elif self.novel_view_mode == 'MF':
237 | self.save_image(image, Path(self.save_path) / inputs['token'][0] / f"{cam}.png")
238 | self.save_image(rgb_gt, Path(self.save_path) / inputs['token'][0] / f"{cam}_gt.png")
239 | self.save_image(inputs[('color', -1, 0)][:, cam, ...], Path(self.save_path) / inputs['token'][0] / f"{cam}_prev_gt.png")
240 | self.save_image(inputs[('color', 1, 0)][:, cam, ...], Path(self.save_path) / inputs['token'][0] / f"{cam}_next_gt.png")
241 | psnr /= self.num_cams
242 | ssim /= self.num_cams
243 | lpips /= self.num_cams
244 | return psnr, ssim, lpips
245 |
246 | @torch.no_grad()
247 | def compute_psnr(
248 | self,
249 | ground_truth: Float[Tensor, "batch channel height width"],
250 | predicted: Float[Tensor, "batch channel height width"],
251 | ) -> Float[Tensor, " batch"]:
252 | ground_truth = ground_truth.clip(min=0, max=1)
253 | predicted = predicted.clip(min=0, max=1)
254 | mse = reduce((ground_truth - predicted) ** 2, "b c h w -> b", "mean")
255 | return -10 * mse.log10()
256 |
257 | @torch.no_grad()
258 | def compute_lpips(
259 | self,
260 | ground_truth: Float[Tensor, "batch channel height width"],
261 | predicted: Float[Tensor, "batch channel height width"],
262 | ) -> Float[Tensor, " batch"]:
263 | value = self.lpips.forward(ground_truth, predicted, normalize=True)
264 | return value[:, 0, 0, 0]
265 |
266 | @torch.no_grad()
267 | def compute_ssim(
268 | self,
269 | ground_truth: Float[Tensor, "batch channel height width"],
270 | predicted: Float[Tensor, "batch channel height width"],
271 | ) -> Float[Tensor, " batch"]:
272 | ssim = [
273 | structural_similarity(
274 | gt.detach().cpu().numpy(),
275 | hat.detach().cpu().numpy(),
276 | win_size=11,
277 | gaussian_weights=True,
278 | channel_axis=0,
279 | data_range=1.0,
280 | )
281 | for gt, hat in zip(ground_truth, predicted)
282 | ]
283 | return torch.tensor(ssim, dtype=predicted.dtype, device=predicted.device)
284 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .logger import Logger
2 | from .misc import get_config
3 |
4 | __all__ = ['Logger', 'get_config']
5 |
6 |
7 | import sys
8 |
9 |
10 | _LIBS = ['./external/packnet_sfm', './external/dgp', './external/monodepth2']
11 |
12 | def setup_env():
13 | if not _LIBS[0] in sys.path:
14 | for lib in _LIBS:
15 | sys.path.append(lib)
16 |
17 | setup_env()
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from collections import defaultdict
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | import PIL.Image as pil
9 |
10 | from tensorboardX import SummaryWriter
11 |
12 | from .visualize import colormap
13 | from .misc import pretty_ts, cal_depth_error
14 |
15 |
16 | def set_tb_title(*args):
17 | """
18 | This function sets title for tensorboard plot.
19 | """
20 | title = ''
21 | for i, s in enumerate(args):
22 | if not i%2: title += '/'
23 | s = s if isinstance(s, str) else str(s)
24 | title += s
25 | return title[1:]
26 |
27 |
28 | def resize_for_tb(image):
29 | """
30 | This function resizes images for tensorboard plot.
31 | """
32 | h, w = image.size()[-2:]
33 | return F.interpolate(image, [h//2, w//2], mode='bilinear', align_corners=True)
34 |
35 |
36 | def plot_tb(writer, step, img, title, j=0):
37 | """
38 | This function plots images on tensotboard.
39 | """
40 | img_resized = resize_for_tb(img)
41 | writer.add_image(title, img_resized[j].data, step)
42 |
43 |
44 | def plot_norm_tb(writer, step, img, title, j=0):
45 | """
46 | This function plots normalized images on tensotboard.
47 | """
48 | img_resized = torch.clamp(resize_for_tb(img), 0., 1.)
49 | writer.add_image(title, img_resized[j].data, step)
50 |
51 |
52 | def plot_disp_tb(writer, step, disp, title, j=0):
53 | """
54 | This function plots disparity maps on tensotboard.
55 | """
56 | disp_resized = resize_for_tb(disp).float()
57 | disp_resized = colormap(disp_resized[j, 0])
58 | writer.add_image(title, disp_resized, step)
59 |
60 |
61 | class Logger:
62 | """
63 | Logger class to monitor training
64 | """
65 | def __init__(self, cfg, use_tb=True):
66 | self.read_config(cfg)
67 | os.makedirs(self.log_path, exist_ok=True)
68 |
69 | if use_tb:
70 | self.init_tb()
71 |
72 | if self.eval_visualize:
73 | self.init_vis()
74 |
75 | self._metric_names = ['abs_rel', 'sq_rel', 'rms', 'log_rms', 'a1', 'a2', 'a3']
76 |
77 | def read_config(self, cfg):
78 | for attr in cfg.keys():
79 | for k, v in cfg[attr].items():
80 | setattr(self, k, v)
81 |
82 | def init_tb(self):
83 | self.writers = {}
84 | for mode in ['train', 'val']:
85 | self.writers[mode] = SummaryWriter(os.path.join(self.log_path, mode))
86 |
87 | def close_tb(self):
88 | for mode in ['train', 'val']:
89 | self.writers[mode].close()
90 |
91 | def init_vis(self):
92 | vis_path = os.path.join(self.log_path, 'vis_results')
93 | os.makedirs(vis_path, exist_ok=True)
94 |
95 | self.cam_paths = []
96 | for cam_id in range(self.num_cams):
97 | cam_path = os.path.join(vis_path, f'cam{cam_id:d}')
98 | os.makedirs(cam_path, exist_ok=True)
99 | self.cam_paths.append(cam_path)
100 |
101 | def get_metric_names(self):
102 | return self._metric_names
103 |
104 | def to_disp(self, depth_in, K_in):
105 | """
106 | This function transforms depth values into disparity values while dividing the value by the focal length.
107 | """
108 | disp = depth_in * self.focal_length_scale / K_in[:, 0:1, 0:1].unsqueeze(2)
109 | disp = 1 / (disp + 0.00001)
110 |
111 | min_disp = 1/self.max_depth
112 | max_disp = 1/self.min_depth
113 | disp_range = max_disp - min_disp
114 |
115 | disp = (disp - min_disp) / disp_range
116 |
117 | return disp
118 |
119 | def update(self, mode, epoch, world_size, batch_idx, step, start_time, before_op_time, inputs, outputs, losses):
120 | """
121 | Display logs with respect to the log frequency
122 | """
123 | # iteration duration
124 | duration = time.time() - before_op_time
125 |
126 | if self.is_checkpoint(step):
127 | self.log_time(epoch, batch_idx * world_size, duration, losses, start_time)
128 | self.log_tb(mode, inputs, outputs, losses, step)
129 |
130 | def is_checkpoint(self, step):
131 | """
132 | Log less frequently after the early phase steps
133 | """
134 | early_phase = (step % self.log_frequency == 0) and (step < self.early_phase)
135 | late_phase = step % self.late_log_frequency == 0
136 | return (early_phase or late_phase)
137 |
138 | def log_time(self, epoch, batch_idx, duration, loss, start_time):
139 | """
140 | This function prints epoch, iteration, duration, loss and spent time.
141 | """
142 | rep_loss = loss['total_loss'].item()
143 | samples_per_sec = self.batch_size / duration
144 | time_sofar = time.time() - start_time
145 | print("")
146 | print(f'epoch: {epoch:2d} | batch: {batch_idx:6d} |' + \
147 | f'examples/s: {samples_per_sec:5.1f} | loss: {rep_loss:.3f} | time elapsed: {pretty_ts(time_sofar)}')
148 |
149 | def log_tb(self, mode, inputs, outputs, losses, step):
150 | """
151 | This function logs outputs for monitoring using tensorboard.
152 | """
153 | writer = self.writers[mode]
154 | # loss
155 | for l, v in losses.items():
156 | writer.add_scalar(f'{l}', v, step)
157 |
158 | scale = 0 # plot the maximum scale
159 | for cam_id in range(self.num_cams):
160 | target_view = outputs[('cam', cam_id)]
161 |
162 | plot_tb(writer, step, inputs[('color', 0, scale)][:, cam_id, ...], set_tb_title('cam', cam_id, 'frame_0')) # frame_id 0
163 | plot_tb(writer, step, inputs[('color', 1, scale)][:, cam_id, ...], set_tb_title('cam', cam_id, 'frame_1')) # frame_id 1
164 | plot_tb(writer, step, inputs[('color', -1, scale)][:, cam_id, ...], set_tb_title('cam', cam_id, 'frame_-1')) # frame_id -1
165 | plot_disp_tb(writer, step, target_view[('disp', scale)], set_tb_title('cam', cam_id, 'disp')) # disparity
166 | depth_gt = inputs['depth'][:, cam_id, ...]
167 | far = depth_gt == 0
168 | depth_gt[far] += 80.0
169 | disp_gt = self.to_disp(depth_gt, inputs[('K', 0)][:, cam_id, ...])
170 | plot_disp_tb(writer, step, disp_gt, set_tb_title('cam', cam_id, 'disp_gt'))
171 |
172 | plot_tb(writer, step, target_view[('reproj_loss', scale)], set_tb_title('cam', cam_id, 'reproj')) # reprojection image
173 | plot_tb(writer, step, target_view[('reproj_mask', scale)], set_tb_title('cam', cam_id, 'reproj_mask')) # reprojection mask
174 | plot_tb(writer, step, inputs['mask'][:, cam_id, ...], set_tb_title('cam', cam_id, 'self_occ_mask'))
175 |
176 | if self.spatio:
177 | plot_norm_tb(writer, step, target_view[('overlap', 0, scale)], set_tb_title('cam', cam_id, 'sp'))
178 | plot_tb(writer, step, target_view[('overlap_mask', 0, scale)], set_tb_title('cam', cam_id, 'sp_mask'))
179 |
180 | if self.spatio_temporal:
181 | for frame_id in self.frame_ids:
182 | if frame_id == 0:
183 | continue
184 | plot_norm_tb(writer, step, target_view[('color', frame_id, scale)], set_tb_title('cam', cam_id, 'pred_', frame_id))
185 | plot_norm_tb(writer, step, target_view[('overlap', frame_id, scale)], set_tb_title('cam', cam_id, 'sp_tm_', frame_id))
186 | plot_tb(writer, step, target_view[('overlap_mask', frame_id, scale)], set_tb_title('cam', cam_id, 'sp_tm_mask_', frame_id))
187 |
188 | if self.gaussian:
189 | if self.novel_view_mode == 'SF':
190 | frame_id = 1
191 | elif self.novel_view_mode == 'MF':
192 | frame_id = 0
193 | else:
194 | raise ValueError(f'Novel view mode {self.novel_view_mode} not supported.')
195 | plot_norm_tb(writer, step, target_view[('gaussian_color', frame_id, scale)], set_tb_title('cam', cam_id, f'gaussian_pred_frame_{frame_id}'))
196 |
197 | def log_result(self, inputs, outputs, idx, syn_visualize=False):
198 | """
199 | This function logs outputs for visualization.
200 | """
201 | scale = 0
202 | for cam_id in range(self.num_cams):
203 | target_view = outputs[('cam', cam_id)]
204 | disps = target_view['disp', scale]
205 | for jdx, disp in enumerate(disps):
206 | disp = colormap(disp)[0,...].transpose(1,2,0)
207 | disp = pil.fromarray((disp * 255).astype(np.uint8))
208 | cur_idx = idx*self.batch_size + jdx
209 | disp.save(os.path.join(self.cam_paths[cam_id], f'{cur_idx:03d}_disp.jpg'))
210 |
211 | def compute_depth_losses(self, inputs, outputs, vis_scale=False):
212 | """
213 | This function computes depth metrics, to allow monitoring of training process on validation dataset.
214 | """
215 | min_eval_depth = self.eval_min_depth
216 | max_eval_depth = self.eval_max_depth
217 |
218 | med_scale = []
219 |
220 | error_metric_dict = defaultdict(float)
221 | error_median_dict = defaultdict(float)
222 |
223 | for cam in range(self.num_cams):
224 | target_view = outputs['cam', cam]
225 |
226 | depth_gt = inputs['depth'][:, cam, ...]
227 |
228 | _, _, h, w = depth_gt.shape
229 |
230 | depth_pred = target_view[('depth', 0, 0)].to(depth_gt.device)
231 | depth_pred = torch.clamp(F.interpolate(
232 | depth_pred, [h, w], mode='bilinear', align_corners=False),
233 | min_eval_depth, max_eval_depth)
234 | depth_pred = depth_pred.detach()
235 |
236 | mask = (depth_gt > min_eval_depth) * (depth_gt < max_eval_depth) * inputs['mask'][:, cam, ...]
237 | mask = mask.bool()
238 |
239 | depth_gt = depth_gt[mask]
240 | depth_pred = depth_pred[mask]
241 |
242 | # calculate median scale
243 | scale_val = torch.median(depth_gt) / torch.median(depth_pred)
244 | med_scale.append(round(scale_val.cpu().numpy().item(), 2))
245 |
246 | depth_pred_metric = torch.clamp(depth_pred, min=min_eval_depth, max=max_eval_depth)
247 | depth_errors_metric = cal_depth_error(depth_pred_metric, depth_gt)
248 |
249 | depth_pred_median = torch.clamp(depth_pred * scale_val, min=min_eval_depth, max=max_eval_depth)
250 | depth_errors_median = cal_depth_error(depth_pred_median, depth_gt)
251 |
252 | for i in range(len(depth_errors_metric)):
253 | key = self._metric_names[i]
254 | error_metric_dict[key] += depth_errors_metric[i]
255 | error_median_dict[key] += depth_errors_median[i]
256 |
257 | if vis_scale==True:
258 | # print median scale
259 | print(f' | median scale = {med_scale}')
260 |
261 | for key in error_metric_dict.keys():
262 | error_metric_dict[key] = error_metric_dict[key].cpu().numpy() / self.num_cams
263 | error_median_dict[key] = error_median_dict[key].cpu().numpy() / self.num_cams
264 |
265 | return error_metric_dict, error_median_dict
266 |
267 | def print_perf(self, loss, scale):
268 | """
269 | This function prints various metrics for depth estimation accuracy.
270 | """
271 | perf = ' '*3 + scale
272 | for k, v in loss.items():
273 | perf += ' | ' + str(k) + f': {v:.3f}'
274 | print(perf)
275 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | from collections import defaultdict
4 |
5 | import torch
6 |
7 | _NUSC_CAM_LIST = ['CAM_FRONT', 'CAM_FRONT_LEFT', 'CAM_FRONT_RIGHT', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT', 'CAM_BACK']
8 | _REL_CAM_DICT = {0: [1,2], 1: [0,3], 2: [0,4], 3: [1,5], 4: [2,5], 5: [3,4]}
9 |
10 |
11 | def camera2ind(cameras):
12 | """
13 | This function transforms camera name list to indices
14 | """
15 | indices = []
16 | for cam in cameras:
17 | if cam in _NUSC_CAM_LIST:
18 | ind = _NUSC_CAM_LIST.index(cam)
19 | else:
20 | ind = None
21 | indices.append(ind)
22 | return indices
23 |
24 |
25 | def get_relcam(cameras):
26 | """
27 | This function returns relative camera indices from given camera list
28 | """
29 | relcam_dict = defaultdict(list)
30 | indices = camera2ind(cameras)
31 | for ind in indices:
32 | relcam_dict[ind] = []
33 | relcam_cand = _REL_CAM_DICT[ind]
34 | for cand in relcam_cand:
35 | if cand in indices:
36 | relcam_dict[ind].append(cand)
37 | return relcam_dict
38 |
39 |
40 | def get_config(config, mode='train', weight_path='./', novel_view_mode='MF'):
41 | """
42 | This function reads the configuration file and return as dictionary
43 | """
44 | with open(config, 'r') as stream:
45 | cfg = yaml.load(stream, Loader=yaml.FullLoader)
46 |
47 | cfg_name = os.path.splitext(os.path.basename(config))[0]
48 | print('Experiment: ', cfg_name)
49 |
50 | _log_path = os.path.join(cfg['data']['log_dir'], cfg_name)
51 | cfg['data']['log_path'] = _log_path
52 | cfg['data']['save_weights_root'] = os.path.join(_log_path, 'models')
53 | cfg['data']['num_cams'] = len(cfg['data']['cameras'])
54 | cfg['data']['rel_cam_list'] = get_relcam(cfg['data']['cameras'])
55 |
56 | cfg['model']['mode'] = mode
57 | cfg['model']['novel_view_mode'] = novel_view_mode
58 |
59 | cfg['load']['load_weights_dir'] = weight_path
60 |
61 | if mode == 'eval':
62 | cfg['ddp']['world_size'] = 1
63 | cfg['ddp']['gpus'] = [0]
64 | cfg['training']['batch_size'] = cfg['eval']['eval_batch_size']
65 | cfg['training']['depth_flip'] = False
66 | return cfg
67 |
68 |
69 | def pretty_ts(ts):
70 | """
71 | This function prints amount of time taken in user friendly way.
72 | """
73 | second = int(ts)
74 | minute = second // 60
75 | hour = minute // 60
76 | return f'{hour:02d}h{(minute%60):02d}m{(second%60):02d}s'
77 |
78 |
79 | def cal_depth_error(pred, target):
80 | """
81 | This function calculates depth error using various metrics.
82 | """
83 | abs_rel = torch.mean(torch.abs(pred-target) / target)
84 | sq_rel = torch.mean((pred-target).pow(2) / target)
85 | rmse = torch.sqrt(torch.mean((pred-target).pow(2)))
86 | rmse_log = torch.sqrt(torch.mean((torch.log(target) - torch.log(pred)).pow(2)))
87 |
88 | thresh = torch.max((target/pred), (pred/ target))
89 | a1 = (thresh < 1.25).float().mean()
90 | a2 = (thresh < 1.25**2).float().mean()
91 | a3 = (thresh < 1.25**3).float().mean()
92 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 | import torch
4 |
5 | _DEGTORAD = 0.0174533
6 |
7 |
8 | def aug_depth_params(K, n_steps= 75):
9 | """
10 | This function augments camera parameters for depth synthesis.
11 | """
12 | # augmented parameters for visualization
13 | aug_params = []
14 |
15 | # roll augmentations
16 | roll_aug = [i for i in range(0, n_steps + 1, 2)] + [i for i in range(n_steps, -n_steps - 1, -2)] + [i for i in range(-n_steps, 1, 2)]
17 | ang_y, ang_z = 0.0, 0.0
18 | for angle in roll_aug:
19 | ang_x = _DEGTORAD * (angle / n_steps * 10.)
20 | aug_params.append([torch.inverse(K), ang_x, ang_y, ang_z])
21 |
22 | # pitch augmentations
23 | pitch_aug = [i for i in range(0, 50 + 1, 2)] + [i for i in range(50, -50 - 1, -2)] + [i for i in range(-50, 1, 2)]
24 | ang_x, ang_z = 0.0, 0.0
25 | for angle in pitch_aug:
26 | ang_y = _DEGTORAD * (angle / 10.)
27 | aug_params.append([torch.inverse(K), ang_x, ang_y, ang_z])
28 |
29 | # focal length augmentations
30 | focal_ratio = K[:, 1, 0, 0] / K[:, 0, 0, 0]
31 | focal_ratio_aug = focal_ratio / 1.5
32 | ang_x, ang_y, ang_z = 0.0, 0.0, 0.0
33 |
34 | for f_idx in range(100 + 1):
35 | f_scale = (f_idx / 100. * focal_ratio_aug + (1 - f_idx / 100.))[:, None]
36 | K_aug = K.clone()
37 | K_aug[:, :, 0, 0] *= f_scale
38 | K_aug[:, :, 1, 1] *= f_scale
39 | aug_params.append([torch.inverse(K_aug), ang_x, ang_y, ang_z])
40 |
41 | for f_idx in range(50 + 1):
42 | f_scale = (f_idx / 50. * focal_ratio + (1 - f_idx / 50.) * focal_ratio_aug)[:, None]
43 | K_aug = K.clone()
44 | K_aug[:, :, 0, 0] *= f_scale
45 | K_aug[:, :, 1, 1] *= f_scale
46 | aug_params.append([torch.inverse(K_aug), ang_x, ang_y, ang_z])
47 |
48 | # yaw augmentations
49 | yaw_aug = [i for i in range(360)]
50 | inv_K_aug = torch.inverse(K_aug)
51 | ang_x, ang_y = 0.0, 0.0
52 | for i in yaw_aug:
53 | ratio_i = i / 360.
54 | ang_z = _DEGTORAD * 360 * ratio_i
55 | aug_params.append([inv_K_aug, ang_x, ang_y, ang_z])
56 | return aug_params
57 |
58 |
59 | def colormap(vis, normalize=True, torch_transpose=True):
60 | """
61 | This function visualizes disparity map using colormap specified with disparity map variable.
62 | """
63 | disparity_map = plt.get_cmap('plasma', 256) # for plotting
64 |
65 | if isinstance(vis, torch.Tensor):
66 | vis = vis.detach().cpu().numpy()
67 |
68 | if normalize:
69 | ma = float(vis.max())
70 | mi = float(vis.min())
71 | d = ma - mi if ma != mi else 1e5
72 | vis = (vis - mi) / d
73 |
74 | if vis.ndim == 4:
75 | vis = vis.transpose([0, 2, 3, 1])
76 | vis = disparity_map(vis)
77 | vis = vis[:, :, :, 0, :3]
78 | if torch_transpose:
79 | vis = vis.transpose(0, 3, 1, 2)
80 | elif vis.ndim == 3:
81 | vis = disparity_map(vis)
82 | vis = vis[:, :, :, :3]
83 | if torch_transpose:
84 | vis = vis.transpose(0, 3, 1, 2)
85 | elif vis.ndim == 2:
86 | vis = disparity_map(vis)
87 | vis = vis[..., :3]
88 | if torch_transpose:
89 | vis = vis.transpose(2, 0, 1)
90 | return vis
--------------------------------------------------------------------------------