├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── assets └── framework.png ├── configs └── nuscenes │ └── main.yaml ├── dataset ├── __init__.py ├── base_dataset.py ├── data_util.py ├── nuscenes │ ├── eval_MF.txt │ ├── eval_SF.txt │ └── train.txt ├── nuscenes_dataset.py └── nuscenes_mask │ ├── CAM_BACK_LEFT_mask.png │ ├── CAM_BACK_RIGHT_mask.png │ ├── CAM_BACK_mask.png │ ├── CAM_FRONT_LEFT_mask.png │ ├── CAM_FRONT_RIGHT_mask.png │ └── CAM_FRONT_mask.png ├── eval.py ├── external ├── dataset │ └── __init__.py ├── layers │ └── __init__.py └── utils │ └── __init__.py ├── models ├── __init__.py ├── base_model.py ├── drivingforward_model.py ├── gaussian │ ├── GaussianRender.py │ ├── __init__.py │ ├── extractor.py │ ├── gaussian_network.py │ ├── gaussian_renderer │ │ └── __init__.py │ └── utils.py ├── geometry │ ├── __init__.py │ ├── geometry_util.py │ ├── pose.py │ └── view_rendering.py └── losses │ ├── __init__.py │ ├── base_loss.py │ ├── loss_util.py │ ├── multi_cam_loss.py │ └── single_cam_loss.py ├── network ├── __init__.py ├── blocks.py ├── depth_network.py ├── pose_network.py └── volumetric_fusionnet.py ├── requirements.txt ├── train.py ├── trainer ├── __init__.py └── trainer.py └── utils ├── __init__.py ├── logger.py ├── misc.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | ._.DS_Store 3 | .__pycache__ 4 | */__pycache__/* 5 | **/__pycache__/ 6 | .ipynb_checkpoints 7 | */.ipynb_checkpoints/* 8 | results 9 | weights_SF 10 | weights_MF 11 | *.pyc 12 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "external/packnet_sfm"] 2 | path = external/packnet_sfm 3 | url = https://github.com/TRI-ML/packnet-sfm 4 | [submodule "external/dgp"] 5 | path = external/dgp 6 | url = https://github.com/TRI-ML/dgp 7 | branch = v1.5 8 | [submodule "models/gaussian/gaussian-splatting"] 9 | path = models/gaussian/gaussian-splatting 10 | url = https://github.com/graphdeco-inria/gaussian-splatting 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Qijian Tian 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |

DrivingForward: Feed-forward 3D Gaussian Splatting for Driving Scene Reconstruction from Flexible Surround-view Input

3 |

4 | Qijian Tian1 5 |  ·  6 | Xin Tan2 7 |  ·  8 | Yuan Xie2 9 |  ·  10 | Lizhuang Ma1 11 |

12 |

13 | 1Shanghai Jiao Tong University 14 |
15 | 2East China Normal University 16 |

17 |

AAAI 2025

18 |

Paper | Project Page | Pretrained Models

19 |

20 | 21 | ## Introduction 22 | 23 | We propose a feed-forward Gaussian Splatting model that reconstructs driving scenes from flexible sparse surround-view input. 24 | 25 | 26 | 27 | Given sparse surround-view input from vehicle-mounted cameras, our model learns 28 | scale-aware localization for Gaussian primitives from the small overlap of spatial and temporal context views. A Gaussian 29 | network predicts other parameters from each image individually. This feed-forward pipeline enables the real-time reconstruction 30 | of driving scenes and the independent prediction from single-frame images supports flexible input modes. At the inference stage, 31 | we include only the depth network and the Gaussian network, as shown in the lower part of the figure. 32 | 33 | ## Installation 34 | 35 | To get started, clone this project, create a conda virtual environment using Python 3.8, and install the requirements: 36 | 37 | ```bash 38 | git clone https://github.com/fangzhou2000/DrivingForward 39 | git submodule update --init --recursive 40 | cd DrivingForward 41 | conda create -n DrivingForward python=3.8 42 | conda activate DrivingForward 43 | pip install torch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 --index-url https://download.pytorch.org/whl/cu113 44 | pip install -r requirements.txt 45 | cd models/gaussian/gaussian-splatting 46 | pip install submodules/diff-gaussian-rasterization 47 | cd ../../.. 48 | ``` 49 | 50 | ## Datasets 51 | 52 | ### nuScenes 53 | * Download [nuScenes](https://www.nuscenes.org/nuscenes) official dataset 54 | * Place the dataset in `input_data/nuscenes/` 55 | 56 | Data should be as follows: 57 | ``` 58 | ├── input_data 59 | │ ├── nuscenes 60 | │ │ ├── maps 61 | │ │ ├── samples 62 | │ │ ├── sweeps 63 | │ │ ├── v1.0-test 64 | │ │ ├── v1.0-trainval 65 | ``` 66 | 67 | ## Running the Code 68 | 69 | ### Evaluation 70 | 71 | Get the [pretrained models](https://drive.google.com/drive/folders/1IASOPK1RQeP-nLQvJUn7WQUtb_fwGlVS), save them to the root directory of the project, and unzip them. 72 | 73 | For SF mode, run the following: 74 | ``` 75 | python -W ignore eval.py --weight_path ./weights_SF --novel_view_mode SF 76 | ``` 77 | 78 | 79 | For MF mode, run the following: 80 | ``` 81 | python -W ignore eval.py --weight_path ./weights_MF --novel_view_mode MF 82 | ``` 83 | 84 | ### Training 85 | 86 | For SF mode, run the following: 87 | ``` 88 | python -W ignore train.py --novel_view_mode SF 89 | ``` 90 | 91 | For MF mode, run the following: 92 | ``` 93 | python -W ignore train.py --novel_view_mode MF 94 | ``` 95 | 96 | ## BibTeX 97 | ``` 98 | @inproceedings{tian2025drivingforward, 99 | title={DrivingForward: Feed-forward 3D Gaussian Splatting for Driving Scene Reconstruction from Flexible Surround-view Input}, 100 | author={Qijian Tian and Xin Tan and Yuan Xie and Lizhuang Ma}, 101 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 102 | year={2025} 103 | } 104 | ``` 105 | 106 | ## Acknowledgements 107 | 108 | The project is partially based on some awesome repos: [MVSplat](https://github.com/donydchen/mvsplat), [GPS-Gaussian](https://github.com/aipixel/GPS-Gaussian), and [VFDepth](https://github.com/42dot/VFDepth). Many thanks to these projects for their excellent contributions! 109 | -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fangzhou2000/DrivingForward/5d0a2c7d6a6358f55318faac80e351cf75dff032/assets/framework.png -------------------------------------------------------------------------------- /configs/nuscenes/main.yaml: -------------------------------------------------------------------------------- 1 | ddp: 2 | ddp_enable: False 3 | world_size: 1 4 | gpus: [0] 5 | 6 | model: 7 | # mode: 'MF' or 'SF' 8 | novel_view_mode: 'MF' 9 | 10 | num_layers: 18 11 | weights_init: True 12 | 13 | fusion_level: 2 14 | fusion_feat_in_dim: 256 15 | use_skips: False 16 | 17 | voxel_unit_size: [1.0, 1.0, 1.5] 18 | voxel_size: [100, 100, 20] 19 | voxel_str_p: [-50.0, -50.0, -15.0] 20 | voxel_pre_dim: [64] 21 | proj_d_bins: 50 22 | proj_d_str: 2 23 | proj_d_end: 50 24 | 25 | data: 26 | data_path: './input_data/nuscenes/' 27 | log_dir: './results/' 28 | dataset: 'nuscenes' 29 | back_context: 1 30 | forward_context: 1 31 | depth_type: 'lidar' 32 | cameras: ['CAM_FRONT', 'CAM_FRONT_LEFT', 'CAM_FRONT_RIGHT', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT', 'CAM_BACK'] 33 | # gt_ego_pose, gt_depth is not used during training, only for visulization 34 | train_requirements: (gt_pose, gt_ego_pose, gt_depth, mask) 35 | val_requirements: (gt_pose, gt_ego_pose, gt_depth, mask) 36 | 37 | training: 38 | # basic 39 | height: 352 40 | width: 640 41 | scales: [0] 42 | frame_ids: [0, -1, 1] 43 | 44 | # optimization 45 | batch_size: 1 46 | num_workers: 4 47 | learning_rate: 0.0001 48 | num_epochs: 10 49 | scheduler_step_size: 15 50 | 51 | # model / loss setting 52 | ## depth range 53 | min_depth: 1.5 54 | max_depth: 80.0 55 | 56 | ## spatio & temporal 57 | spatio: True 58 | spatio_temporal: True 59 | 60 | ## gaussian 61 | gaussian: True 62 | 63 | ## intensity align 64 | intensity_align: True 65 | 66 | ## focal length scaling 67 | focal_length_scale: 300 68 | 69 | # Loss hyperparams 70 | loss: 71 | disparity_smoothness: 0.001 72 | spatio_coeff: 0.03 73 | spatio_tempo_coeff: 0.1 74 | gaussian_coeff: 0.01 75 | pose_loss_coeff: 0.0 76 | 77 | eval: 78 | eval_batch_size: 1 79 | eval_num_workers: 4 80 | eval_min_depth: 0 81 | eval_max_depth: 80 82 | eval_visualize: False 83 | save_images: False 84 | save_path: './results/images' 85 | 86 | load: 87 | pretrain: False 88 | models_to_load: ['depth_net', 'gs_net'] 89 | load_dir: path to weights directory 90 | 91 | logging: 92 | early_phase: 5000 93 | log_frequency: 1000 94 | late_log_frequency: 5000 95 | save_frequency: 1 96 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import construct_dataset 2 | 3 | __all__ = ['construct_dataset'] -------------------------------------------------------------------------------- /dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | from external.dataset import get_transforms 2 | 3 | 4 | def construct_dataset(cfg, mode, **kwargs): 5 | """ 6 | This function constructs datasets. 7 | """ 8 | # dataset arguments for the dataloader 9 | if mode == 'train': 10 | dataset_args = { 11 | 'cameras': cfg['data']['cameras'], 12 | 'back_context': cfg['data']['back_context'], 13 | 'forward_context': cfg['data']['forward_context'], 14 | 'data_transform': get_transforms('train', **kwargs), 15 | 'depth_type': cfg['data']['depth_type'] if 'gt_depth' in cfg['data']['train_requirements'] else None, 16 | 'with_pose': 'gt_pose' in cfg['data']['train_requirements'], 17 | 'with_ego_pose': 'gt_ego_pose' in cfg['data']['train_requirements'], 18 | 'with_mask': 'mask' in cfg['data']['train_requirements'] 19 | } 20 | 21 | elif mode == 'val' or mode == 'eval': 22 | dataset_args = { 23 | 'cameras': cfg['data']['cameras'], 24 | 'back_context': cfg['data']['back_context'], 25 | 'forward_context': cfg['data']['forward_context'], 26 | 'data_transform': get_transforms('train', **kwargs), # for aligning inputs without any augmentations 27 | 'depth_type': cfg['data']['depth_type'] if 'gt_depth' in cfg['data']['val_requirements'] else None, 28 | 'with_pose': 'gt_pose' in cfg['data']['val_requirements'], 29 | 'with_ego_pose': 'gt_ego_pose' in cfg['data']['val_requirements'], 30 | 'with_mask': 'mask' in cfg['data']['val_requirements'] 31 | } 32 | 33 | # NuScenes dataset 34 | if cfg['data']['dataset'] == 'nuscenes': 35 | from dataset.nuscenes_dataset import NuScenesdataset 36 | if mode == 'train': 37 | split = 'train' 38 | else: 39 | if cfg['model']['novel_view_mode'] == 'MF': 40 | split = 'eval_MF' 41 | elif cfg['model']['novel_view_mode'] == 'SF': 42 | split = 'eval_SF' 43 | else: 44 | raise ValueError('Unknown novel view mode: ' + cfg['model']['novel_view_mode']) 45 | dataset = NuScenesdataset( 46 | cfg['data']['data_path'], split, 47 | **dataset_args 48 | ) 49 | else: 50 | raise ValueError('Unknown dataset: ' + cfg['data']['dataset']) 51 | return dataset -------------------------------------------------------------------------------- /dataset/data_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import PIL.Image as pil 5 | 6 | import torch.nn.functional as F 7 | import torchvision.transforms as transforms 8 | 9 | _DEL_KEYS= ['rgb', 'rgb_context', 'rgb_original', 'rgb_context_original', 'intrinsics', 'contexts', 'splitname', 'ego_pose'] 10 | 11 | 12 | def transform_mask_sample(sample, data_transform): 13 | """ 14 | This function transforms masks to match input rgb images. 15 | """ 16 | image_shape = data_transform.keywords['image_shape'] 17 | # resize transform 18 | resize_transform = transforms.Resize(image_shape, interpolation=pil.ANTIALIAS) 19 | sample['mask'] = resize_transform(sample['mask']) 20 | # totensor transform 21 | tensor_transform = transforms.ToTensor() 22 | sample['mask'] = tensor_transform(sample['mask']) 23 | return sample 24 | 25 | 26 | def img_loader(path): 27 | """ 28 | This function loads rgb image. 29 | """ 30 | with open(path, 'rb') as f: 31 | with pil.open(f) as img: 32 | return img.convert('RGB') 33 | 34 | 35 | def mask_loader_scene(path, mask_idx, cam): 36 | """ 37 | This function loads mask that correspondes to the scene and camera. 38 | """ 39 | fname = os.path.join(path, str(mask_idx), '{}_mask.png'.format(cam.upper())) 40 | with open(fname, 'rb') as f: 41 | with pil.open(f) as img: 42 | return img.convert('L') 43 | 44 | 45 | def align_dataset(sample, scales, contexts): 46 | """ 47 | This function reorganize samples to match our trainer configuration. 48 | """ 49 | K = sample['intrinsics'] 50 | aug_images = sample['rgb'] 51 | aug_contexts = sample['rgb_context'] 52 | org_images = sample['rgb_original'] 53 | org_contexts= sample['rgb_context_original'] 54 | ego_poses = sample['ego_pose'] 55 | 56 | n_cam, _, w, h = aug_images.shape 57 | 58 | # initialize intrinsics 59 | resized_K = np.expand_dims(np.eye(4), 0).repeat(n_cam, axis=0) 60 | resized_K[:, :3, :3] = K 61 | 62 | # augment images and intrinsics in accordance with scales 63 | for scale in scales: 64 | scaled_K = resized_K.copy() 65 | scaled_K[:,:2,:] /= (2**scale) 66 | 67 | sample[('K', scale)] = scaled_K.copy() 68 | sample[('inv_K', scale)]= np.linalg.pinv(scaled_K).copy() 69 | 70 | resized_org = F.interpolate(org_images, 71 | size=(w//(2**scale),h//(2**scale)), 72 | mode = 'bilinear', 73 | align_corners=False) 74 | resized_aug = F.interpolate(aug_images, 75 | size=(w//(2**scale),h//(2**scale)), 76 | mode = 'bilinear', 77 | align_corners=False) 78 | 79 | sample[('color', 0, scale)] = resized_org 80 | sample[('color_aug', 0, scale)] = resized_aug 81 | 82 | # for context data 83 | for idx, frame in enumerate(contexts): 84 | sample[('color', frame, 0)] = org_contexts[idx] 85 | sample[('color_aug',frame, 0)] = aug_contexts[idx] 86 | sample[('cam_T_cam', 0, frame)] = ego_poses[idx] 87 | 88 | # delete unused arrays 89 | for key in list(sample.keys()): 90 | if key in _DEL_KEYS: 91 | del sample[key] 92 | return sample 93 | -------------------------------------------------------------------------------- /dataset/nuscenes/eval_SF.txt: -------------------------------------------------------------------------------- 1 | fd8420396768425eabec9bdddf7e64b6 2 | 88a090c54cc844cc878f00d08274a683 3 | b9c60cfaf1814c8bb363037dec7abd35 4 | 43cfc03be6f842ae8942c995f8e3f2fb 5 | 234b0f37af694571876a745e5bbb5e08 6 | fbd3aa2098234fb1869d34d920e16eda 7 | 134c7f63827c4b858a2a4ce0582d8f6e 8 | f0e4797551024f9487a016ee9d9e29e1 9 | d62b0f9db06440c695781c52ed71909a 10 | e1b5c775be134db696ce22413d195c72 11 | 0f39f34febc84a6689f08599105d1421 12 | 15c59626df6b4f96bb7ef7c5c1797159 13 | 8e46878ced734164aae0e83567aa02a5 14 | 6e0955c28ecb405b98d91b089761dbbb 15 | 19ffc60b7e9f429cbca0dbd5e2fc2e03 16 | 78be3bd5f1f146cdb1b7120bbf8a218c 17 | f9475d1f1f96401c81ce511a93578a96 18 | 1c81b06c812b47788462c31f58b6c1d0 19 | c292d90b90ec4dbd826dfd11aa361df1 20 | b9db804449d34414ba98868e9034abd3 21 | 028a3f3a237e469cb099f6d9c8e7f2f4 22 | d5324133e1d241658a7b763831b3af2e 23 | 806011e8a7114eca971fd0f6b95bcd4b 24 | 3e8750f331d7499e9b5123e9eb70f2e2 25 | f1b7b463072e4630b2c1d5ba8b6cd50b 26 | 8d56b71c09c04ece827107b3c103498f 27 | 04ebf63518914e0883bdffdb0b6c3f4c 28 | ad8905836f364a87a7233eb0aef915ea 29 | 339ead96177c4e338fde8235c188cfaa 30 | 4d0d06979c984a72bfb9e6c5e3a35f84 31 | 80c91574d0174206a74435200aba8ba8 32 | 79070dca796f4ffb868d84f1b067f9c2 33 | 30e55a3ec6184d8cb1944b39ba19d622 34 | adce54dcf6404d37bd170ffc4c2a7836 35 | 51f37263b06c417e94138f2d9348f909 36 | c28c950afda14f7b93ff956799873dff 37 | e261f474caa34797b0759aa2ca8a576d 38 | 1d477a7c21784a98b2e6a9289a958fe2 39 | 040a956bf0c04ad09fffc286cab5bc56 40 | 8fb167b9767d420b9c6368719bafb7b4 41 | 3502d36d2c76428f9b79f2854e96b696 42 | 7e2fd590b46d42cdbe8dfbf035761561 43 | f4b5b6fc59e34da6bd18d4d8e299dab8 44 | 269601cc1e4d45a3a36fbf291d45a447 45 | b98e65c90a7a445a907a3bb1467b551b 46 | 09361337f2bf4e819aac2a310481e62f 47 | 097c7295f13a4746a2c46d036c21bd3e 48 | 418cbb70159341c1ada4f3ba1fb69713 49 | c525629d7b3749d8968b486bb5412adf 50 | c35efe5eaf784275b8c7e31fb50aa902 51 | 4e26b42796024b45a9a852738e3701bb 52 | 6123a0c091d24c39ab274cdbad994bc4 53 | 107dd32140844573924260f9ad9390bf 54 | 5ec2ee2d55cc4582acd0f92476845a74 55 | ce03b6f0c1d2455aa18f6275917485bc 56 | 89c9204978674480aec9d2a9a377eecc 57 | 5b90526901414d2aa5db2b9fdaf858eb 58 | 8687ba92abd3406aa797115b874ebeba 59 | e0b85d628af34d6bbbbb0da216aa5bca 60 | 12a42a64274a4453b166d804b625d61f 61 | 17e6f41d3b1f40d4bc85412fe8b67688 62 | 6cb74a87748b4fab82ad1b0869b27e01 63 | b4b5e858cd9d4cf9bdb01098ea7d3c45 64 | 029af235f15d4f6f80713b778edc3c61 65 | 5c558d97108341ebb7b30b1c841ff57f 66 | 9eebbe06762c47aa8c671008df5fb474 67 | e493c6923a0a4d5fb78b75dcc0ffc5ab 68 | 970f41b6fc4544b388fe7ccd578988a9 69 | 3b86f8218554493f928b39824655872a 70 | f0c47597c226445fa1e14c82d879f9a4 71 | 67d2d74087714e4994a68c7347acf55a 72 | 57f23d2f55e6455696c6d3cc70a4f501 73 | 0fae8aae30d44b4faf7a9854eac4bc09 74 | 14384b3ce9d243ce827bf0c1598dd35f 75 | 6a808b09e5f34d33ba1de76cc8dab423 76 | 78c80690450a4e8aba63e1f2ca27c207 77 | fe5b1b2e8aa2430e90c5c90825f9f1f2 78 | 3bb017b691c34f8f89985c0c2220ccec 79 | 43f05ac397484672b9bdf86dcc0b0289 80 | 3ac231a1ef4c481c9c15921cafa87307 81 | 8b77b8e2cbf144bea91f4c1856b05a7c 82 | ccd7c315063047ddb231fceb7ceff64b 83 | 4cecbedd7c9c406ba5a0b22a3fa2e49c 84 | 5a3a15f4ebdb4e6b83eb58628dc8e83f 85 | e836d06299ce46b884b9187ac6017579 86 | bb10039077c74099b51fb67300391659 87 | 485c78dbfed04fd5a6aad2a88b116fe9 88 | 7949609177914534a89a8c7985ae4db7 89 | 87c3788da01f4f63953f35383d0f5095 90 | 9699d6a8d9384f8885e8c5318bc621ab 91 | e829650532104af6bdb060d75924421f 92 | abad69b7e6194c95b18a054aa082cfb8 93 | 159f9a70a6d64e27a38b3b2668b5a68e 94 | f8e8dd98088d458abb0a664330403168 95 | c1676a2feac74eee8aa38ca3901787d6 96 | 604d12ebcf784c3f945189f79262f19c 97 | 154dbc8467d64780a1acffda0193d46b 98 | 2eef1a0ab4484a1faa73e7c2a048149e 99 | 94cd383effba4e688570233a060a13f4 100 | 63e3bc947eed4823b00e1beadecd3be3 101 | 4e04d32a9949430287c14d49851f6f55 102 | bc94beb70f694ee8a0399c4e7e1fc2d0 103 | 2fb7b5d1eab24b13abb4e0703850dba9 104 | e693989d8ad141608999200df71aca71 105 | a4782f7ade234010b5b50553403f6e79 106 | 3abf81a7c3894000a4c508e6ced0caca 107 | 43da50265dfa4876bdf1efed38e7a261 108 | d056b9bdd56f44669540c6c323042d30 109 | 8487e4c995874bb2aeb41d6484717d10 110 | 1f048fa210c64e80b9d57e3dc757702c 111 | f8de6e9a14d34f2791c21a07ce5d47c5 112 | 7f8368b7a14341b3bef03a7aec63f9b5 113 | b5989651183643369174912bc5641d3b 114 | 8c5c08c818d247ee9de5aa08e6762003 115 | bd0c27135b3c43c6b0e59f03fdf441cb 116 | e8a25125ab404726ab03976d3871c460 117 | 852e22dbdbbf422e99224c488d7618bf 118 | 73c335a2a207457ea36cc855b2417bd9 119 | c07a480048b04e989f6fc1e002e5c5b8 120 | c201dea88033406fb304747d32b823a6 121 | e6fcfc5cd0254eee901c1bbc0a30a6ba 122 | 45eda670b5a840acbb730aac15e63b19 123 | 5d785cd90a884e0ebf9c5047a527df76 124 | fd0915986a6444ee8296894d2a307291 125 | dad9942ea8f1409985140f500c69c88e 126 | e57fa7124f6741e697b1841bc023a90f 127 | b139c133286247d48093bba9151920b2 128 | b2a93b022374432d917cc7db15ca023a 129 | 637fc435f95f4bc1b72f09811f2d78d2 130 | b5cbf927fe0d47c39a9c1d1b05279c55 131 | 23d02c83708b4f5e9579bbd99370a672 132 | aa298fdc84fb489a911a3f363d3beda9 133 | 923b34d42c5048c0a4db75c65de52f3c 134 | 45c4bd03ee6f4a369dd42abf08dc3957 135 | e80bfd61023745b0a97caae3ff4a1727 136 | cb4504edb87a40d4b650cd5860f6c3b7 137 | e74b8ce0b1824e0c8167f2534d8d4f7d 138 | 79f6489272c24d3ebc5e225ce6ff2aea 139 | 617899966d3e4fea9042a1f1a26b764e 140 | 42c250251b504ee8a5f1e49def135715 141 | b62a484b94a04ae4b1b2729fb2064696 142 | dc2def2a7bc94693b67eba2ffbe13a56 143 | 21bb21c84c3e43908cd554db2686278f 144 | b6a9d120895e4207a0527ad4b3284d42 145 | 2920a3315ccd4fefb30fb0468ebb2351 146 | b58985f7ddc2454d95cb84a65d1fe9fe 147 | 23779301ebc34c1284e539ddf057f0b4 148 | 7e351605549547f8af19432917f23b73 149 | f9ad282e31ce47d5a16dffc3b25c456a 150 | 127c5ddf95e947ef96888e56d5a355f6 -------------------------------------------------------------------------------- /dataset/nuscenes_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import PIL.Image as pil 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | from nuscenes.nuscenes import NuScenes 10 | from pyquaternion import Quaternion 11 | 12 | from .data_util import img_loader, mask_loader_scene, align_dataset, transform_mask_sample 13 | 14 | 15 | def is_numpy(data): 16 | """Checks if data is a numpy array.""" 17 | return isinstance(data, np.ndarray) 18 | 19 | def is_tensor(data): 20 | """Checks if data is a torch tensor.""" 21 | return type(data) == torch.Tensor 22 | 23 | def is_list(data): 24 | """Checks if data is a list.""" 25 | return isinstance(data, list) 26 | 27 | def stack_sample(sample): 28 | """Stack a sample from multiple sensors""" 29 | # If there is only one sensor don't do anything 30 | if len(sample) == 1: 31 | return sample[0] 32 | 33 | # Otherwise, stack sample 34 | stacked_sample = {} 35 | for key in sample[0]: 36 | # Global keys (do not stack) 37 | if key in ['idx', 'dataset_idx', 'sensor_name', 'filename', 'token']: 38 | stacked_sample[key] = sample[0][key] 39 | else: 40 | # Stack torch tensors 41 | if is_tensor(sample[0][key]): 42 | stacked_sample[key] = torch.stack([s[key] for s in sample], 0) 43 | # Stack numpy arrays 44 | elif is_numpy(sample[0][key]): 45 | stacked_sample[key] = np.stack([s[key] for s in sample], 0) 46 | # Stack list 47 | elif is_list(sample[0][key]): 48 | stacked_sample[key] = [] 49 | # Stack list of torch tensors 50 | if is_tensor(sample[0][key][0]): 51 | for i in range(len(sample[0][key])): 52 | stacked_sample[key].append( 53 | torch.stack([s[key][i] for s in sample], 0)) 54 | # Stack list of numpy arrays 55 | if is_numpy(sample[0][key][0]): 56 | for i in range(len(sample[0][key])): 57 | stacked_sample[key].append( 58 | np.stack([s[key][i] for s in sample], 0)) 59 | 60 | # Return stacked sample 61 | return stacked_sample 62 | 63 | class NuScenesdataset(Dataset): 64 | """ 65 | Loaders for NuScenes dataset 66 | """ 67 | def __init__(self, path, split, 68 | cameras=None, 69 | back_context=0, 70 | forward_context=0, 71 | data_transform=None, 72 | depth_type=None, 73 | scale_range=2, 74 | with_pose=None, 75 | with_ego_pose=None, 76 | with_mask=None, 77 | ): 78 | super().__init__() 79 | version = 'v1.0-trainval' 80 | self.path = path 81 | self.split = split 82 | self.dataset_idx = 0 83 | 84 | self.cameras = cameras 85 | self.scales = np.arange(scale_range+2) 86 | self.num_cameras = len(cameras) 87 | 88 | self.bwd = back_context 89 | self.fwd = forward_context 90 | 91 | self.has_context = back_context + forward_context > 0 92 | self.data_transform = data_transform 93 | 94 | self.with_depth = depth_type is not None 95 | self.with_pose = with_pose 96 | self.with_ego_pose = with_ego_pose 97 | 98 | self.loader = img_loader 99 | 100 | self.with_mask = with_mask 101 | cur_path = os.path.dirname(os.path.realpath(__file__)) 102 | self.mask_path = os.path.join(cur_path, 'nuscenes_mask') 103 | self.mask_loader = mask_loader_scene 104 | 105 | self.dataset = NuScenes(version=version, dataroot=self.path, verbose=True) 106 | 107 | # list of scenes for training and validation of model 108 | with open('dataset/nuscenes/{}.txt'.format(self.split), 'r') as f: 109 | self.filenames = f.readlines() 110 | 111 | def get_current(self, key, cam_sample): 112 | """ 113 | This function returns samples for current contexts 114 | """ 115 | # get current timestamp rgb sample 116 | if key == 'rgb': 117 | rgb_path = cam_sample['filename'] 118 | return self.loader(os.path.join(self.path, rgb_path)) 119 | # get current timestamp camera intrinsics 120 | elif key == 'intrinsics': 121 | cam_param = self.dataset.get('calibrated_sensor', 122 | cam_sample['calibrated_sensor_token']) 123 | return np.array(cam_param['camera_intrinsic'], dtype=np.float32) 124 | # get current timestamp camera extrinsics 125 | elif key == 'extrinsics': 126 | cam_param = self.dataset.get('calibrated_sensor', 127 | cam_sample['calibrated_sensor_token']) 128 | return self.get_tranformation_mat(cam_param) 129 | else: 130 | raise ValueError('Unknown key: ' +key) 131 | 132 | def get_context(self, key, cam_sample): 133 | """ 134 | This function returns samples for backward and forward contexts 135 | """ 136 | bwd_context, fwd_context = [], [] 137 | if self.bwd != 0: 138 | if self.split == 'eval_SF': # validation 139 | bwd_sample = cam_sample 140 | else: 141 | bwd_sample = self.dataset.get('sample_data', cam_sample['prev']) 142 | bwd_context = [self.get_current(key, bwd_sample)] 143 | 144 | if self.fwd != 0: 145 | fwd_sample = self.dataset.get('sample_data', cam_sample['next']) 146 | fwd_context = [self.get_current(key, fwd_sample)] 147 | return bwd_context + fwd_context 148 | 149 | def get_cam_T_cam(self, cam_sample): 150 | # cam 0 to world 151 | # cam 0 to ego 0 152 | cam_to_ego = self.dataset.get( 153 | 'calibrated_sensor', cam_sample['calibrated_sensor_token']) 154 | cam_to_ego_rotation = Quaternion(cam_to_ego['rotation']) 155 | cam_to_ego_translation = np.array(cam_to_ego['translation'])[:, None] 156 | cam_to_ego = np.vstack([ 157 | np.hstack((cam_to_ego_rotation.rotation_matrix, 158 | cam_to_ego_translation)), 159 | np.array([0, 0, 0, 1]) 160 | ]) 161 | 162 | # ego 0 to world 163 | world_to_ego = self.dataset.get( 164 | 'ego_pose', cam_sample['ego_pose_token']) 165 | world_to_ego_rotation = Quaternion(world_to_ego['rotation']).inverse 166 | world_to_ego_translation = - np.array(world_to_ego['translation'])[:, None] 167 | world_to_ego = np.vstack([ 168 | np.hstack((world_to_ego_rotation.rotation_matrix, 169 | world_to_ego_rotation.rotation_matrix @ world_to_ego_translation)), 170 | np.array([0, 0, 0, 1]) 171 | ]) 172 | ego_to_world = np.linalg.inv(world_to_ego) 173 | 174 | cam_T_cam = [] 175 | 176 | # cam_T_cam, 0, -1 177 | if self.bwd != 0: 178 | if self.split == 'eval_SF': # validation 179 | bwd_sample = cam_sample 180 | else: 181 | bwd_sample = self.dataset.get('sample_data', cam_sample['prev']) 182 | 183 | # world to ego -1 184 | world_to_ego_bwd = self.dataset.get( 185 | 'ego_pose', bwd_sample['ego_pose_token']) 186 | world_to_ego_bwd_rotation = Quaternion(world_to_ego_bwd['rotation']).inverse 187 | world_to_ego_bwd_translation = - np.array(world_to_ego_bwd['translation'])[:, None] 188 | world_to_ego_bwd = np.vstack([ 189 | np.hstack((world_to_ego_bwd_rotation.rotation_matrix, 190 | world_to_ego_bwd_rotation.rotation_matrix @ world_to_ego_bwd_translation)), 191 | np.array([0, 0, 0, 1]) 192 | ]) 193 | 194 | # ego -1 to cam -1 195 | cam_to_ego_bwd = self.dataset.get( 196 | 'calibrated_sensor', bwd_sample['calibrated_sensor_token']) 197 | cam_to_ego_bwd_rotation = Quaternion(cam_to_ego_bwd['rotation']) 198 | cam_to_ego_bwd_translation = np.array(cam_to_ego_bwd['translation'])[:, None] 199 | cam_to_ego_bwd = np.vstack([ 200 | np.hstack((cam_to_ego_bwd_rotation.rotation_matrix, 201 | cam_to_ego_bwd_translation)), 202 | np.array([0, 0, 0, 1]) 203 | ]) 204 | ego_to_cam_bwd = np.linalg.inv(cam_to_ego_bwd) 205 | 206 | cam_T_cam_bwd = ego_to_cam_bwd @ world_to_ego_bwd @ ego_to_world @ cam_to_ego 207 | 208 | cam_T_cam.append(cam_T_cam_bwd) 209 | 210 | # cam_T_cam, 0, 1 211 | if self.fwd != 0: 212 | fwd_sample = self.dataset.get('sample_data', cam_sample['next']) 213 | 214 | # world to ego 1 215 | world_to_ego_fwd = self.dataset.get( 216 | 'ego_pose', fwd_sample['ego_pose_token']) 217 | world_to_ego_fwd_rotation = Quaternion(world_to_ego_fwd['rotation']).inverse 218 | world_to_ego_fwd_translation = - np.array(world_to_ego_fwd['translation'])[:, None] 219 | world_to_ego_fwd = np.vstack([ 220 | np.hstack((world_to_ego_fwd_rotation.rotation_matrix, 221 | world_to_ego_fwd_rotation.rotation_matrix @ world_to_ego_fwd_translation)), 222 | np.array([0, 0, 0, 1]) 223 | ]) 224 | 225 | # ego 1 to cam 1 226 | cam_to_ego_fwd = self.dataset.get( 227 | 'calibrated_sensor', fwd_sample['calibrated_sensor_token']) 228 | cam_to_ego_fwd_rotation = Quaternion(cam_to_ego_fwd['rotation']) 229 | cam_to_ego_fwd_translation = np.array(cam_to_ego_fwd['translation'])[:, None] 230 | cam_to_ego_fwd = np.vstack([ 231 | np.hstack((cam_to_ego_fwd_rotation.rotation_matrix, 232 | cam_to_ego_fwd_translation)), 233 | np.array([0, 0, 0, 1]) 234 | ]) 235 | ego_to_cam_fwd = np.linalg.inv(cam_to_ego_fwd) 236 | 237 | cam_T_cam_fwd = ego_to_cam_fwd @ world_to_ego_fwd @ ego_to_world @ cam_to_ego 238 | 239 | cam_T_cam.append(cam_T_cam_fwd) 240 | 241 | return cam_T_cam 242 | 243 | def generate_depth_map(self, sample, sensor, cam_sample): 244 | """ 245 | This function returns depth map for nuscenes dataset, 246 | result of depth map is saved in nuscenes/samples/DEPTH_MAP 247 | """ 248 | # generate depth filename 249 | filename = '{}/{}.npz'.format( 250 | os.path.join(os.path.dirname(self.path), 'samples'), 251 | 'DEPTH_MAP/{}/{}'.format(sensor, cam_sample['filename'])) 252 | 253 | # load and return if exists 254 | if os.path.exists(filename): 255 | return np.load(filename, allow_pickle=True)['depth'] 256 | else: 257 | lidar_sample = self.dataset.get( 258 | 'sample_data', sample['data']['LIDAR_TOP']) 259 | 260 | # lidar points 261 | lidar_file = os.path.join( 262 | self.path, lidar_sample['filename']) 263 | lidar_points = np.fromfile(lidar_file, dtype=np.float32) 264 | lidar_points = lidar_points.reshape(-1, 5)[:, :3] 265 | 266 | # lidar -> world 267 | lidar_pose = self.dataset.get( 268 | 'ego_pose', lidar_sample['ego_pose_token']) 269 | lidar_rotation= Quaternion(lidar_pose['rotation']) 270 | lidar_translation = np.array(lidar_pose['translation'])[:, None] 271 | lidar_to_world = np.vstack([ 272 | np.hstack((lidar_rotation.rotation_matrix, lidar_translation)), 273 | np.array([0, 0, 0, 1]) 274 | ]) 275 | 276 | # lidar -> ego 277 | sensor_sample = self.dataset.get( 278 | 'calibrated_sensor', lidar_sample['calibrated_sensor_token']) 279 | lidar_to_ego_rotation = Quaternion( 280 | sensor_sample['rotation']).rotation_matrix 281 | lidar_to_ego_translation = np.array( 282 | sensor_sample['translation']).reshape(1, 3) 283 | 284 | ego_lidar_points = np.dot( 285 | lidar_points[:, :3], lidar_to_ego_rotation.T) 286 | ego_lidar_points += lidar_to_ego_translation 287 | 288 | homo_ego_lidar_points = np.concatenate( 289 | (ego_lidar_points, np.ones((ego_lidar_points.shape[0], 1))), axis=1) 290 | 291 | 292 | # world -> ego 293 | ego_pose = self.dataset.get( 294 | 'ego_pose', cam_sample['ego_pose_token']) 295 | ego_rotation = Quaternion(ego_pose['rotation']).inverse 296 | ego_translation = - np.array(ego_pose['translation'])[:, None] 297 | world_to_ego = np.vstack([ 298 | np.hstack((ego_rotation.rotation_matrix, 299 | ego_rotation.rotation_matrix @ ego_translation)), 300 | np.array([0, 0, 0, 1]) 301 | ]) 302 | 303 | # Ego -> sensor 304 | sensor_sample = self.dataset.get( 305 | 'calibrated_sensor', cam_sample['calibrated_sensor_token']) 306 | sensor_rotation = Quaternion(sensor_sample['rotation']) 307 | sensor_translation = np.array( 308 | sensor_sample['translation'])[:, None] 309 | sensor_to_ego = np.vstack([ 310 | np.hstack((sensor_rotation.rotation_matrix, 311 | sensor_translation)), 312 | np.array([0, 0, 0, 1]) 313 | ]) 314 | ego_to_sensor = np.linalg.inv(sensor_to_ego) 315 | 316 | # lidar -> sensor 317 | lidar_to_sensor = ego_to_sensor @ world_to_ego @ lidar_to_world 318 | homo_ego_lidar_points = torch.from_numpy(homo_ego_lidar_points).float() 319 | cam_lidar_points = np.matmul(lidar_to_sensor, homo_ego_lidar_points.T).T 320 | 321 | # depth > 0 322 | depth_mask = cam_lidar_points[:, 2] > 0 323 | cam_lidar_points = cam_lidar_points[depth_mask] 324 | 325 | # sensor -> image 326 | intrinsics = np.eye(4) 327 | intrinsics[:3, :3] = sensor_sample['camera_intrinsic'] 328 | pixel_points = np.matmul(intrinsics, cam_lidar_points.T).T 329 | pixel_points[:, :2] /= pixel_points[:, 2:3] 330 | 331 | # load image for pixel range 332 | image_filename = os.path.join( 333 | self.path, cam_sample['filename']) 334 | img = pil.open(image_filename) 335 | h, w, _ = np.array(img).shape 336 | 337 | # mask points in pixel range 338 | pixel_mask = (pixel_points[:, 0] >= 0) & (pixel_points[:, 0] <= w-1)\ 339 | & (pixel_points[:,1] >= 0) & (pixel_points[:,1] <= h-1) 340 | valid_points = pixel_points[pixel_mask].round().int() 341 | valid_depth = cam_lidar_points[:, 2][pixel_mask] 342 | 343 | depth = np.zeros([h, w]) 344 | depth[valid_points[:, 1], valid_points[:,0]] = valid_depth 345 | 346 | # save depth map 347 | os.makedirs(os.path.dirname(filename), exist_ok=True) 348 | np.savez_compressed(filename, depth=depth) 349 | return depth 350 | 351 | def get_tranformation_mat(self, pose): 352 | """ 353 | This function transforms pose information in accordance with DDAD dataset format 354 | """ 355 | extrinsics = Quaternion(pose['rotation']).transformation_matrix 356 | extrinsics[:3, 3] = np.array(pose['translation']) 357 | return extrinsics.astype(np.float32) 358 | 359 | def __len__(self): 360 | return len(self.filenames) 361 | 362 | def __getitem__(self, idx): 363 | # get nuscenes dataset sample 364 | frame_idx = self.filenames[idx].strip().split()[0] 365 | sample_nusc = self.dataset.get('sample', frame_idx) 366 | 367 | sample = [] 368 | contexts = [] 369 | if self.bwd: 370 | contexts.append(-1) 371 | if self.fwd: 372 | contexts.append(1) 373 | 374 | # loop over all cameras 375 | for cam in self.cameras: 376 | cam_sample = self.dataset.get( 377 | 'sample_data', sample_nusc['data'][cam]) 378 | 379 | data = { 380 | 'idx': idx, 381 | 'token': frame_idx, 382 | 'sensor_name': cam, 383 | 'contexts': contexts, 384 | 'filename': cam_sample['filename'], 385 | 'rgb': self.get_current('rgb', cam_sample), 386 | 'intrinsics': self.get_current('intrinsics', cam_sample) 387 | } 388 | 389 | # if depth is returned 390 | if self.with_depth: 391 | data.update({ 392 | 'depth': self.generate_depth_map(sample_nusc, cam, cam_sample) 393 | }) 394 | # if pose is returned 395 | if self.with_pose: 396 | data.update({ 397 | 'extrinsics':self.get_current('extrinsics', cam_sample) 398 | }) 399 | # if ego_pose is returned 400 | if self.with_ego_pose: 401 | data.update({ 402 | 'ego_pose': self.get_cam_T_cam(cam_sample) 403 | }) 404 | # if mask is returned 405 | if self.with_mask: 406 | data.update({ 407 | 'mask': self.mask_loader(self.mask_path, '', cam) 408 | }) 409 | # if context is returned 410 | if self.has_context: 411 | data.update({ 412 | 'rgb_context': self.get_context('rgb', cam_sample) 413 | }) 414 | 415 | sample.append(data) 416 | 417 | # apply same data transformations for all sensors 418 | if self.data_transform: 419 | sample = [self.data_transform(smp) for smp in sample] 420 | sample = [transform_mask_sample(smp, self.data_transform) for smp in sample] 421 | 422 | # stack and align dataset for our trainer 423 | sample = stack_sample(sample) 424 | sample = align_dataset(sample, self.scales, contexts) 425 | return sample 426 | -------------------------------------------------------------------------------- /dataset/nuscenes_mask/CAM_BACK_LEFT_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fangzhou2000/DrivingForward/5d0a2c7d6a6358f55318faac80e351cf75dff032/dataset/nuscenes_mask/CAM_BACK_LEFT_mask.png -------------------------------------------------------------------------------- /dataset/nuscenes_mask/CAM_BACK_RIGHT_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fangzhou2000/DrivingForward/5d0a2c7d6a6358f55318faac80e351cf75dff032/dataset/nuscenes_mask/CAM_BACK_RIGHT_mask.png -------------------------------------------------------------------------------- /dataset/nuscenes_mask/CAM_BACK_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fangzhou2000/DrivingForward/5d0a2c7d6a6358f55318faac80e351cf75dff032/dataset/nuscenes_mask/CAM_BACK_mask.png -------------------------------------------------------------------------------- /dataset/nuscenes_mask/CAM_FRONT_LEFT_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fangzhou2000/DrivingForward/5d0a2c7d6a6358f55318faac80e351cf75dff032/dataset/nuscenes_mask/CAM_FRONT_LEFT_mask.png -------------------------------------------------------------------------------- /dataset/nuscenes_mask/CAM_FRONT_RIGHT_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fangzhou2000/DrivingForward/5d0a2c7d6a6358f55318faac80e351cf75dff032/dataset/nuscenes_mask/CAM_FRONT_RIGHT_mask.png -------------------------------------------------------------------------------- /dataset/nuscenes_mask/CAM_FRONT_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fangzhou2000/DrivingForward/5d0a2c7d6a6358f55318faac80e351cf75dff032/dataset/nuscenes_mask/CAM_FRONT_mask.png -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | torch.backends.cudnn.deterministic = True 5 | torch.backends.cudnn.benchmark = False 6 | torch.backends.cuda.matmul.allow_tf32 = False 7 | 8 | import utils 9 | from models import DrivingForwardModel 10 | from trainer import DrivingForwardTrainer 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser(description='evaluation script') 15 | parser.add_argument('--config_file', default='./configs/nuscenes/main.yaml', type=str, help='config yaml file path') 16 | parser.add_argument('--weight_path', default='./weights', type=str, help='weight path') 17 | parser.add_argument('--novel_view_mode', default='MF', type=str, help='MF of SF') 18 | args = parser.parse_args() 19 | return args 20 | 21 | def test(cfg): 22 | print("Evaluating reconstruction") 23 | model = DrivingForwardModel(cfg, 0) 24 | trainer = DrivingForwardTrainer(cfg, 0, use_tb = False) 25 | trainer.evaluate(model) 26 | 27 | if __name__ == '__main__': 28 | args = parse_args() 29 | cfg = utils.get_config(args.config_file, mode='eval', weight_path=args.weight_path, novel_view_mode=args.novel_view_mode) 30 | 31 | test(cfg) 32 | -------------------------------------------------------------------------------- /external/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from external.packnet_sfm.packnet_sfm.datasets.transforms import get_transforms 2 | from external.packnet_sfm.packnet_sfm.datasets.dgp_dataset import DGPDataset 3 | from external.packnet_sfm.packnet_sfm.datasets.dgp_dataset import stack_sample 4 | from external.packnet_sfm.packnet_sfm.datasets.dgp_dataset import SynchronizedSceneDataset 5 | 6 | __all__ = ['get_transforms', 'stack_sample', 'DGPDataset', 'SynchronizedSceneDataset'] -------------------------------------------------------------------------------- /external/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 42dot. All rights reserved. 2 | from external.packnet_sfm.packnet_sfm.networks.layers.resnet.resnet_encoder import ResnetEncoder 3 | from external.packnet_sfm.packnet_sfm.networks.layers.resnet.pose_decoder import PoseDecoder 4 | from external.packnet_sfm.packnet_sfm.networks.layers.resnet.depth_decoder import DepthDecoder 5 | 6 | __all__ = ['ResnetEncoder', 'PoseDecoder', 'DepthDecoder'] 7 | -------------------------------------------------------------------------------- /external/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 42dot. All rights reserved. 2 | from external.dgp.dgp.utils.camera import Camera 3 | from external.dgp.dgp.utils.camera import generate_depth_map 4 | from external.packnet_sfm.packnet_sfm.utils.misc import make_list 5 | 6 | __all__ = ['Camera', 'generate_depth_map', 'make_list'] -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .drivingforward_model import DrivingForwardModel 2 | 3 | __all__ = ['DrivingForwardModel'] -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | _OPTIMIZER_NAME ='adam' 5 | 6 | 7 | class BaseModel: 8 | def __init__(self, cfg): 9 | self._dataloaders = {} 10 | self.mode = None 11 | self.models = None 12 | self.optimizer = None 13 | self.lr_scheduler = None 14 | self.ddp_enable = False 15 | 16 | def read_config(self, cfg): 17 | raise NotImplementedError('Not implemented for BaseModel') 18 | 19 | def prepare_dataset(self): 20 | raise NotImplementedError('Not implemented for BaseModel') 21 | 22 | def set_optimizer(self): 23 | raise NotImplementedError('Not implemented for BaseModel') 24 | 25 | def train_dataloader(self): 26 | return self._dataloaders['train'] 27 | 28 | def val_dataloader(self): 29 | return self._dataloaders['val'] 30 | 31 | def eval_dataloader(self): 32 | return self._dataloaders['eval'] 33 | 34 | def set_train(self): 35 | self.mode = 'train' 36 | for m in self.models.values(): 37 | m.train() 38 | 39 | def set_val(self): 40 | self.mode = 'val' 41 | for m in self.models.values(): 42 | m.eval() 43 | 44 | def set_eval(self): 45 | self.mode = 'eval' 46 | for m in self.models.values(): 47 | m.eval() 48 | 49 | def save_model(self, epoch): 50 | curr_model_weights_dir = os.path.join(self.save_weights_root, f'weights_{epoch}') 51 | os.makedirs(curr_model_weights_dir, exist_ok=True) 52 | 53 | for model_name, model in self.models.items(): 54 | model_file_path = os.path.join(curr_model_weights_dir, f'{model_name}.pth') 55 | to_save = model.state_dict() 56 | torch.save(to_save, model_file_path) 57 | 58 | # save optimizer 59 | optim_file_path = os.path.join(curr_model_weights_dir, f'{_OPTIMIZER_NAME}.pth') 60 | torch.save(self.optimizer.state_dict(), optim_file_path) 61 | 62 | def load_weights(self): 63 | assert os.path.isdir(self.load_weights_dir), f'\tCannot find {self.load_weights_dir}' 64 | print(f'Loading a model from {self.load_weights_dir}') 65 | 66 | # to retrain 67 | if self.pretrain and self.ddp_enable: 68 | map_location = {'cuda:%d' % 0: 'cuda:%d' % (self.world_size-1)} 69 | 70 | for n in self.models_to_load: 71 | print(f'Loading {n} weights...') 72 | path = os.path.join(self.load_weights_dir, f'{n}.pth') 73 | model_dict = self.models[n].state_dict() 74 | 75 | # distribute gpus for ddp retraining 76 | if self.pretrain and self.ddp_enable: 77 | pre_trained_dict = torch.load(path, map_location=map_location) 78 | else: 79 | pre_trained_dict = torch.load(path) 80 | 81 | # load parameters 82 | pre_trained_dict = {k: v for k, v in pre_trained_dict.items() if k in model_dict} 83 | model_dict.update(pre_trained_dict) 84 | self.models[n].load_state_dict(model_dict) 85 | 86 | if self.mode == 'train': 87 | # loading adam state 88 | optim_file_path = os.path.join(self.load_weights_dir, f'{_OPTIMIZER_NAME}.pth') 89 | if os.path.isfile(optim_file_path): 90 | try: 91 | print(f'Loading {_OPTIMIZER_NAME} weights') 92 | optimizer_dict = torch.load(optim_file_path) 93 | self.optimizer.load_state_dict(optimizer_dict) 94 | except ValueError: 95 | print(f'\tCannnot load {_OPTIMIZER_NAME} - the optimizer will be randomly initialized') 96 | else: 97 | print(f'\tCannot find {_OPTIMIZER_NAME} weights, so the optimizer will be randomly initialized') -------------------------------------------------------------------------------- /models/drivingforward_model.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.utils.data import DataLoader, Subset 7 | torch.manual_seed(0) 8 | 9 | from dataset import construct_dataset 10 | from network import * 11 | 12 | from .base_model import BaseModel 13 | from .geometry import Pose, ViewRendering 14 | from .losses import MultiCamLoss, SingleCamLoss 15 | 16 | from .gaussian import GaussianNetwork, depth2pc, pts2render, focal2fov, getProjectionMatrix, getWorld2View2, rotate_sh 17 | from einops import rearrange 18 | 19 | _NO_DEVICE_KEYS = ['idx', 'dataset_idx', 'sensor_name', 'filename', 'token'] 20 | 21 | 22 | class DrivingForwardModel(BaseModel): 23 | def __init__(self, cfg, rank): 24 | super(DrivingForwardModel, self).__init__(cfg) 25 | self.rank = rank 26 | self.read_config(cfg) 27 | self.prepare_dataset(cfg, rank) 28 | self.models = self.prepare_model(cfg, rank) 29 | self.losses = self.init_losses(cfg, rank) 30 | self.view_rendering, self.pose = self.init_geometry(cfg, rank) 31 | self.set_optimizer() 32 | 33 | if self.pretrain and rank == 0: 34 | self.load_weights() 35 | 36 | self.left_cam_dict = {2:0, 0:1, 4:2, 1:3, 5:4, 3:5} 37 | self.right_cam_dict = {0:2, 1:0, 2:4, 3:1, 4:5, 5:3} 38 | 39 | def read_config(self, cfg): 40 | for attr in cfg.keys(): 41 | for k, v in cfg[attr].items(): 42 | setattr(self, k, v) 43 | 44 | def init_geometry(self, cfg, rank): 45 | view_rendering = ViewRendering(cfg, rank) 46 | pose = Pose(cfg) 47 | return view_rendering, pose 48 | 49 | def init_losses(self, cfg, rank): 50 | if self.spatio_temporal or self.spatio: 51 | loss_model = MultiCamLoss(cfg, rank) 52 | else : 53 | loss_model = SingleCamLoss(cfg, rank) 54 | return loss_model 55 | 56 | def prepare_model(self, cfg, rank): 57 | models = {} 58 | models['pose_net'] = self.set_posenet(cfg) 59 | models['depth_net'] = self.set_depthnet(cfg) 60 | if self.gaussian: 61 | models['gs_net'] = self.set_gaussiannet(cfg) 62 | 63 | return models 64 | 65 | def set_posenet(self, cfg): 66 | return PoseNetwork(cfg).cuda() 67 | 68 | def set_depthnet(self, cfg): 69 | return DepthNetwork(cfg).cuda() 70 | 71 | def set_gaussiannet(self, cfg): 72 | return GaussianNetwork(rgb_dim=3, depth_dim=1).cuda() 73 | 74 | def prepare_dataset(self, cfg, rank): 75 | if rank == 0: 76 | print('### Preparing Datasets') 77 | 78 | if self.mode == 'train': 79 | self.set_train_dataloader(cfg, rank) 80 | if rank == 0 : 81 | self.set_val_dataloader(cfg) 82 | 83 | if self.mode == 'eval': 84 | self.set_eval_dataloader(cfg) 85 | 86 | def set_train_dataloader(self, cfg, rank): 87 | # jittering augmentation and image resizing for the training data 88 | _augmentation = { 89 | 'image_shape': (int(self.height), int(self.width)), 90 | 'jittering': (0.2, 0.2, 0.2, 0.05), 91 | 'crop_train_borders': (), 92 | 'crop_eval_borders': () 93 | } 94 | 95 | # construct train dataset 96 | train_dataset = construct_dataset(cfg, 'train', **_augmentation) 97 | 98 | dataloader_opts = { 99 | 'batch_size': self.batch_size, 100 | 'shuffle': True, 101 | 'num_workers': self.num_workers, 102 | 'pin_memory': True, 103 | 'drop_last': True 104 | } 105 | 106 | self._dataloaders['train'] = DataLoader(train_dataset, **dataloader_opts) 107 | num_train_samples = len(train_dataset) 108 | self.num_total_steps = num_train_samples // (self.batch_size * self.world_size) * self.num_epochs 109 | 110 | def set_val_dataloader(self, cfg): 111 | # Image resizing for the validation data 112 | _augmentation = { 113 | 'image_shape': (int(self.height), int(self.width)), 114 | 'jittering': (0.0, 0.0, 0.0, 0.0), 115 | 'crop_train_borders': (), 116 | 'crop_eval_borders': () 117 | } 118 | 119 | # construct validation dataset 120 | val_dataset = construct_dataset(cfg, 'val', **_augmentation) 121 | 122 | dataloader_opts = { 123 | 'batch_size': self.batch_size, 124 | 'shuffle': False, 125 | 'num_workers': 0, 126 | 'pin_memory': True, 127 | 'drop_last': True 128 | } 129 | 130 | self._dataloaders['val'] = DataLoader(val_dataset, **dataloader_opts) 131 | 132 | def set_eval_dataloader(self, cfg): 133 | # Image resizing for the validation data 134 | _augmentation = { 135 | 'image_shape': (int(self.height), int(self.width)), 136 | 'jittering': (0.0, 0.0, 0.0, 0.0), 137 | 'crop_train_borders': (), 138 | 'crop_eval_borders': () 139 | } 140 | 141 | dataloader_opts = { 142 | 'batch_size': self.eval_batch_size, 143 | 'shuffle': False, 144 | 'num_workers': self.eval_num_workers, 145 | 'pin_memory': True, 146 | 'drop_last': True 147 | } 148 | 149 | eval_dataset = construct_dataset(cfg, 'eval', **_augmentation) 150 | 151 | self._dataloaders['eval'] = DataLoader(eval_dataset, **dataloader_opts) 152 | 153 | def set_optimizer(self): 154 | parameters_to_train = [] 155 | for v in self.models.values(): 156 | parameters_to_train += list(v.parameters()) 157 | 158 | self.optimizer = optim.Adam( 159 | parameters_to_train, 160 | self.learning_rate 161 | ) 162 | 163 | self.lr_scheduler = optim.lr_scheduler.StepLR( 164 | self.optimizer, 165 | self.scheduler_step_size, 166 | 0.1 167 | ) 168 | 169 | def process_batch(self, inputs, rank): 170 | """ 171 | Pass a minibatch through the network and generate images, depth maps, and losses. 172 | """ 173 | for key, ipt in inputs.items(): 174 | if key not in _NO_DEVICE_KEYS: 175 | if 'context' in key: 176 | inputs[key] = [ipt[k].float().to(rank) for k in range(len(inputs[key]))] 177 | if 'ego_pose' in key: 178 | inputs[key] = [ipt[k].float().to(rank) for k in range(len(inputs[key]))] 179 | else: 180 | inputs[key] = ipt.float().to(rank) 181 | 182 | outputs = self.estimate(inputs) 183 | losses = self.compute_losses(inputs, outputs) 184 | return outputs, losses 185 | 186 | def estimate(self, inputs): 187 | """ 188 | This function estimates the outputs of the network. 189 | """ 190 | # pre-calculate inverse of the extrinsic matrix 191 | inputs['extrinsics_inv'] = torch.inverse(inputs['extrinsics']) 192 | 193 | # init dictionary 194 | outputs = {} 195 | for cam in range(self.num_cams): 196 | outputs[('cam', cam)] = {} 197 | 198 | pose_pred = self.predict_pose(inputs) 199 | depth_feats = self.predict_depth(inputs) 200 | 201 | for cam in range(self.num_cams): 202 | if self.mode != 'train': 203 | outputs[('cam', cam)].update({('cam_T_cam', 0, 1): inputs[('cam_T_cam', 0, 1)][:, cam, ...]}) 204 | outputs[('cam', cam)].update({('cam_T_cam', 0, -1): inputs[('cam_T_cam', 0, -1)][:, cam, ...]}) 205 | elif self.mode == 'train': 206 | outputs[('cam', cam)].update(pose_pred[('cam', cam)]) 207 | outputs[('cam', cam)].update(depth_feats[('cam', cam)]) 208 | 209 | self.compute_depth_maps(inputs, outputs) 210 | return outputs 211 | 212 | def predict_pose(self, inputs): 213 | """ 214 | This function predicts poses. 215 | """ 216 | net = self.models['pose_net'] 217 | 218 | pose = self.pose.compute_pose(net, inputs) 219 | return pose 220 | 221 | def predict_depth(self, inputs): 222 | """ 223 | This function predicts disparity maps. 224 | """ 225 | net = self.models['depth_net'] 226 | 227 | depth_feats = net(inputs) 228 | return depth_feats 229 | 230 | def compute_depth_maps(self, inputs, outputs): 231 | """ 232 | This function computes depth map for each viewpoint. 233 | """ 234 | source_scale = 0 235 | for cam in range(self.num_cams): 236 | ref_K = inputs[('K', source_scale)][:, cam, ...] 237 | for scale in self.scales: 238 | disp = outputs[('cam', cam)][('disp', scale)] 239 | outputs[('cam', cam)][('depth', 0, scale)] = self.to_depth(disp, ref_K) 240 | if self.novel_view_mode == 'MF': 241 | disp_last = outputs[('cam', cam)][('disp', -1, scale)] 242 | outputs[('cam', cam)][('depth', -1, scale)] = self.to_depth(disp_last, ref_K) 243 | disp_next = outputs[('cam', cam)][('disp', 1, scale)] 244 | outputs[('cam', cam)][('depth', 1, scale)] = self.to_depth(disp_next, ref_K) 245 | 246 | def to_depth(self, disp_in, K_in): 247 | """ 248 | This function transforms disparity value into depth map while multiplying the value with the focal length. 249 | """ 250 | min_disp = 1/self.max_depth 251 | max_disp = 1/self.min_depth 252 | disp_range = max_disp-min_disp 253 | 254 | disp_in = F.interpolate(disp_in, [self.height, self.width], mode='bilinear', align_corners=False) 255 | disp = min_disp + disp_range * disp_in 256 | depth = 1/disp 257 | return depth * K_in[:, 0:1, 0:1].unsqueeze(2)/self.focal_length_scale 258 | 259 | def get_gaussian_data(self, inputs, outputs, cam): 260 | """ 261 | This function computes gaussian data for each viewpoint. 262 | """ 263 | bs, _, height, width = inputs[('color', 0, 0)][:, cam, ...].shape 264 | zfar = self.max_depth 265 | znear = 0.01 266 | 267 | if self.novel_view_mode == 'MF': 268 | for frame_id in self.frame_ids: 269 | if frame_id == 0: 270 | outputs[('cam', cam)][('e2c_extr', frame_id, 0)] = inputs['extrinsics_inv'][:, cam, ...] 271 | outputs[('cam', cam)][('c2e_extr', frame_id, 0)] = inputs['extrinsics'][:, cam, ...] 272 | FovX_list = [] 273 | FovY_list = [] 274 | world_view_transform_list = [] 275 | full_proj_transform_list = [] 276 | camera_center_list = [] 277 | for i in range(bs): 278 | intr = inputs[('K', 0)][:, cam, ...][i,:] 279 | extr = inputs['extrinsics_inv'][:, cam, ...][i,:] 280 | FovX = focal2fov(intr[0, 0], width) 281 | FovY = focal2fov(intr[1, 1], height) 282 | projection_matrix = getProjectionMatrix(znear=znear, zfar=zfar, K=intr, h=height, w=width).transpose(0, 1).cuda() 283 | world_view_transform = torch.tensor(extr).transpose(0, 1).cuda() 284 | # full_proj_transform: (E^T K^T) = (K E)^T 285 | full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0) 286 | camera_center = world_view_transform.inverse()[3, :3] 287 | 288 | FovX_list.append(FovX) 289 | FovY_list.append(FovY) 290 | world_view_transform_list.append(world_view_transform.unsqueeze(0)) 291 | full_proj_transform_list.append(full_proj_transform.unsqueeze(0)) 292 | camera_center_list.append(camera_center.unsqueeze(0)) 293 | 294 | outputs[('cam', cam)][('FovX', frame_id, 0)] = torch.tensor(FovX_list).cuda() 295 | outputs[('cam', cam)][('FovY', frame_id, 0)] = torch.tensor(FovY_list).cuda() 296 | outputs[('cam', cam)][('world_view_transform', frame_id, 0)] = torch.cat(world_view_transform_list, dim=0) 297 | outputs[('cam', cam)][('full_proj_transform', frame_id, 0)] = torch.cat(full_proj_transform_list, dim=0) 298 | outputs[('cam', cam)][('camera_center', frame_id, 0)] = torch.cat(camera_center_list, dim=0) 299 | else: 300 | outputs[('cam', cam)][('e2c_extr', frame_id, 0)] = \ 301 | torch.matmul(outputs[('cam', cam)][('cam_T_cam', 0, frame_id)], inputs['extrinsics_inv'][:, cam, ...]) 302 | outputs[('cam', cam)][('c2e_extr', frame_id, 0)] = \ 303 | torch.matmul(inputs['extrinsics'][:, cam, ...], torch.inverse(outputs[('cam', cam)][('cam_T_cam', 0, frame_id)])) 304 | outputs[('cam', cam)][('xyz', frame_id, 0)] = depth2pc(outputs[('cam', cam)][('depth', frame_id, 0)], outputs[('cam', cam)][('e2c_extr', frame_id, 0)], inputs[('K', 0)][:, cam, ...]) 305 | valid = outputs[('cam', cam)][('depth', frame_id, 0)] != 0.0 306 | outputs[('cam', cam)][('pts_valid', frame_id, 0)] = valid.view(bs, -1) 307 | rot_maps, scale_maps, opacity_maps, sh_maps = \ 308 | self.gs_net(inputs[('color', frame_id, 0)][:, cam, ...], outputs[('cam', cam)][('depth', frame_id, 0)], outputs[('cam', cam)][('img_feat', frame_id, 0)]) 309 | c2w_rotations = rearrange(outputs[('cam', cam)][('c2e_extr', frame_id, 0)][..., :3, :3], "k i j -> k () () () i j") 310 | sh_maps = rotate_sh(sh_maps, c2w_rotations[..., None, :, :]) 311 | outputs[('cam', cam)][('rot_maps', frame_id, 0)] = rot_maps 312 | outputs[('cam', cam)][('scale_maps', frame_id, 0)] = scale_maps 313 | outputs[('cam', cam)][('opacity_maps', frame_id, 0)] = opacity_maps 314 | outputs[('cam', cam)][('sh_maps', frame_id, 0)] = sh_maps 315 | elif self.novel_view_mode == 'SF': 316 | frame_id = 0 317 | outputs[('cam', cam)][('e2c_extr', frame_id, 0)] = inputs['extrinsics_inv'][:, cam, ...] 318 | outputs[('cam', cam)][('c2e_extr', frame_id, 0)] = inputs['extrinsics'][:, cam, ...] 319 | outputs[('cam', cam)][('xyz', frame_id, 0)] = depth2pc(outputs[('cam', cam)][('depth', frame_id, 0)], outputs[('cam', cam)][('e2c_extr', frame_id, 0)], inputs[('K', 0)][:, cam, ...]) 320 | valid = outputs[('cam', cam)][('depth', frame_id, 0)] != 0.0 321 | outputs[('cam', cam)][('pts_valid', frame_id, 0)] = valid.view(bs, -1) 322 | rot_maps, scale_maps, opacity_maps, sh_maps = \ 323 | self.gs_net(inputs[('color', frame_id, 0)][:, cam, ...], outputs[('cam', cam)][('depth', frame_id, 0)], outputs[('cam', cam)][('img_feat', frame_id, 0)]) 324 | c2w_rotations = rearrange(outputs[('cam', cam)][('c2e_extr', frame_id, 0)][..., :3, :3], "k i j -> k () () () i j") 325 | sh_maps = rotate_sh(sh_maps, c2w_rotations[..., None, :, :]) 326 | outputs[('cam', cam)][('rot_maps', frame_id, 0)] = rot_maps 327 | outputs[('cam', cam)][('scale_maps', frame_id, 0)] = scale_maps 328 | outputs[('cam', cam)][('opacity_maps', frame_id, 0)] = opacity_maps 329 | outputs[('cam', cam)][('sh_maps', frame_id, 0)] = sh_maps 330 | 331 | # novel view 332 | for frame_id in self.frame_ids[1:]: 333 | outputs[('cam', cam)][('e2c_extr', frame_id, 0)] = \ 334 | torch.matmul(outputs[('cam', cam)][('cam_T_cam', 0, frame_id)], inputs['extrinsics_inv'][:, cam, ...]) 335 | outputs[('cam', cam)][('c2e_extr', frame_id, 0)] = \ 336 | torch.matmul(inputs['extrinsics'][:, cam, ...], torch.inverse(outputs[('cam', cam)][('cam_T_cam', 0, frame_id)])) 337 | 338 | FovX_list = [] 339 | FovY_list = [] 340 | world_view_transform_list = [] 341 | full_proj_transform_list = [] 342 | camera_center_list = [] 343 | for i in range(bs): 344 | intr = inputs[('K', 0)][:, cam, ...][i,:] 345 | extr = inputs['extrinsics_inv'][:, cam, ...][i,:] 346 | T_i = outputs[('cam', cam)][('cam_T_cam', 0, frame_id)][i,:] 347 | FovX = focal2fov(intr[0, 0], width) 348 | FovY = focal2fov(intr[1, 1], height) 349 | projection_matrix = getProjectionMatrix(znear=znear, zfar=zfar, K=intr, h=height, w=width).transpose(0, 1).cuda() 350 | world_view_transform = torch.matmul(T_i, torch.tensor(extr).cuda()).transpose(0, 1) 351 | # full_proj_transform: (E^T K^T) = (K E)^T 352 | full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0) 353 | camera_center = world_view_transform.inverse()[3, :3] 354 | FovX_list.append(FovX) 355 | FovY_list.append(FovY) 356 | world_view_transform_list.append(world_view_transform.unsqueeze(0)) 357 | full_proj_transform_list.append(full_proj_transform.unsqueeze(0)) 358 | camera_center_list.append(camera_center.unsqueeze(0)) 359 | outputs[('cam', cam)][('FovX', frame_id, 0)] = torch.tensor(FovX_list).cuda() 360 | outputs[('cam', cam)][('FovY', frame_id, 0)] = torch.tensor(FovY_list).cuda() 361 | outputs[('cam', cam)][('world_view_transform', frame_id, 0)] = torch.cat(world_view_transform_list, dim=0) 362 | outputs[('cam', cam)][('full_proj_transform', frame_id, 0)] = torch.cat(full_proj_transform_list, dim=0) 363 | outputs[('cam', cam)][('camera_center', frame_id, 0)] = torch.cat(camera_center_list, dim=0) 364 | 365 | def compute_losses(self, inputs, outputs): 366 | """ 367 | This function computes losses. 368 | """ 369 | losses = 0 370 | loss_fn = defaultdict(list) 371 | loss_mean = defaultdict(float) 372 | 373 | # compute gaussian data 374 | if self.gaussian: 375 | self.gs_net = self.models['gs_net'] 376 | for cam in range(self.num_cams): 377 | self.get_gaussian_data(inputs, outputs, cam) 378 | 379 | # generate image and compute loss per cameara 380 | for cam in range(self.num_cams): 381 | self.pred_cam_imgs(inputs, outputs, cam) 382 | if self.gaussian: 383 | self.pred_gaussian_imgs(inputs, outputs, cam) 384 | cam_loss, loss_dict = self.losses(inputs, outputs, cam) 385 | 386 | losses += cam_loss 387 | for k, v in loss_dict.items(): 388 | loss_fn[k].append(v) 389 | 390 | losses /= self.num_cams 391 | 392 | for k in loss_fn.keys(): 393 | loss_mean[k] = sum(loss_fn[k]) / float(len(loss_fn[k])) 394 | 395 | loss_mean['total_loss'] = losses 396 | return loss_mean 397 | 398 | def pred_cam_imgs(self, inputs, outputs, cam): 399 | """ 400 | This function renders projected images using camera parameters and depth information. 401 | """ 402 | rel_pose_dict = self.pose.compute_relative_cam_poses(inputs, outputs, cam) 403 | self.view_rendering(inputs, outputs, cam, rel_pose_dict) 404 | 405 | def pred_gaussian_imgs(self, inputs, outputs, cam): 406 | if self.novel_view_mode == 'MF': 407 | outputs[('cam', cam)][('gaussian_color', 0, 0)] = \ 408 | pts2render(inputs=inputs, 409 | outputs=outputs, 410 | cam_num=self.num_cams, 411 | novel_cam=cam, 412 | novel_frame_id=0, 413 | bg_color=[1.0, 1.0, 1.0], 414 | mode=self.novel_view_mode) 415 | elif self.novel_view_mode == 'SF': 416 | for novel_frame_id in self.frame_ids[1:]: 417 | outputs[('cam', cam)][('gaussian_color', novel_frame_id, 0)] = \ 418 | pts2render(inputs=inputs, 419 | outputs=outputs, 420 | cam_num=self.num_cams, 421 | novel_cam=cam, 422 | novel_frame_id=novel_frame_id, 423 | bg_color=[1.0, 1.0, 1.0], 424 | mode=self.novel_view_mode) 425 | 426 | -------------------------------------------------------------------------------- /models/gaussian/GaussianRender.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from .gaussian_renderer import render 4 | from einops import rearrange 5 | 6 | left_cam_dict = {2:0, 0:1, 4:2, 1:3, 5:4, 3:5} 7 | right_cam_dict = {0:2, 1:0, 2:4, 3:1, 4:5, 5:3} 8 | 9 | def get_adj_cams(cam): 10 | adj_cams = [cam] 11 | adj_cams.append(left_cam_dict[cam]) 12 | adj_cams.append(right_cam_dict[cam]) 13 | return adj_cams 14 | 15 | def pts2render(inputs, outputs, cam_num, novel_cam, novel_frame_id, bg_color, mode='MF'): 16 | bs, _, height, width = inputs[('color', 0, 0)][:, novel_cam, ...].shape 17 | render_novel_list = [] 18 | for i in range(bs): 19 | xyz_i_valid = [] 20 | # rgb_i_valid = [] 21 | rot_i_valid = [] 22 | scale_i_valid = [] 23 | opacity_i_valid = [] 24 | sh_i_valid = [] 25 | if mode == 'SF': 26 | frame_id = 0 27 | for cam in range(cam_num): 28 | valid_i = outputs[('cam', cam)][('pts_valid', frame_id, 0)][i, :] 29 | xyz_i = outputs[('cam', cam)][('xyz', frame_id, 0)][i, :, :] 30 | # rgb_i = inputs[('color', frame_id, 0)][:, cam, ...][i, :, :, :].permute(1, 2, 0).view(-1, 3) # HWC 31 | 32 | rot_i = outputs[('cam', cam)][('rot_maps', frame_id, 0)][i, :, :, :].permute(1, 2, 0).view(-1, 4) 33 | scale_i = outputs[('cam', cam)][('scale_maps', frame_id, 0)][i, :, :, :].permute(1, 2, 0).view(-1, 3) 34 | opacity_i = outputs[('cam', cam)][('opacity_maps', frame_id, 0)][i, :, :, :].permute(1, 2, 0).view(-1, 1) 35 | sh_i = rearrange(outputs[('cam', cam)][('sh_maps', frame_id, 0)][i, :, :, :], "p srf r xyz d_sh -> (p srf r) d_sh xyz").contiguous() 36 | 37 | xyz_i_valid.append(xyz_i[valid_i].view(-1, 3)) 38 | # rgb_i_valid.append(rgb_i[valid_i].view(-1, 3)) 39 | rot_i_valid.append(rot_i[valid_i].view(-1, 4)) 40 | scale_i_valid.append(scale_i[valid_i].view(-1, 3)) 41 | opacity_i_valid.append(opacity_i[valid_i].view(-1, 1)) 42 | sh_i_valid.append(sh_i[valid_i]) 43 | 44 | elif mode == 'MF': 45 | for frame_id in [-1, 1]: 46 | cam = novel_cam 47 | valid_i = outputs[('cam', cam)][('pts_valid', frame_id, 0)][i, :] 48 | xyz_i = outputs[('cam', cam)][('xyz', frame_id, 0)][i, :, :] 49 | # rgb_i = inputs[('color', frame_id, 0)][:, cam, ...][i, :, :, :].permute(1, 2, 0).view(-1, 3) # HWC 50 | 51 | rot_i = outputs[('cam', cam)][('rot_maps', frame_id, 0)][i, :, :, :].permute(1, 2, 0).view(-1, 4) 52 | scale_i = outputs[('cam', cam)][('scale_maps', frame_id, 0)][i, :, :, :].permute(1, 2, 0).view(-1, 3) 53 | opacity_i = outputs[('cam', cam)][('opacity_maps', frame_id, 0)][i, :, :, :].permute(1, 2, 0).view(-1, 1) 54 | sh_i = rearrange(outputs[('cam', cam)][('sh_maps', frame_id, 0)][i, :, :, :], "p srf r xyz d_sh -> (p srf r) d_sh xyz").contiguous() 55 | 56 | xyz_i_valid.append(xyz_i[valid_i].view(-1, 3)) 57 | # rgb_i_valid.append(rgb_i[valid_i].view(-1, 3)) 58 | rot_i_valid.append(rot_i[valid_i].view(-1, 4)) 59 | scale_i_valid.append(scale_i[valid_i].view(-1, 3)) 60 | opacity_i_valid.append(opacity_i[valid_i].view(-1, 1)) 61 | sh_i_valid.append(sh_i[valid_i]) 62 | 63 | pts_xyz_i = torch.concat(xyz_i_valid, dim=0) 64 | # pts_rgb_i = torch.concat(rgb_i_valid, dim=0) 65 | # pts_rgb_i = pts_rgb_i * 0.5 + 0.5 66 | rot_i = torch.concat(rot_i_valid, dim=0) 67 | scale_i = torch.concat(scale_i_valid, dim=0) 68 | opacity_i = torch.concat(opacity_i_valid, dim=0) 69 | sh_i = torch.concat(sh_i_valid, dim=0) 70 | 71 | novel_FovX_i = outputs[('cam', novel_cam)][('FovX', novel_frame_id, 0)][i] 72 | novel_FovY_i = outputs[('cam', novel_cam)][('FovY', novel_frame_id, 0)][i] 73 | novel_world_view_transform_i = outputs[('cam', novel_cam)][('world_view_transform', novel_frame_id, 0)][i] 74 | novel_function_proj_transform_i = outputs[('cam', novel_cam)][('full_proj_transform', novel_frame_id, 0)][i] 75 | novel_camera_center_i = outputs[('cam', novel_cam)][('camera_center', novel_frame_id, 0)][i] 76 | 77 | render_novel_i = render(novel_FovX=novel_FovX_i, 78 | novel_FovY=novel_FovY_i, 79 | novel_height=height, 80 | novel_width=width, 81 | novel_world_view_transform=novel_world_view_transform_i, 82 | novel_full_proj_transform=novel_function_proj_transform_i, 83 | novel_camera_center=novel_camera_center_i, 84 | pts_xyz=pts_xyz_i, 85 | pts_rgb=None, 86 | rotations=rot_i, 87 | scales=scale_i, 88 | opacity=opacity_i, 89 | shs=sh_i, 90 | bg_color=bg_color) 91 | render_novel_list.append(render_novel_i.unsqueeze(0)) 92 | 93 | novel = torch.concat(render_novel_list, dim=0) 94 | 95 | return novel 96 | -------------------------------------------------------------------------------- /models/gaussian/__init__.py: -------------------------------------------------------------------------------- 1 | from .gaussian_network import GaussianNetwork 2 | from .GaussianRender import pts2render 3 | from .utils import depth2pc, rotate_sh, focal2fov, getProjectionMatrix, getWorld2View2 4 | -------------------------------------------------------------------------------- /models/gaussian/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | if not (stride == 1 and in_planes == planes): 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | 22 | elif norm_fn == 'batch': 23 | self.norm1 = nn.BatchNorm2d(planes) 24 | self.norm2 = nn.BatchNorm2d(planes) 25 | if not (stride == 1 and in_planes == planes): 26 | self.norm3 = nn.BatchNorm2d(planes) 27 | 28 | elif norm_fn == 'instance': 29 | self.norm1 = nn.InstanceNorm2d(planes) 30 | self.norm2 = nn.InstanceNorm2d(planes) 31 | if not (stride == 1 and in_planes == planes): 32 | self.norm3 = nn.InstanceNorm2d(planes) 33 | 34 | elif norm_fn == 'none': 35 | self.norm1 = nn.Sequential() 36 | self.norm2 = nn.Sequential() 37 | if not (stride == 1 and in_planes == planes): 38 | self.norm3 = nn.Sequential() 39 | 40 | if stride == 1 and in_planes == planes: 41 | self.downsample = None 42 | 43 | else: 44 | self.downsample = nn.Sequential( 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 46 | 47 | 48 | def forward(self, x): 49 | y = x 50 | y = self.conv1(y) 51 | y = self.norm1(y) 52 | y = self.relu(y) 53 | y = self.conv2(y) 54 | y = self.norm2(y) 55 | y = self.relu(y) 56 | 57 | if self.downsample is not None: 58 | x = self.downsample(x) 59 | 60 | return self.relu(x+y) 61 | 62 | 63 | class UnetExtractor(nn.Module): 64 | def __init__(self, in_channel=3, encoder_dim=[32, 48, 96], norm_fn='group'): 65 | super().__init__() 66 | self.in_ds = nn.Sequential( 67 | nn.Conv2d(in_channel, 32, kernel_size=5, stride=2, padding=2), 68 | nn.GroupNorm(num_groups=8, num_channels=32), 69 | nn.ReLU(inplace=True) 70 | ) 71 | 72 | self.res1 = nn.Sequential( 73 | ResidualBlock(32, encoder_dim[0], norm_fn=norm_fn), 74 | ResidualBlock(encoder_dim[0], encoder_dim[0], norm_fn=norm_fn) 75 | ) 76 | self.res2 = nn.Sequential( 77 | ResidualBlock(encoder_dim[0], encoder_dim[1], stride=2, norm_fn=norm_fn), 78 | ResidualBlock(encoder_dim[1], encoder_dim[1], norm_fn=norm_fn) 79 | ) 80 | self.res3 = nn.Sequential( 81 | ResidualBlock(encoder_dim[1], encoder_dim[2], stride=2, norm_fn=norm_fn), 82 | ResidualBlock(encoder_dim[2], encoder_dim[2], norm_fn=norm_fn), 83 | ) 84 | 85 | def forward(self, x): 86 | x = self.in_ds(x) 87 | x1 = self.res1(x) 88 | x2 = self.res2(x1) 89 | x3 = self.res3(x2) 90 | 91 | return x1, x2, x3 92 | 93 | 94 | class MultiBasicEncoder(nn.Module): 95 | def __init__(self, output_dim=[128], encoder_dim=[64, 96, 128]): 96 | super(MultiBasicEncoder, self).__init__() 97 | 98 | # output convolution for feature 99 | self.conv2 = nn.Sequential( 100 | ResidualBlock(encoder_dim[2], encoder_dim[2], stride=1), 101 | nn.Conv2d(encoder_dim[2], encoder_dim[2]*2, 3, padding=1)) 102 | 103 | # output convolution for context 104 | output_list = [] 105 | for dim in output_dim: 106 | conv_out = nn.Sequential( 107 | ResidualBlock(encoder_dim[2], encoder_dim[2], stride=1), 108 | nn.Conv2d(encoder_dim[2], dim[2], 3, padding=1)) 109 | output_list.append(conv_out) 110 | 111 | self.outputs08 = nn.ModuleList(output_list) 112 | 113 | def forward(self, x): 114 | feat1, feat2 = self.conv2(x).split(dim=0, split_size=x.shape[0]//2) 115 | 116 | outputs08 = [f(x) for f in self.outputs08] 117 | return outputs08, feat1, feat2 118 | 119 | 120 | if __name__ == '__main__': 121 | 122 | data = torch.ones((1, 3, 1024, 1024)) 123 | 124 | model = UnetExtractor(in_channel=3, encoder_dim=[64, 96, 128]) 125 | 126 | x1, x2, x3 = model(data) 127 | print(x1.shape, x2.shape, x3.shape) 128 | -------------------------------------------------------------------------------- /models/gaussian/gaussian_network.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | from .extractor import UnetExtractor, ResidualBlock 5 | from einops import rearrange 6 | 7 | 8 | class GaussianNetwork(nn.Module): 9 | def __init__(self, rgb_dim=3, depth_dim=1, norm_fn='group'): 10 | super().__init__() 11 | self.rgb_dims = [64, 64, 128] 12 | self.depth_dims = [32, 48, 96] 13 | self.decoder_dims = [48, 64, 96] 14 | self.head_dim = 32 15 | 16 | self.sh_degree = 4 17 | self.d_sh = (self.sh_degree + 1) ** 2 18 | 19 | self.register_buffer( 20 | "sh_mask", 21 | torch.ones((self.d_sh,), dtype=torch.float32), 22 | persistent=False, 23 | ) 24 | for degree in range(1, self.sh_degree + 1): 25 | self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.1 * 0.25**degree 26 | 27 | self.depth_encoder = UnetExtractor(in_channel=depth_dim, encoder_dim=self.depth_dims) 28 | 29 | self.decoder3 = nn.Sequential( 30 | ResidualBlock(self.rgb_dims[2]+self.depth_dims[2], self.decoder_dims[2], norm_fn=norm_fn), 31 | ResidualBlock(self.decoder_dims[2], self.decoder_dims[2], norm_fn=norm_fn) 32 | ) 33 | 34 | self.decoder2 = nn.Sequential( 35 | ResidualBlock(self.rgb_dims[1]+self.depth_dims[1]+self.decoder_dims[2], self.decoder_dims[1], norm_fn=norm_fn), 36 | ResidualBlock(self.decoder_dims[1], self.decoder_dims[1], norm_fn=norm_fn) 37 | ) 38 | 39 | self.decoder1 = nn.Sequential( 40 | ResidualBlock(self.rgb_dims[0]+self.depth_dims[0]+self.decoder_dims[1], self.decoder_dims[0], norm_fn=norm_fn), 41 | ResidualBlock(self.decoder_dims[0], self.decoder_dims[0], norm_fn=norm_fn) 42 | ) 43 | self.up = nn.Upsample(scale_factor=2, mode="bilinear") 44 | self.out_conv = nn.Conv2d(self.decoder_dims[0]+rgb_dim+1, self.head_dim, kernel_size=3, padding=1) 45 | self.out_relu = nn.ReLU(inplace=True) 46 | 47 | self.rot_head = nn.Sequential( 48 | nn.Conv2d(self.head_dim, self.head_dim, kernel_size=3, padding=1), 49 | nn.ReLU(inplace=True), 50 | nn.Conv2d(self.head_dim, 4, kernel_size=1), 51 | ) 52 | self.scale_head = nn.Sequential( 53 | nn.Conv2d(self.head_dim, self.head_dim, kernel_size=3, padding=1), 54 | nn.ReLU(inplace=True), 55 | nn.Conv2d(self.head_dim, 3, kernel_size=1), 56 | nn.Softplus(beta=100) 57 | ) 58 | self.opacity_head = nn.Sequential( 59 | nn.Conv2d(self.head_dim, self.head_dim, kernel_size=3, padding=1), 60 | nn.ReLU(inplace=True), 61 | nn.Conv2d(self.head_dim, 1, kernel_size=1), 62 | nn.Sigmoid() 63 | ) 64 | self.sh_head = nn.Sequential( 65 | nn.Conv2d(self.head_dim, self.head_dim, kernel_size=3, padding=1), 66 | nn.ReLU(inplace=True), 67 | nn.Conv2d(self.head_dim, 3 * self.d_sh, kernel_size=1), 68 | ) 69 | 70 | def forward(self, img, depth, img_feat): 71 | # img_feat1: [4, 64, 176, 320] 72 | # img_feat2: [4, 64, 88, 160] 73 | # img_feat3: [4, 128, 44, 80] 74 | img_feat1, img_feat2, img_feat3 = img_feat 75 | # depth_feat1: [4, 32, 176, 320] 76 | # depth_feat2: [4, 48, 88, 160] 77 | # depth_feat3: [4, 96, 44, 80] 78 | depth_feat1, depth_feat2, depth_feat3 = self.depth_encoder(depth) 79 | 80 | feat3 = torch.concat([img_feat3, depth_feat3], dim=1) 81 | feat2 = torch.concat([img_feat2, depth_feat2], dim=1) 82 | feat1 = torch.concat([img_feat1, depth_feat1], dim=1) 83 | 84 | up3 = self.decoder3(feat3) 85 | up3 = self.up(up3) 86 | up2 = self.decoder2(torch.cat([up3, feat2], dim=1)) 87 | up2 = self.up(up2) 88 | up1 = self.decoder1(torch.cat([up2, feat1], dim=1)) 89 | 90 | up1 = self.up(up1) 91 | out = torch.cat([up1, img, depth], dim=1) 92 | out = self.out_conv(out) 93 | out = self.out_relu(out) 94 | 95 | # rot head 96 | rot_out = self.rot_head(out) 97 | rot_out = torch.nn.functional.normalize(rot_out, dim=1) 98 | 99 | # scale head 100 | scale_out = torch.clamp_max(self.scale_head(out), 0.01) 101 | 102 | # opacity head 103 | opacity_out = self.opacity_head(out) 104 | 105 | # sh head 106 | sh_out = self.sh_head(out) 107 | # sh_out: [(b * v), C, H, W] 108 | 109 | sh_out = rearrange( 110 | sh_out, "n c h w -> n (h w) c", 111 | ) 112 | sh_out = rearrange( 113 | sh_out, 114 | "... (srf c) -> ... srf () c", 115 | srf=1, 116 | ) 117 | 118 | sh_out = rearrange(sh_out, "... (xyz d_sh) -> ... xyz d_sh", xyz=3) 119 | # [(b * v), (H * W), 1, 1 3, 25] 120 | 121 | # sh_out = sh_out.broadcast_to(sh_out.shape) * self.sh_mask 122 | sh_out = sh_out * self.sh_mask 123 | 124 | 125 | return rot_out, scale_out, opacity_out, sh_out 126 | -------------------------------------------------------------------------------- /models/gaussian/gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer 4 | 5 | 6 | def render(novel_FovX, 7 | novel_FovY, 8 | novel_height, 9 | novel_width, 10 | novel_world_view_transform, 11 | novel_full_proj_transform, 12 | novel_camera_center, 13 | pts_xyz, 14 | pts_rgb, 15 | rotations, 16 | scales, 17 | opacity, 18 | shs, 19 | bg_color): 20 | """ 21 | Render the scene. 22 | 23 | Background tensor (bg_color) must be on GPU! 24 | """ 25 | bg_color = torch.tensor(bg_color, dtype=torch.float32).cuda() 26 | 27 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 28 | screenspace_points = torch.zeros_like(pts_xyz, dtype=torch.float32, requires_grad=True).cuda() 29 | try: 30 | screenspace_points.retain_grad() 31 | except: 32 | pass 33 | 34 | # Set up rasterization configuration 35 | tanfovx = math.tan(novel_FovX * 0.5) 36 | tanfovy = math.tan(novel_FovY * 0.5) 37 | 38 | raster_settings = GaussianRasterizationSettings( 39 | image_height=int(novel_height), 40 | image_width=int(novel_width), 41 | tanfovx=tanfovx, 42 | tanfovy=tanfovy, 43 | bg=bg_color, 44 | scale_modifier=1.0, 45 | viewmatrix=novel_world_view_transform, 46 | projmatrix=novel_full_proj_transform, 47 | sh_degree=3, 48 | campos=novel_camera_center, 49 | prefiltered=False, 50 | debug=False, 51 | antialiasing=False 52 | ) 53 | 54 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 55 | 56 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 57 | rendered_image, _, _ = rasterizer( 58 | means3D=pts_xyz, 59 | means2D=screenspace_points, 60 | shs=shs, 61 | colors_precomp=pts_rgb, 62 | opacities=opacity, 63 | scales=scales, 64 | rotations=rotations, 65 | cov3D_precomp=None) 66 | 67 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 68 | # They will be excluded from value updates used in the splitting criteria. 69 | 70 | return rendered_image 71 | -------------------------------------------------------------------------------- /models/gaussian/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from jaxtyping import Float 4 | from math import isqrt 5 | from e3nn.o3 import matrix_to_angles, wigner_D 6 | from einops import einsum 7 | import math 8 | import numpy as np 9 | 10 | def depth2pc(depth, extrinsic, intrinsic): 11 | B, C, H, W = depth.shape 12 | depth = depth[:, 0, :, :] 13 | rot = extrinsic[:, :3, :3] 14 | trans = extrinsic[:, :3, 3:] 15 | 16 | y, x = torch.meshgrid(torch.linspace(0.5, H-0.5, H, device=depth.device), torch.linspace(0.5, W-0.5, W, device=depth.device)) 17 | pts_2d = torch.stack([x, y, torch.ones_like(x)], dim=-1).unsqueeze(0).repeat(B, 1, 1, 1) # B H W 3 18 | 19 | pts_2d[..., 2] = depth 20 | pts_2d[:, :, :, 0] -= intrinsic[:, None, None, 0, 2] 21 | pts_2d[:, :, :, 1] -= intrinsic[:, None, None, 1, 2] 22 | pts_2d_xy = pts_2d[:, :, :, :2] * pts_2d[:, :, :, 2:] 23 | pts_2d = torch.cat([pts_2d_xy, pts_2d[..., 2:]], dim=-1) 24 | 25 | pts_2d[..., 0] /= intrinsic[:, 0, 0][:, None, None] 26 | pts_2d[..., 1] /= intrinsic[:, 1, 1][:, None, None] 27 | 28 | pts_2d = pts_2d.view(B, -1, 3).permute(0, 2, 1) 29 | 30 | rot_t = rot.permute(0, 2, 1) 31 | pts = torch.bmm(rot_t, pts_2d) - torch.bmm(rot_t, trans) 32 | 33 | return pts.permute(0, 2, 1) 34 | 35 | def rotate_sh( 36 | sh_coefficients: Float[Tensor, "*#batch n"], 37 | rotations: Float[Tensor, "*#batch 3 3"], 38 | ) -> Float[Tensor, "*batch n"]: 39 | device = sh_coefficients.device 40 | dtype = sh_coefficients.dtype 41 | 42 | *_, n = sh_coefficients.shape 43 | alpha, beta, gamma = matrix_to_angles(rotations) 44 | result = [] 45 | for degree in range(isqrt(n)): 46 | sh_rotations = wigner_D(degree, alpha, beta, gamma).type(dtype) 47 | sh_rotated = einsum( 48 | sh_rotations, 49 | sh_coefficients[..., degree**2 : (degree + 1) ** 2], 50 | "... i j, ... j -> ... i", 51 | ) 52 | result.append(sh_rotated) 53 | 54 | return torch.cat(result, dim=-1) 55 | 56 | def focal2fov(focal, pixels): 57 | return 2*math.atan(pixels/(2*focal)) 58 | 59 | def getProjectionMatrix(znear, zfar, K, h, w): 60 | near_fx = znear / K[0, 0] 61 | near_fy = znear / K[1, 1] 62 | left = - (w - K[0, 2]) * near_fx 63 | right = K[0, 2] * near_fx 64 | bottom = (K[1, 2] - h) * near_fy 65 | top = K[1, 2] * near_fy 66 | 67 | P = torch.zeros(4, 4) 68 | z_sign = 1.0 69 | P[0, 0] = 2.0 * znear / (right - left) 70 | P[1, 1] = 2.0 * znear / (top - bottom) 71 | P[0, 2] = (right + left) / (right - left) 72 | P[1, 2] = (top + bottom) / (top - bottom) 73 | P[3, 2] = z_sign 74 | P[2, 2] = z_sign * zfar / (zfar - znear) 75 | P[2, 3] = -(zfar * znear) / (zfar - znear) 76 | return P 77 | 78 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 79 | Rt = np.zeros((4, 4)) 80 | Rt[:3, :3] = R.transpose() 81 | Rt[:3, 3] = t 82 | Rt[3, 3] = 1.0 83 | 84 | C2W = np.linalg.inv(Rt) 85 | cam_center = C2W[:3, 3] 86 | cam_center = (cam_center + translate) * scale 87 | C2W[:3, 3] = cam_center 88 | Rt = np.linalg.inv(C2W) 89 | return np.float32(Rt) -------------------------------------------------------------------------------- /models/geometry/__init__.py: -------------------------------------------------------------------------------- 1 | from .pose import Pose 2 | from .view_rendering import ViewRendering 3 | 4 | __all__ = ['Pose', 'ViewRendering'] -------------------------------------------------------------------------------- /models/geometry/geometry_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from pytorch3d.transforms import axis_angle_to_matrix 5 | 6 | 7 | def vec_to_matrix(rot_angle, trans_vec, invert=False): 8 | """ 9 | This function transforms rotation angle and translation vector into 4x4 matrix. 10 | """ 11 | # initialize matrices 12 | b, _, _ = rot_angle.shape 13 | R_mat = torch.eye(4).repeat([b, 1, 1]).to(device=rot_angle.device) 14 | T_mat = torch.eye(4).repeat([b, 1, 1]).to(device=rot_angle.device) 15 | 16 | R_mat[:, :3, :3] = axis_angle_to_matrix(rot_angle).squeeze(1) 17 | t_vec = trans_vec.clone().contiguous().view(-1, 3, 1) 18 | 19 | if invert == True: 20 | R_mat = R_mat.transpose(1,2) 21 | t_vec = -1 * t_vec 22 | 23 | T_mat[:, :3, 3:] = t_vec 24 | 25 | if invert == True: 26 | P_mat = torch.matmul(R_mat, T_mat) 27 | else : 28 | P_mat = torch.matmul(T_mat, R_mat) 29 | return P_mat 30 | 31 | 32 | class Projection(nn.Module): 33 | """ 34 | This class computes projection and reprojection function. 35 | """ 36 | def __init__(self, batch_size, height, width, device): 37 | super().__init__() 38 | self.batch_size = batch_size 39 | self.width = width 40 | self.height = height 41 | 42 | # initialize img point grid 43 | img_points = np.meshgrid(range(width), range(height), indexing='xy') 44 | img_points = torch.from_numpy(np.stack(img_points, 0)).float() 45 | img_points = torch.stack([img_points[0].view(-1), img_points[1].view(-1)], 0).repeat(batch_size, 1, 1) 46 | img_points = img_points.to(device) 47 | 48 | self.to_homo = torch.ones([batch_size, 1, width*height]).to(device) 49 | self.homo_points = torch.cat([img_points, self.to_homo], 1) 50 | 51 | def backproject(self, invK, depth): 52 | """ 53 | This function back-projects 2D image points to 3D. 54 | """ 55 | depth = depth.view(self.batch_size, 1, -1) 56 | 57 | points3D = torch.matmul(invK[:, :3, :3], self.homo_points) 58 | points3D = depth*points3D 59 | return torch.cat([points3D, self.to_homo], 1) 60 | 61 | def reproject(self, K, points3D, T): 62 | """ 63 | This function reprojects transformed 3D points to 2D image coordinate. 64 | """ 65 | # project points 66 | points2D = (K @ T)[:,:3, :] @ points3D 67 | 68 | # normalize projected points for grid sample function 69 | norm_points2D = points2D[:, :2, :]/(points2D[:, 2:, :] + 1e-7) 70 | norm_points2D = norm_points2D.view(self.batch_size, 2, self.height, self.width) 71 | norm_points2D = norm_points2D.permute(0, 2, 3, 1) 72 | 73 | norm_points2D[..., 0 ] /= self.width - 1 74 | norm_points2D[..., 1 ] /= self.height - 1 75 | norm_points2D = (norm_points2D-0.5)*2 76 | return norm_points2D 77 | 78 | def forward(self, depth, T, bp_invK, rp_K): 79 | cam_points = self.backproject(bp_invK, depth) 80 | pix_coords = self.reproject(rp_K, cam_points, T) 81 | return pix_coords -------------------------------------------------------------------------------- /models/geometry/pose.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .geometry_util import vec_to_matrix 4 | 5 | 6 | class Pose: 7 | """ 8 | Class for multi-camera pose calculation 9 | """ 10 | def __init__(self, cfg): 11 | self.read_config(cfg) 12 | 13 | def read_config(self, cfg): 14 | for attr in cfg.keys(): 15 | for k, v in cfg[attr].items(): 16 | setattr(self, k, v) 17 | 18 | def compute_pose(self, net, inputs): 19 | """ 20 | This function computes multi-camera posse in accordance with the network structure. 21 | """ 22 | pose = self.get_single_pose(net, inputs, None) 23 | pose = self.distribute_pose(pose, inputs['extrinsics'], inputs['extrinsics_inv']) 24 | 25 | return pose 26 | 27 | def get_single_pose(self, net, inputs ,cam): 28 | """ 29 | This function computes pose for a single camera. 30 | """ 31 | output = {} 32 | for f_i in self.frame_ids[1:]: 33 | # To maintain ordering we always pass frames in temporal order 34 | frame_ids = [-1, 0] if f_i < 0 else [0, 1] 35 | axisangle, translation = net(inputs, frame_ids, cam) 36 | output[('cam_T_cam', 0, f_i)] = vec_to_matrix(axisangle[:, 0], translation[:, 0], invert=(f_i < 0)) 37 | return output 38 | 39 | def distribute_pose(self, poses, exts, exts_inv): 40 | """ 41 | This function distrubutes pose to each camera by using the canonical pose and camera extrinsics. 42 | (default: reference camera 0) 43 | """ 44 | outputs = {} 45 | for cam in range(self.num_cams): 46 | outputs[('cam',cam)] = {} 47 | # Refernce camera(canonical) 48 | ref_ext = exts[:, 0, ...] 49 | ref_ext_inv = exts_inv[:, 0, ...] 50 | for f_i in self.frame_ids[1:]: 51 | ref_T = poses['cam_T_cam', 0, f_i].float() # canonical pose 52 | # Relative cameras(canonical) 53 | for cam in range(self.num_cams): 54 | cur_ext = exts[:,cam,...] 55 | cur_ext_inv = exts_inv[:,cam,...] 56 | cur_T = cur_ext_inv @ ref_ext @ ref_T @ ref_ext_inv @ cur_ext 57 | 58 | outputs[('cam',cam)][('cam_T_cam', 0, f_i)] = cur_T 59 | return outputs 60 | 61 | def compute_relative_cam_poses(self, inputs, outputs, cam): 62 | """ 63 | This function computes spatio & spatio-temporal transformation for images from different viewpoints. 64 | """ 65 | ref_ext = inputs['extrinsics'][:, cam, ...] 66 | target_view = outputs[('cam', cam)] 67 | 68 | rel_pose_dict = {} 69 | # precompute the relative pose 70 | if self.spatio: 71 | # current time step (spatio) 72 | for cur_index in self.rel_cam_list[cam]: 73 | # for partial surround view training 74 | if cur_index >= self.num_cams: 75 | continue 76 | 77 | cur_ext_inv = inputs['extrinsics_inv'][:, cur_index, ...] 78 | rel_pose_dict[(0, cur_index)] = torch.matmul(cur_ext_inv, ref_ext) 79 | 80 | if self.spatio_temporal: 81 | # different time step (spatio-temporal) 82 | for frame_id in self.frame_ids[1:]: 83 | for cur_index in self.rel_cam_list[cam]: 84 | # for partial surround view training 85 | if cur_index >= self.num_cams: 86 | continue 87 | 88 | T = target_view[('cam_T_cam', 0, frame_id)] 89 | # assuming that extrinsic doesn't change 90 | rel_ext = rel_pose_dict[(0, cur_index)] 91 | rel_pose_dict[(frame_id, cur_index)] = torch.matmul(rel_ext, T) # using matmul speed up 92 | return rel_pose_dict -------------------------------------------------------------------------------- /models/geometry/view_rendering.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .geometry_util import Projection 6 | 7 | 8 | class ViewRendering(nn.Module): 9 | """ 10 | Class for rendering images from given camera parameters and pixel wise depth information 11 | """ 12 | def __init__(self, cfg, rank): 13 | super().__init__() 14 | self.read_config(cfg) 15 | self.rank = rank 16 | self.project = self.init_project_imgs(rank) 17 | 18 | def read_config(self, cfg): 19 | for attr in cfg.keys(): 20 | for k, v in cfg[attr].items(): 21 | setattr(self, k, v) 22 | 23 | def init_project_imgs(self, rank): 24 | project_imgs = {} 25 | project_imgs = Projection( 26 | self.batch_size, self.height, self.width, rank) 27 | return project_imgs 28 | 29 | def get_mean_std(self, feature, mask): 30 | """ 31 | This function returns mean and standard deviation of the overlapped features. 32 | """ 33 | _, c, h, w = mask.size() 34 | mean = (feature * mask).sum(dim=(1,2,3), keepdim=True) / (mask.sum(dim=(1,2,3), keepdim=True) + 1e-8) 35 | var = ((feature - mean) ** 2).sum(dim=(1,2,3), keepdim=True) / (c*h*w) 36 | return mean, torch.sqrt(var + 1e-16) 37 | 38 | def get_norm_image_single(self, src_img, src_mask, warp_img, warp_mask): 39 | """ 40 | obtain normalized warped images using the mean and the variance from the overlapped regions of the target frame. 41 | """ 42 | warp_mask = warp_mask.detach() 43 | 44 | with torch.no_grad(): 45 | mask = (src_mask * warp_mask).bool() 46 | if mask.size(1) != 3: 47 | mask = mask.repeat(1,3,1,1) 48 | 49 | mask_sum = mask.sum(dim=(-3,-2,-1)) 50 | # skip when there is no overlap 51 | if torch.any(mask_sum == 0): 52 | return warp_img 53 | 54 | s_mean, s_std = self.get_mean_std(src_img, mask) 55 | w_mean, w_std = self.get_mean_std(warp_img, mask) 56 | 57 | norm_warp = (warp_img - w_mean) / (w_std + 1e-8) * s_std + s_mean 58 | return norm_warp * warp_mask.float() 59 | 60 | def get_virtual_image(self, src_img, src_mask, tar_depth, tar_invK, src_K, T, scale=0): 61 | """ 62 | This function warps source image to target image using backprojection and reprojection process. 63 | """ 64 | # do reconstruction for target from source 65 | pix_coords = self.project(tar_depth, T, tar_invK, src_K) 66 | 67 | img_warped = F.grid_sample(src_img, pix_coords, mode='bilinear', 68 | padding_mode='zeros', align_corners=True) 69 | mask_warped = F.grid_sample(src_mask, pix_coords, mode='nearest', 70 | padding_mode='zeros', align_corners=True) 71 | 72 | # nan handling 73 | inf_img_regions = torch.isnan(img_warped) 74 | img_warped[inf_img_regions] = 2.0 75 | inf_mask_regions = torch.isnan(mask_warped) 76 | mask_warped[inf_mask_regions] = 0 77 | 78 | pix_coords = pix_coords.permute(0, 3, 1, 2) 79 | invalid_mask = torch.logical_or(pix_coords > 1, 80 | pix_coords < -1).sum(dim=1, keepdim=True) > 0 81 | return img_warped, (~invalid_mask).float() * mask_warped 82 | 83 | def get_virtual_depth(self, src_depth, src_mask, src_invK, src_K, tar_depth, tar_invK, tar_K, T, min_depth, max_depth, scale=0): 84 | """ 85 | This function backward-warp source depth into the target coordinate. 86 | src -> target 87 | """ 88 | # transform source depth 89 | b, _, h, w = src_depth.size() 90 | src_points = self.project.backproject(src_invK, src_depth) 91 | src_points_warped = torch.matmul(T[:, :3, :], src_points) 92 | src_depth_warped = src_points_warped.reshape(b, 3, h, w)[:, 2:3, :, :] 93 | 94 | # reconstruct depth: backward-warp source depth to the target coordinate 95 | pix_coords = self.project(tar_depth, torch.inverse(T), tar_invK, src_K) 96 | depth_warped = F.grid_sample(src_depth_warped, pix_coords, mode='bilinear', 97 | padding_mode='zeros', align_corners=True) 98 | mask_warped = F.grid_sample(src_mask, pix_coords, mode='nearest', 99 | padding_mode='zeros', align_corners=True) 100 | 101 | # nan handling 102 | inf_depth = torch.isnan(depth_warped) 103 | depth_warped[inf_depth] = 2.0 104 | inf_regions = torch.isnan(mask_warped) 105 | mask_warped[inf_regions] = 0 106 | 107 | pix_coords = pix_coords.permute(0, 3, 1, 2) 108 | invalid_mask = torch.logical_or(pix_coords > 1, pix_coords < -1).sum(dim=1, keepdim=True) > 0 109 | 110 | # range handling 111 | valid_depth_min = (depth_warped > min_depth) 112 | depth_warped[~valid_depth_min] = min_depth 113 | valid_depth_max = (depth_warped < max_depth) 114 | depth_warped[~valid_depth_max] = max_depth 115 | return depth_warped, (~invalid_mask).float() * mask_warped * valid_depth_min * valid_depth_max 116 | 117 | def forward(self, inputs, outputs, cam, rel_pose_dict): 118 | # predict images for each scale(default = scale 0 only) 119 | source_scale = 0 120 | 121 | # ref inputs 122 | ref_color = inputs['color', 0, source_scale][:,cam, ...] 123 | ref_mask = inputs['mask'][:, cam, ...] 124 | ref_K = inputs[('K', source_scale)][:,cam, ...] 125 | ref_invK = inputs[('inv_K', source_scale)][:,cam, ...] 126 | 127 | # output 128 | target_view = outputs[('cam', cam)] 129 | 130 | for scale in self.scales: 131 | ref_depth = target_view[('depth', 0, scale)] 132 | for frame_id in self.frame_ids[1:]: 133 | # for temporal learning 134 | T = target_view[('cam_T_cam', 0, frame_id)] 135 | src_color = inputs['color', frame_id, source_scale][:, cam, ...] 136 | src_mask = inputs['mask'][:, cam, ...] 137 | warped_img, warped_mask = self.get_virtual_image( 138 | src_color, 139 | src_mask, 140 | ref_depth, 141 | ref_invK, 142 | ref_K, 143 | T, 144 | source_scale 145 | ) 146 | 147 | if self.intensity_align: 148 | warped_img = self.get_norm_image_single( 149 | ref_color, 150 | ref_mask, 151 | warped_img, 152 | warped_mask 153 | ) 154 | 155 | target_view[('color', frame_id, scale)] = warped_img 156 | target_view[('color_mask', frame_id, scale)] = warped_mask 157 | 158 | # spatio-temporal learning 159 | if self.spatio or self.spatio_temporal: 160 | for frame_id in self.frame_ids: 161 | overlap_img = torch.zeros_like(ref_color) 162 | overlap_mask = torch.zeros_like(ref_mask) 163 | 164 | for cur_index in self.rel_cam_list[cam]: 165 | # for partial surround view training 166 | if cur_index >= self.num_cams: 167 | continue 168 | 169 | src_color = inputs['color', frame_id, source_scale][:, cur_index, ...] 170 | src_mask = inputs['mask'][:, cur_index, ...] 171 | src_K = inputs[('K', source_scale)][:, cur_index, ...] 172 | 173 | rel_pose = rel_pose_dict[(frame_id, cur_index)] 174 | warped_img, warped_mask = self.get_virtual_image( 175 | src_color, 176 | src_mask, 177 | ref_depth, 178 | ref_invK, 179 | src_K, 180 | rel_pose, 181 | source_scale 182 | ) 183 | 184 | if self.intensity_align: 185 | warped_img = self.get_norm_image_single( 186 | ref_color, 187 | ref_mask, 188 | warped_img, 189 | warped_mask 190 | ) 191 | 192 | # assuming no overlap between warped images 193 | overlap_img = overlap_img + warped_img 194 | overlap_mask = overlap_mask + warped_mask 195 | 196 | target_view[('overlap', frame_id, scale)] = overlap_img 197 | target_view[('overlap_mask', frame_id, scale)] = overlap_mask 198 | 199 | outputs[('cam', cam)] = target_view -------------------------------------------------------------------------------- /models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .single_cam_loss import SingleCamLoss 2 | from .multi_cam_loss import MultiCamLoss 3 | 4 | __all__ = ['SingleCamLoss', 'MultiCamLoss'] -------------------------------------------------------------------------------- /models/losses/base_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BaseLoss(nn.Module): 5 | """ 6 | Base class loss calculation 7 | """ 8 | def __init__(self, cfg, rank): 9 | super().__init__() 10 | self.rank = rank 11 | self.init_weights(cfg) 12 | self.init_attrib(cfg) 13 | 14 | def init_weights(self, cfg): 15 | for attr in cfg.keys(): 16 | if attr == 'loss' or attr == 'training': 17 | for k, v in cfg[attr].items(): 18 | setattr(self, k, v) 19 | 20 | def init_attrib(self, cfg): 21 | for attr in cfg.keys(): 22 | for k, v in cfg[attr].items(): 23 | if k == 'scales' or k=='frame_ids' or k == 'novel_view_mode': 24 | setattr(self, k, v) 25 | 26 | def get_logs(self, loss_dict, output, cam): 27 | """ 28 | This function logs depth and pose information for monitoring training process. 29 | """ 30 | # log statistics 31 | depth_log = output[('depth', 0, 0)].clone().detach() 32 | loss_dict['depth/mean'] = depth_log.mean() 33 | loss_dict['depth/max'] = depth_log.max() 34 | loss_dict['depth/min'] = depth_log.min() 35 | 36 | if cam == 0: 37 | pose_t = output[('cam_T_cam', 0, -1)].clone().detach() 38 | loss_dict['pose/tx'] = pose_t[:, 0, 3].abs().mean() 39 | loss_dict['pose/ty'] = pose_t[:, 1, 3].abs().mean() 40 | loss_dict['pose/tz'] = pose_t[:, 2, 3].abs().mean() 41 | 42 | return loss_dict 43 | 44 | def forward(self, *args, **kwargs): 45 | raise NotImplementedError('Not implemented for BaseLoss') -------------------------------------------------------------------------------- /models/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def compute_auto_masks(reprojection_loss, identity_reprojection_loss): 6 | """ 7 | This function computes auto mask using reprojection loss and identity reprojection loss. 8 | """ 9 | if identity_reprojection_loss is None: 10 | # without using auto(identity loss) mask 11 | reprojection_loss_mask = torch.ones_like(reprojection_loss) 12 | else: 13 | # using auto(identity loss) mask 14 | losses = torch.cat([reprojection_loss, identity_reprojection_loss], dim=1) 15 | idxs = torch.argmin(losses, dim=1, keepdim=True) 16 | reprojection_loss_mask = (idxs == 0).float() 17 | return reprojection_loss_mask 18 | 19 | 20 | def compute_masked_loss(loss, mask): 21 | """ 22 | This function masks losses while avoiding zero division. 23 | """ 24 | return (loss * mask).sum() / (mask.sum() + 1e-8) 25 | 26 | 27 | def compute_edg_smooth_loss(rgb, disp_map): 28 | """ 29 | This function calculates edge-aware smoothness. 30 | """ 31 | grad_rgb_x = (rgb[:, :, :, :-1] - rgb[:, :, :, 1:]).abs().mean(1, True) 32 | grad_rgb_y = (rgb[:, :, :-1, :] - rgb[:, :, 1:, :]).abs().mean(1, True) 33 | 34 | grad_disp_x = (disp_map[:, :, :, :-1] - disp_map[:, :, :, 1:]).abs() 35 | grad_disp_y = (disp_map[:, :, :-1, :] - disp_map[:, :, 1:, :]).abs() 36 | 37 | grad_disp_x *= (-1.0 * grad_rgb_x).exp() 38 | grad_disp_y *= (-1.0 * grad_rgb_y).exp() 39 | return grad_disp_x.mean() + grad_disp_y.mean() 40 | 41 | 42 | def compute_ssim_loss(pred, target): 43 | """ 44 | This function calculates SSIM loss between predicted image and target image. 45 | """ 46 | ref_pad = torch.nn.ReflectionPad2d(1) 47 | pred = ref_pad(pred) 48 | target = ref_pad(target) 49 | 50 | mu_pred = F.avg_pool2d(pred, kernel_size = 3, stride = 1) 51 | mu_target = F.avg_pool2d(target, kernel_size = 3, stride = 1) 52 | 53 | musq_pred = mu_pred.pow(2) 54 | musq_target = mu_target.pow(2) 55 | mu_pred_target = mu_pred*mu_target 56 | 57 | sigma_pred = F.avg_pool2d(pred.pow(2), kernel_size = 3, stride = 1)-musq_pred 58 | sigma_target = F.avg_pool2d(target.pow(2), kernel_size = 3, stride = 1)-musq_target 59 | sigma_pred_target = F.avg_pool2d(pred*target, kernel_size = 3, stride = 1)-mu_pred_target 60 | 61 | C1 = 0.01**2 62 | C2 = 0.03**2 63 | 64 | ssim_map = ((2*mu_pred_target + C1)*(2*sigma_pred_target + C2)) \ 65 | /((musq_pred + musq_target + C1)*(sigma_pred + sigma_target + C2)+1e-8) 66 | return torch.clamp((1-ssim_map)/2, 0, 1) 67 | 68 | 69 | def compute_photometric_loss(pred=None, target=None): 70 | """ 71 | This function calculates photometric reconstruction loss (0.85*SSIM + 0.15*L1) 72 | """ 73 | abs_diff = torch.abs(target - pred) 74 | l1_loss = abs_diff.mean(1, True) 75 | ssim_loss = compute_ssim_loss(pred, target).mean(1, True) 76 | rep_loss = 0.85 * ssim_loss + 0.15 * l1_loss 77 | return rep_loss 78 | -------------------------------------------------------------------------------- /models/losses/multi_cam_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch3d.transforms import matrix_to_euler_angles 3 | 4 | from .loss_util import compute_photometric_loss, compute_masked_loss 5 | from .single_cam_loss import SingleCamLoss 6 | 7 | from lpips import LPIPS 8 | 9 | class MultiCamLoss(SingleCamLoss): 10 | """ 11 | Class for multi-camera(spatio & temporal) loss calculation 12 | """ 13 | def __init__(self, cfg, rank): 14 | super(MultiCamLoss, self).__init__(cfg, rank) 15 | 16 | self.lpips = LPIPS(net="vgg").cuda(rank) 17 | 18 | def compute_spatio_loss(self, inputs, target_view, cam=None, scale=None, ref_mask=None): 19 | """ 20 | This function computes spatial loss. 21 | """ 22 | # self occlusion mask * overlap region mask 23 | spatio_mask = ref_mask * target_view[('overlap_mask', 0, scale)] 24 | loss_args = { 25 | 'pred': target_view[('overlap', 0, scale)], 26 | 'target': inputs['color',0, 0][:,cam, ...] 27 | } 28 | spatio_loss = compute_photometric_loss(**loss_args) 29 | 30 | target_view[('overlap_mask', 0, scale)] = spatio_mask 31 | return compute_masked_loss(spatio_loss, spatio_mask) 32 | 33 | def compute_spatio_tempo_loss(self, inputs, target_view, cam=None, scale=None, ref_mask=None, reproj_loss_mask=None) : 34 | """ 35 | This function computes spatio-temporal loss. 36 | """ 37 | spatio_tempo_losses = [] 38 | spatio_tempo_masks = [] 39 | for frame_id in self.frame_ids[1:]: 40 | 41 | pred_mask = ref_mask * target_view[('overlap_mask', frame_id, scale)] 42 | pred_mask = pred_mask * reproj_loss_mask 43 | 44 | loss_args = { 45 | 'pred': target_view[('overlap', frame_id, scale)], 46 | 'target': inputs['color',0, 0][:,cam, ...] 47 | } 48 | 49 | spatio_tempo_losses.append(compute_photometric_loss(**loss_args)) 50 | spatio_tempo_masks.append(pred_mask) 51 | 52 | # concatenate losses and masks 53 | spatio_tempo_losses = torch.cat(spatio_tempo_losses, 1) 54 | spatio_tempo_masks = torch.cat(spatio_tempo_masks, 1) 55 | 56 | # for the loss, take minimum value between reprojection loss and identity loss(moving object) 57 | # for the mask, take maximum value between reprojection mask and overlap mask to apply losses on all the True values of masks. 58 | spatio_tempo_loss, _ = torch.min(spatio_tempo_losses, dim=1, keepdim=True) 59 | spatio_tempo_mask, _ = torch.max(spatio_tempo_masks.float(), dim=1, keepdim=True) 60 | 61 | return compute_masked_loss(spatio_tempo_loss, spatio_tempo_mask) 62 | 63 | def compute_pose_con_loss(self, inputs, outputs, cam=None, scale=None, ref_mask=None, reproj_loss_mask=None) : 64 | """ 65 | This function computes pose consistency loss in "Full surround monodepth from multiple cameras" 66 | """ 67 | ref_output = outputs[('cam', 0)] 68 | ref_ext = inputs['extrinsics'][:, 0, ...] 69 | ref_ext_inv = inputs['extrinsics_inv'][:, 0, ...] 70 | 71 | cur_output = outputs[('cam', cam)] 72 | cur_ext = inputs['extrinsics'][:, cam, ...] 73 | cur_ext_inv = inputs['extrinsics_inv'][:, cam, ...] 74 | 75 | trans_loss = 0. 76 | angle_loss = 0. 77 | 78 | for frame_id in self.frame_ids[1:]: 79 | ref_T = ref_output[('cam_T_cam', 0, frame_id)] 80 | cur_T = cur_output[('cam_T_cam', 0, frame_id)] 81 | 82 | cur_T_aligned = ref_ext_inv@cur_ext@cur_T@cur_ext_inv@ref_ext 83 | 84 | ref_ang = matrix_to_euler_angles(ref_T[:,:3,:3], 'XYZ') 85 | cur_ang = matrix_to_euler_angles(cur_T_aligned[:,:3,:3], 'XYZ') 86 | 87 | ang_diff = torch.norm(ref_ang - cur_ang, p=2, dim=1).mean() 88 | t_diff = torch.norm(ref_T[:,:3,3] - cur_T_aligned[:,:3,3], p=2, dim=1).mean() 89 | 90 | trans_loss += t_diff 91 | angle_loss += ang_diff 92 | 93 | pose_loss = (trans_loss + 10 * angle_loss) / len(self.frame_ids[1:]) 94 | return pose_loss 95 | 96 | def compute_gaussian_loss(self, inputs, target_view, cam=0, scale=0): 97 | """ 98 | This function computes gaussian loss. 99 | """ 100 | # self occlusion mask * overlap region mask 101 | if self.novel_view_mode == 'MF': 102 | pred = target_view[('gaussian_color', 0, scale)] 103 | target_view = inputs['color', 0, 0][:,cam, ...] 104 | lpips_loss = self.lpips(pred, target_view, normalize=True).mean() 105 | l2_loss = ((pred - target_view)**2).mean() 106 | return 1 * l2_loss + 0.05 * lpips_loss 107 | elif self.novel_view_mode == 'SF': 108 | gaussian_loss = 0.0 109 | for frame_id in self.frame_ids[1:]: 110 | pred = target_view[('gaussian_color', frame_id, scale)] 111 | gt = inputs['color', frame_id, 0][:,cam, ...] 112 | lpips_loss = self.lpips(pred, gt, normalize=True).mean() 113 | l2_loss = ((pred - gt)**2).mean() 114 | gaussian_loss += 1 * l2_loss + 0.05 * lpips_loss 115 | return gaussian_loss / 2 116 | 117 | 118 | def forward(self, inputs, outputs, cam): 119 | loss_dict = {} 120 | cam_loss = 0. # loss across the multi-scale 121 | target_view = outputs[('cam', cam)] 122 | for scale in self.scales: 123 | kargs = { 124 | 'cam': cam, 125 | 'scale': scale, 126 | 'ref_mask': inputs['mask'][:,cam,...] 127 | } 128 | 129 | reprojection_loss = self.compute_reproj_loss(inputs, target_view, **kargs) 130 | smooth_loss = self.compute_smooth_loss(inputs, target_view, **kargs) 131 | spatio_loss = self.compute_spatio_loss(inputs, target_view, **kargs) 132 | 133 | kargs['reproj_loss_mask'] = target_view[('reproj_mask', scale)] 134 | spatio_tempo_loss = self.compute_spatio_tempo_loss(inputs, target_view, **kargs) 135 | 136 | if self.gaussian: 137 | gaussian_loss = self.compute_gaussian_loss(inputs, target_view, cam, scale) 138 | 139 | pose_loss = 0 140 | 141 | cam_loss += reprojection_loss 142 | cam_loss += self.disparity_smoothness * smooth_loss / (2 ** scale) 143 | cam_loss += self.spatio_coeff * spatio_loss + self.spatio_tempo_coeff * spatio_tempo_loss 144 | if self.gaussian: 145 | cam_loss += self.gaussian_coeff * gaussian_loss 146 | cam_loss += self.pose_loss_coeff* pose_loss 147 | 148 | ########################## 149 | # for logger 150 | ########################## 151 | if scale == 0: 152 | loss_dict['reproj_loss'] = reprojection_loss.item() 153 | loss_dict['spatio_loss'] = spatio_loss.item() 154 | if self.gaussian: 155 | loss_dict['gaussian_loss'] = gaussian_loss.item() 156 | loss_dict['spatio_tempo_loss'] = spatio_tempo_loss.item() 157 | loss_dict['smooth'] = smooth_loss.item() 158 | 159 | # log statistics 160 | self.get_logs(loss_dict, target_view, cam) 161 | 162 | cam_loss /= len(self.scales) 163 | return cam_loss, loss_dict -------------------------------------------------------------------------------- /models/losses/single_cam_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .loss_util import compute_photometric_loss, compute_edg_smooth_loss, compute_masked_loss, compute_auto_masks 4 | from .base_loss import BaseLoss 5 | 6 | _EPSILON = 0.00001 7 | 8 | 9 | class SingleCamLoss(BaseLoss): 10 | """ 11 | Class for single camera(temporal only) loss calculation 12 | """ 13 | def __init__(self, cfg, rank): 14 | super().__init__(cfg, rank) 15 | 16 | def compute_reproj_loss(self, inputs, target_view, cam=0, scale=0, ref_mask=None): 17 | """ 18 | This function computes reprojection loss using auto mask. 19 | """ 20 | reprojection_losses = [] 21 | for frame_id in self.frame_ids[1:]: 22 | reproj_loss_args = { 23 | 'pred': target_view[('color', frame_id, scale)], 24 | 'target': inputs['color',0, 0][:, cam, ...] 25 | } 26 | reprojection_losses.append( 27 | compute_photometric_loss(**reproj_loss_args) 28 | ) 29 | 30 | reprojection_losses = torch.cat(reprojection_losses, 1) 31 | reprojection_loss, _ = torch.min(reprojection_losses, dim=1, keepdim=True) 32 | 33 | identity_reprojection_losses = [] 34 | for frame_id in self.frame_ids[1:]: 35 | identity_reproj_loss_args = { 36 | 'pred': inputs[('color', frame_id, 0)][:, cam, ...], 37 | 'target': inputs['color',0, 0][:, cam, ...] 38 | } 39 | identity_reprojection_losses.append( 40 | compute_photometric_loss(**identity_reproj_loss_args) 41 | ) 42 | 43 | identity_reprojection_losses = torch.cat(identity_reprojection_losses, 1) 44 | identity_reprojection_losses = identity_reprojection_losses + \ 45 | _EPSILON * torch.randn(identity_reprojection_losses.shape).to(self.rank) 46 | identity_reprojection_loss, _ = torch.min(identity_reprojection_losses, dim=1, keepdim=True) 47 | 48 | # find minimum losses 49 | reprojection_auto_mask = compute_auto_masks(reprojection_loss, identity_reprojection_loss) 50 | reprojection_auto_mask *= ref_mask 51 | 52 | target_view[('reproj_loss', scale)] = reprojection_auto_mask * reprojection_loss 53 | target_view[('reproj_mask', scale)] = reprojection_auto_mask 54 | return compute_masked_loss(reprojection_loss, reprojection_auto_mask) 55 | 56 | def compute_smooth_loss(self, inputs, target_view, cam = 0, scale = 0, ref_mask=None): 57 | """ 58 | This function computes edge-aware smoothness loss for the disparity map. 59 | """ 60 | color = inputs['color', 0, scale][:, cam, ...] 61 | disp = target_view[('disp', scale)] 62 | mean_disp = disp.mean(2, True).mean(3, True) 63 | norm_disp = disp / (mean_disp + 1e-8) 64 | return compute_edg_smooth_loss(color, norm_disp) 65 | 66 | def forward(self, inputs, outputs, cam): 67 | loss_dict = {} 68 | cam_loss = 0. # loss across the multi-scale 69 | target_view = outputs[('cam', cam)] 70 | for scale in self.scales: 71 | kargs = { 72 | 'cam': cam, 73 | 'scale': scale, 74 | 'ref_mask': inputs['mask'][:,cam,...] 75 | } 76 | 77 | reprojection_loss = self.compute_reproj_loss(inputs, target_view, **kargs) 78 | smooth_loss = self.compute_smooth_loss(inputs, target_view, **kargs) 79 | 80 | cam_loss += reprojection_loss 81 | cam_loss += self.disparity_smoothness * smooth_loss / (2 ** scale) 82 | 83 | ########################## 84 | # for logger 85 | ########################## 86 | if scale == 0: 87 | loss_dict['reproj_loss'] = reprojection_loss.item() 88 | loss_dict['smooth'] = smooth_loss.item() 89 | 90 | # log statistics 91 | self.get_logs(loss_dict, target_view, cam) 92 | 93 | cam_loss /= len(self.scales) 94 | return cam_loss, loss_dict -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | from .pose_network import PoseNetwork 2 | from .depth_network import DepthNetwork 3 | 4 | __all__ = ['DepthNetwork', 'PoseNetwork'] -------------------------------------------------------------------------------- /network/blocks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | def pack_cam_feat(x): 6 | if isinstance(x, dict): 7 | for k, v in x.items(): 8 | b, n_cam = v.shape[:2] 9 | x[k] = v.view(b*n_cam, *v.shape[2:]) 10 | return x 11 | else: 12 | b, n_cam = x.shape[:2] 13 | x = x.view(b*n_cam, *x.shape[2:]) 14 | return x 15 | 16 | 17 | def unpack_cam_feat(x, b, n_cam): 18 | if isinstance(x, dict): 19 | for k, v in x.items(): 20 | x[k] = v.view(b, n_cam, *v.shape[1:]) 21 | return x 22 | else: 23 | x = x.view(b, n_cam, *x.shape[1:]) 24 | return x 25 | 26 | 27 | def upsample(x): 28 | return F.interpolate(x, scale_factor=2, mode='nearest') 29 | 30 | 31 | def conv2d(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, nonlin = 'LRU', padding_mode = 'reflect', norm = False): 32 | if nonlin== 'LRU': 33 | act = nn.LeakyReLU(0.1, inplace=True) 34 | elif nonlin == 'ELU': 35 | act = nn.ELU(inplace=True) 36 | else: 37 | act = nn.Identity() 38 | 39 | if norm: 40 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, 41 | padding=((kernel_size - 1) * dilation) // 2, bias=False, padding_mode=padding_mode) 42 | bnorm = nn.BatchNorm2d(out_planes) 43 | else: 44 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, 45 | padding=((kernel_size - 1) * dilation) // 2, bias=True, padding_mode=padding_mode) 46 | bnorm = nn.Identity() 47 | return nn.Sequential(conv, bnorm, act) 48 | 49 | 50 | def conv1d(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, nonlin='LRU', padding_mode='reflect', norm = False): 51 | if nonlin== 'LRU': 52 | act = nn.LeakyReLU(0.1, inplace=True) 53 | elif nonlin == 'ELU': 54 | act = nn.ELU(inplace=True) 55 | else: 56 | act = nn.Identity() 57 | 58 | if norm: 59 | conv = nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, 60 | padding=((kernel_size - 1) * dilation) // 2, bias=False, padding_mode=padding_mode) 61 | bnorm = nn.BatchNorm1d(out_planes) 62 | else: 63 | conv = nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, 64 | padding=((kernel_size - 1) * dilation) // 2, bias=True, padding_mode=padding_mode) 65 | bnorm = nn.Identity() 66 | 67 | return nn.Sequential(conv, bnorm, act) -------------------------------------------------------------------------------- /network/depth_network.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .blocks import upsample, conv2d, pack_cam_feat, unpack_cam_feat 8 | from .volumetric_fusionnet import VFNet 9 | 10 | from external.layers import ResnetEncoder 11 | 12 | class DepthNetwork(nn.Module): 13 | """ 14 | Depth fusion module 15 | """ 16 | def __init__(self, cfg): 17 | super(DepthNetwork, self).__init__() 18 | self.read_config(cfg) 19 | 20 | # feature encoder 21 | # resnet feat: 64(1/2), 64(1/4), 128(1/8), 256(1/16), 512(1/32) 22 | self.encoder = ResnetEncoder(self.num_layers, self.weights_init, 1) # number of layers, pretrained, number of input images 23 | del self.encoder.encoder.fc # del fc in weights_init 24 | enc_feat_dim = sum(self.encoder.num_ch_enc[self.fusion_level:]) 25 | self.conv1x1 = conv2d(enc_feat_dim, self.fusion_feat_in_dim, kernel_size=1, padding_mode = 'reflect') 26 | 27 | # fusion net 28 | fusion_feat_out_dim = self.encoder.num_ch_enc[self.fusion_level] 29 | self.fusion_net = VFNet(cfg, self.fusion_feat_in_dim, fusion_feat_out_dim, model ='depth') 30 | 31 | # depth decoder 32 | num_ch_enc = self.encoder.num_ch_enc[:(self.fusion_level+1)] 33 | num_ch_dec = [16, 32, 64, 128, 256] 34 | self.decoder = DepthDecoder(self.fusion_level, num_ch_enc, num_ch_dec, self.scales, use_skips = self.use_skips) 35 | 36 | def read_config(self, cfg): 37 | for attr in cfg.keys(): 38 | for k, v in cfg[attr].items(): 39 | setattr(self, k, v) 40 | 41 | def forward(self, inputs): 42 | ''' 43 | dict_keys(['idx', 'sensor_name', 'filename', 'extrinsics', 'mask', 44 | ('K', 0), ('inv_K', 0), ('color', 0, 0), ('color_aug', 0, 0), 45 | ('K', 1), ('inv_K', 1), ('color', 0, 1), ('color_aug', 0, 1), 46 | ('K', 2), ('inv_K', 2), ('color', 0, 2), ('color_aug', 0, 2), 47 | ('K', 3), ('inv_K', 3), ('color', 0, 3), ('color_aug', 0, 3), 48 | ('color', -1, 0), ('color_aug', -1, 0), ('color', 1, 0), ('color_aug', 1, 0), 'extrinsics_inv']) 49 | ''' 50 | 51 | outputs = {} 52 | 53 | # dictionary initialize 54 | for cam in range(self.num_cams): # self.num_cames = 6 55 | outputs[('cam', cam)] = {} # outputs = {('cam', 0): {}, ..., ('cam', 5): {}} 56 | 57 | lev = self.fusion_level # 2 58 | 59 | # packed images for surrounding view 60 | sf_images = torch.stack([inputs[('color_aug', 0, 0)][:, cam, ...] for cam in range(self.num_cams)], 1) 61 | if self.novel_view_mode == 'MF': 62 | sf_images_last = torch.stack([inputs[('color_aug', -1, 0)][:, cam, ...] for cam in range(self.num_cams)], 1) 63 | sf_images_next = torch.stack([inputs[('color_aug', 1, 0)][:, cam, ...] for cam in range(self.num_cams)], 1) 64 | packed_input = pack_cam_feat(sf_images) 65 | if self.novel_view_mode == 'MF': 66 | packed_input_last = pack_cam_feat(sf_images_last) 67 | packed_input_next = pack_cam_feat(sf_images_next) 68 | 69 | # feature encoder 70 | packed_feats = self.encoder(packed_input) 71 | if self.novel_view_mode == 'MF': 72 | packed_feats_last = self.encoder(packed_input_last) 73 | packed_feats_next = self.encoder(packed_input_next) 74 | # aggregate feature H / 2^(lev+1) x W / 2^(lev+1) 75 | _, _, up_h, up_w = packed_feats[lev].size() 76 | 77 | packed_feats_list = packed_feats[lev:lev+1] \ 78 | + [F.interpolate(feat, [up_h, up_w], mode='bilinear', align_corners=True) for feat in packed_feats[lev+1:]] 79 | if self.novel_view_mode == 'MF': 80 | packed_feats_last_list = packed_feats_last[lev:lev+1] \ 81 | + [F.interpolate(feat, [up_h, up_w], mode='bilinear', align_corners=True) for feat in packed_feats_last[lev+1:]] 82 | packed_feats_next_list = packed_feats_next[lev:lev+1] \ 83 | + [F.interpolate(feat, [up_h, up_w], mode='bilinear', align_corners=True) for feat in packed_feats_next[lev+1:]] 84 | 85 | packed_feats_agg = self.conv1x1(torch.cat(packed_feats_list, dim=1)) 86 | if self.novel_view_mode == 'MF': 87 | packed_feats_agg_last = self.conv1x1(torch.cat(packed_feats_last_list, dim=1)) 88 | packed_feats_agg_next = self.conv1x1(torch.cat(packed_feats_next_list, dim=1)) 89 | 90 | feats_agg = unpack_cam_feat(packed_feats_agg, self.batch_size, self.num_cams) 91 | if self.novel_view_mode == 'MF': 92 | feats_agg_last = unpack_cam_feat(packed_feats_agg_last, self.batch_size, self.num_cams) 93 | feats_agg_next = unpack_cam_feat(packed_feats_agg_next, self.batch_size, self.num_cams) 94 | 95 | # fusion_net, backproject each feature into the 3D voxel space 96 | fusion_dict = self.fusion_net(inputs, feats_agg) 97 | if self.novel_view_mode == 'MF': 98 | fusion_dict_last = self.fusion_net(inputs, feats_agg_last) 99 | fusion_dict_next = self.fusion_net(inputs, feats_agg_next) 100 | 101 | feat_in = packed_feats[:lev] + [fusion_dict['proj_feat']] 102 | img_feat = [] 103 | for i in range(len(feat_in)): 104 | img_feat.append(unpack_cam_feat(feat_in[i], self.batch_size, self.num_cams)) 105 | 106 | if self.novel_view_mode == 'MF': 107 | feat_in_last = packed_feats_last[:lev] + [fusion_dict_last['proj_feat']] 108 | img_feat_last = [] 109 | for i in range(len(feat_in_last)): 110 | img_feat_last.append(unpack_cam_feat(feat_in_last[i], self.batch_size, self.num_cams)) 111 | 112 | feat_in_next = packed_feats_next[:lev] + [fusion_dict_next['proj_feat']] 113 | img_feat_next = [] 114 | for i in range(len(feat_in_next)): 115 | img_feat_next.append(unpack_cam_feat(feat_in_next[i], self.batch_size, self.num_cams)) 116 | 117 | packed_depth_outputs = self.decoder(feat_in) 118 | if self.novel_view_mode == 'MF': 119 | packed_depth_outputs_last = self.decoder(feat_in_last) 120 | packed_depth_outputs_next = self.decoder(feat_in_next) 121 | 122 | depth_outputs = unpack_cam_feat(packed_depth_outputs, self.batch_size, self.num_cams) 123 | if self.novel_view_mode == 'MF': 124 | depth_outputs_last = unpack_cam_feat(packed_depth_outputs_last, self.batch_size, self.num_cams) 125 | depth_outputs_next = unpack_cam_feat(packed_depth_outputs_next, self.batch_size, self.num_cams) 126 | 127 | for cam in range(self.num_cams): 128 | for k in depth_outputs.keys(): 129 | outputs[('cam', cam)][k] = depth_outputs[k][:, cam, ...] 130 | outputs[('cam', cam)][('img_feat', 0, 0)] = [feat[:, cam, ...] for feat in img_feat] 131 | if self.novel_view_mode == 'MF': 132 | outputs[('cam', cam)][('img_feat', -1, 0)] = [feat[:, cam, ...] for feat in img_feat_last] 133 | outputs[('cam', cam)][('img_feat', 1, 0)] = [feat[:, cam, ...] for feat in img_feat_next] 134 | outputs[('cam', cam)][('disp', -1, 0)] = depth_outputs_last[('disp', 0)][:, cam, ...] 135 | outputs[('cam', cam)][('disp', 1, 0)] = depth_outputs_next[('disp', 0)][:, cam, ...] 136 | 137 | return outputs 138 | 139 | 140 | class DepthDecoder(nn.Module): 141 | def __init__(self, level_in, num_ch_enc, num_ch_dec, scales=range(2), use_skips=False): 142 | super(DepthDecoder, self).__init__() 143 | 144 | self.num_output_channels = 1 145 | self.scales = scales 146 | self.use_skips = use_skips 147 | 148 | self.level_in = level_in 149 | self.num_ch_enc = num_ch_enc 150 | self.num_ch_dec = num_ch_dec 151 | 152 | self.convs = OrderedDict() 153 | for i in range(self.level_in, -1, -1): 154 | num_ch_in = self.num_ch_enc[-1] if i == self.level_in else self.num_ch_dec[i + 1] 155 | num_ch_out = self.num_ch_dec[i] 156 | self.convs[('upconv', i, 0)] = conv2d(num_ch_in, num_ch_out, kernel_size=3, nonlin = 'ELU') 157 | 158 | num_ch_in = self.num_ch_dec[i] 159 | if self.use_skips and i > 0: 160 | num_ch_in += self.num_ch_enc[i - 1] 161 | num_ch_out = self.num_ch_dec[i] 162 | self.convs[('upconv', i, 1)] = conv2d(num_ch_in, num_ch_out, kernel_size=3, nonlin = 'ELU') 163 | 164 | for s in self.scales: 165 | self.convs[('dispconv', s)] = conv2d(self.num_ch_dec[s], self.num_output_channels, 3, nonlin = None) 166 | 167 | self.decoder = nn.ModuleList(list(self.convs.values())) 168 | self.sigmoid = nn.Sigmoid() 169 | 170 | def forward(self, input_features): 171 | outputs = {} 172 | 173 | # decode 174 | x = input_features[-1] 175 | for i in range(self.level_in, -1, -1): 176 | x = self.convs[('upconv', i, 0)](x) 177 | x = [upsample(x)] 178 | if self.use_skips and i > 0: 179 | x += [input_features[i - 1]] 180 | x = torch.cat(x, 1) 181 | x = self.convs[('upconv', i, 1)](x) 182 | if i in self.scales: 183 | outputs[('disp', i)] = self.sigmoid(self.convs[('dispconv', i)](x)) 184 | return outputs -------------------------------------------------------------------------------- /network/pose_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .blocks import conv2d, pack_cam_feat, unpack_cam_feat 6 | from .volumetric_fusionnet import VFNet 7 | 8 | from external.layers import ResnetEncoder, PoseDecoder 9 | 10 | 11 | class PoseNetwork(nn.Module): 12 | """ 13 | Canonical motion estimation module. 14 | """ 15 | def __init__(self, cfg): 16 | super(PoseNetwork, self).__init__() 17 | self.read_config(cfg) 18 | 19 | # feature encoder 20 | # resnet feat: 64(1/2), 64(1/4), 128(1/8), 256(1/16), 512(1/32) 21 | self.encoder = ResnetEncoder(self.num_layers, self.weights_init, 2) # number of layers, pretrained, number of input images 22 | del self.encoder.encoder.fc 23 | enc_feat_dim = sum(self.encoder.num_ch_enc[self.fusion_level:]) 24 | self.conv1x1 = conv2d(enc_feat_dim, self.fusion_feat_in_dim, kernel_size=1, padding_mode='reflect') 25 | 26 | # fusion net 27 | fusion_feat_out_dim = self.encoder.num_ch_enc[self.fusion_level] 28 | self.fusion_net = VFNet(cfg, self.fusion_feat_in_dim, fusion_feat_out_dim, model = 'pose') 29 | 30 | # depth decoder 31 | self.pose_decoder = PoseDecoder(num_ch_enc = [fusion_feat_out_dim], 32 | num_input_features=1, 33 | num_frames_to_predict_for=1, 34 | stride=2) 35 | 36 | def read_config(self, cfg): 37 | for attr in cfg.keys(): 38 | for k, v in cfg[attr].items(): 39 | setattr(self, k, v) 40 | 41 | def forward(self, inputs, frame_ids, _): 42 | outputs = {} 43 | 44 | # initialize dictionary 45 | for cam in range(self.num_cams): 46 | outputs[('cam', cam)] = {} 47 | 48 | lev = self.fusion_level 49 | 50 | # packed images for surrounding view 51 | cur_image = inputs[('color_aug', frame_ids[0], 0)] 52 | next_image = inputs[('color_aug', frame_ids[1], 0)] 53 | 54 | pose_images = torch.cat([cur_image, next_image], 2) 55 | packed_pose_images = pack_cam_feat(pose_images) 56 | 57 | packed_feats = self.encoder(packed_pose_images) 58 | 59 | # aggregate feature H / 2^(lev+1) x W / 2^(lev+1) 60 | _, _, up_h, up_w = packed_feats[lev].size() 61 | 62 | packed_feats_list = packed_feats[lev:lev+1] \ 63 | + [F.interpolate(feat, [up_h, up_w], mode='bilinear', align_corners=True) for feat in packed_feats[lev+1:]] 64 | 65 | packed_feats_agg = self.conv1x1(torch.cat(packed_feats_list, dim=1)) 66 | feats_agg = unpack_cam_feat(packed_feats_agg, self.batch_size, self.num_cams) 67 | 68 | # fusion_net, backproject each feature into the 3D voxel space 69 | bev_feat = self.fusion_net(inputs, feats_agg) 70 | axis_angle, translation = self.pose_decoder([[bev_feat]]) 71 | return axis_angle, torch.clamp(translation, -4.0, 4.0) # for DDAD dataset -------------------------------------------------------------------------------- /network/volumetric_fusionnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from pytorch3d.transforms import axis_angle_to_matrix 5 | 6 | from .blocks import conv2d, conv1d, pack_cam_feat 7 | 8 | 9 | class VFNet(nn.Module): 10 | """ 11 | Surround-view fusion module that estimates a single 3D feature using surround-view images 12 | """ 13 | def __init__(self, cfg, feat_in_dim, feat_out_dim, model='depth'): 14 | super(VFNet, self).__init__() 15 | self.read_config(cfg) 16 | self.eps = 1e-8 17 | self.model = model 18 | # define the 3D voxel space(follows the DDAD extrinsic coordinate -- x: forward, y: left, z: up) 19 | # define voxel end range in accordance with voxel_str_p, voxel_size, voxel_unit_size 20 | self.voxel_end_p = [self.voxel_str_p[i] + self.voxel_unit_size[i] * (self.voxel_size[i] - 1) for i in range(3)] 21 | 22 | # define a voxel space, [1, 3, z, y, x], each voxel contains its 3D position 23 | voxel_grid = self.create_voxel_grid(self.voxel_str_p, self.voxel_end_p, self.voxel_size) 24 | b, _, self.z_dim, self.y_dim, self.x_dim = voxel_grid.size() 25 | self.n_voxels = self.z_dim * self.y_dim * self.x_dim 26 | ones = torch.ones(self.batch_size, 1, self.n_voxels) 27 | self.voxel_pts = torch.cat([voxel_grid.view(b, 3, self.n_voxels), ones], dim=1) 28 | 29 | # define grids in pixel space 30 | self.img_h = self.height // (2 ** (self.fusion_level+1)) 31 | self.img_w = self.width // (2 ** (self.fusion_level+1)) 32 | self.num_pix = self.img_h * self.img_w 33 | self.pixel_grid = self.create_pixel_grid(self.batch_size, self.img_h, self.img_w) 34 | self.pixel_ones = torch.ones(self.batch_size, 1, self.proj_d_bins, self.num_pix) 35 | 36 | # define a depth grid for projection 37 | depth_bins = torch.linspace(self.proj_d_str, self.proj_d_end, self.proj_d_bins) 38 | self.depth_grid = self.create_depth_grid(self.batch_size, self.num_pix, self.proj_d_bins, depth_bins) 39 | 40 | # depth fusion(process overlap and non-overlap regions) 41 | if model == 'depth': 42 | # voxel - preprocessing layer 43 | self.v_dim_o = [(feat_in_dim + 1) * 2] + self.voxel_pre_dim 44 | self.v_dim_no = [feat_in_dim + 1] + self.voxel_pre_dim 45 | 46 | self.conv_overlap = conv1d(self.v_dim_o[0], self.v_dim_o[1], kernel_size=1) 47 | self.conv_non_overlap = conv1d(self.v_dim_no[0], self.v_dim_no[1], kernel_size=1) 48 | 49 | encoder_dims = self.proj_d_bins * self.v_dim_o[-1] 50 | stride = 1 51 | 52 | else: 53 | encoder_dims = (feat_in_dim + 1)*self.z_dim 54 | stride = 2 55 | 56 | # channel dimension reduction 57 | self.reduce_dim = nn.Sequential(*conv2d(encoder_dims, 256, kernel_size=3, stride = stride).children(), 58 | *conv2d(256, feat_out_dim, kernel_size=3, stride = stride).children()) 59 | 60 | def read_config(self, cfg): 61 | for attr in cfg.keys(): 62 | for k, v in cfg[attr].items(): 63 | setattr(self, k, v) 64 | 65 | def create_voxel_grid(self, str_p, end_p, v_size): 66 | grids = [torch.linspace(str_p[i], end_p[i], v_size[i]) for i in range(3)] 67 | 68 | x_dim, y_dim, z_dim = v_size 69 | grids[0] = grids[0].view(1, 1, 1, 1, x_dim) 70 | grids[1] = grids[1].view(1, 1, 1, y_dim, 1) 71 | grids[2] = grids[2].view(1, 1, z_dim, 1, 1) 72 | 73 | grids = [grid.expand(self.batch_size, 1, z_dim, y_dim, x_dim) for grid in grids] 74 | return torch.cat(grids, 1) 75 | 76 | def create_pixel_grid(self, batch_size, height, width): 77 | grid_xy = torch.meshgrid(torch.arange(width), torch.arange(height), indexing='xy') 78 | pix_coords = torch.stack(grid_xy, axis=0).unsqueeze(0).view(1, 2, height * width) 79 | pix_coords = pix_coords.repeat(batch_size, 1, 1) 80 | ones = torch.ones(batch_size, 1, height * width) 81 | pix_coords = torch.cat([pix_coords, ones], 1) 82 | return pix_coords 83 | 84 | def create_depth_grid(self, batch_size, n_pixels, n_depth_bins, depth_bins): 85 | depth_layers = [] 86 | for d in depth_bins: 87 | depth_layer = torch.ones((1, n_pixels)) * d 88 | depth_layers.append(depth_layer) 89 | depth_layers = torch.cat(depth_layers, dim=0).view(1, 1, n_depth_bins, n_pixels) 90 | depth_layers = depth_layers.expand(batch_size, 3, n_depth_bins, n_pixels) 91 | return depth_layers 92 | 93 | def type_check(self, sample_tensor): 94 | d_dtype, d_device = sample_tensor.dtype, sample_tensor.device 95 | if (self.voxel_pts.dtype != d_dtype) or (self.voxel_pts.device != d_device): 96 | self.voxel_pts = self.voxel_pts.to(device=d_device, dtype=d_dtype) 97 | self.pixel_grid = self.pixel_grid.to(device=d_device, dtype=d_dtype) 98 | self.depth_grid = self.depth_grid.to(device=d_device, dtype=d_dtype) 99 | self.pixel_ones = self.pixel_ones.to(device=d_device, dtype=d_dtype) 100 | 101 | def backproject_into_voxel(self, feats_agg, input_mask, intrinsics, extrinsics_inv): 102 | voxel_feat_list = [] 103 | voxel_mask_list = [] 104 | 105 | for cam in range(self.num_cams): 106 | feats_img = feats_agg[:, cam, ...] 107 | _, _, h_dim, w_dim = feats_img.size() 108 | 109 | mask_img = input_mask[:, cam, ...] 110 | mask_img = F.interpolate(mask_img, [h_dim, w_dim], mode='bilinear', align_corners=True) 111 | 112 | # 3D points in the voxel grid -> 3D points referenced at each view. [b, 3, n_voxels] 113 | ext_inv_mat = extrinsics_inv[:, cam, :3, :] 114 | v_pts_local = torch.matmul(ext_inv_mat, self.voxel_pts) 115 | 116 | # calculate pixel coordinate that each point are projected in the image. [b, n_voxels, 1, 2] 117 | K_mat = intrinsics[:, cam, :, :] 118 | pix_coords = self.calculate_sample_pixel_coords(K_mat, v_pts_local, w_dim, h_dim) 119 | 120 | # compute validity mask. [b, 1, n_voxels] 121 | valid_mask = self.calculate_valid_mask(mask_img, pix_coords, v_pts_local) 122 | 123 | # retrieve each per-pixel feature. [b, feat_dim, n_voxels, 1] 124 | feat_warped = F.grid_sample(feats_img, pix_coords, mode='bilinear', padding_mode='zeros', align_corners=True) 125 | # concatenate relative depth as the feature. [b, feat_dim + 1, n_voxels] 126 | feat_warped = torch.cat([feat_warped.squeeze(-1), v_pts_local[:, 2:3, :]/(self.voxel_size[0])], dim=1) 127 | feat_warped = feat_warped * valid_mask.float() 128 | 129 | voxel_feat_list.append(feat_warped) # [[batch_size, feat_dim, n_voxels], ...] 130 | voxel_mask_list.append(valid_mask) 131 | 132 | # compute overlap region 133 | voxel_mask_count = torch.sum(torch.cat(voxel_mask_list, dim=1), dim=1, keepdim=True) 134 | 135 | if self.model == 'depth': 136 | # discriminatively process overlap and non_overlap regions using different MLPs 137 | voxel_non_overlap = self.preprocess_non_overlap(voxel_feat_list, voxel_mask_list, voxel_mask_count) 138 | voxel_overlap = self.preprocess_overlap(voxel_feat_list, voxel_mask_list, voxel_mask_count) 139 | voxel_feat = voxel_non_overlap + voxel_overlap # [batch_size, feat_dim, n_voxels] 140 | 141 | elif self.model == 'pose': 142 | voxel_feat = torch.sum(torch.stack(voxel_feat_list, dim=1), dim=1, keepdim=False) 143 | voxel_feat = voxel_feat/(voxel_mask_count+1e-7) 144 | 145 | return voxel_feat 146 | 147 | def calculate_sample_pixel_coords(self, K, v_pts, w_dim, h_dim): 148 | """ 149 | This function calculates pixel coords for each point([batch, n_voxels, 1, 2]) to sample the per-pixel feature. 150 | """ 151 | cam_points = torch.matmul(K[:, :3, :3], v_pts) # [batch_size, 3, n_voxels] 152 | pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps) 153 | 154 | if not torch.all(torch.isfinite(pix_coords)): 155 | pix_coords = torch.clamp(pix_coords, min=-w_dim*2, max=w_dim*2) 156 | 157 | pix_coords = pix_coords.view(self.batch_size, 2, self.n_voxels, 1) #[batch_size, 2, n_voxels] -> [batch_size, 2, n_voxels, 1] 158 | pix_coords = pix_coords.permute(0, 2, 3, 1) # [batch_size, 2, n_voxels, 1] -> [batch_size, n_voxels, 1, 2] 159 | pix_coords[:, :, :, 0] = pix_coords[:, :, :, 0] / (w_dim - 1) 160 | pix_coords[:, :, :, 1] = pix_coords[:, :, :, 1] / (h_dim - 1) 161 | pix_coords = (pix_coords - 0.5) * 2 # [batch_size, n_voxels, 1, 2] 162 | return pix_coords 163 | 164 | def calculate_valid_mask(self, mask_img, pix_coords, v_pts_local): 165 | """ 166 | This function creates valid mask in voxel coordinate by projecting self-occlusion mask to 3D voxel coords. 167 | """ 168 | # compute validity mask, [b, 1, n_voxels, 1] 169 | mask_selfocc = (F.grid_sample(mask_img, pix_coords, mode='nearest', padding_mode='zeros', align_corners=True) > 0.5) 170 | # discard points behind the camera, [b, 1, n_voxels] 171 | mask_depth = (v_pts_local[:, 2:3, :] > 0) 172 | # compute validity mask, [b, 1, n_voxels, 1] 173 | pix_coords_mask = pix_coords.permute(0, 3, 1, 2) 174 | mask_oob = ~(torch.logical_or(pix_coords_mask > 1, pix_coords_mask < -1).sum(dim=1, keepdim=True) > 0) 175 | valid_mask = mask_selfocc.squeeze(-1) * mask_depth * mask_oob.squeeze(-1) 176 | return valid_mask 177 | 178 | def preprocess_non_overlap(self, voxel_feat_list, voxel_mask_list, voxel_mask_count): 179 | """ 180 | This function applies 1x1 convolutions to features from non-overlapping features. 181 | """ 182 | non_overlap_mask = (voxel_mask_count == 1) 183 | voxel = sum(voxel_feat_list) 184 | voxel = voxel * non_overlap_mask.float() 185 | 186 | for conv_no in self.conv_non_overlap: 187 | voxel = conv_no(voxel) 188 | return voxel * non_overlap_mask.float() 189 | 190 | def preprocess_overlap(self, voxel_feat_list, voxel_mask_list, voxel_mask_count): 191 | """ 192 | This function applies 1x1 convolutions on overlapping features. 193 | Camera configuration [0,1,2] or [0,1,2,3,4,5]: 194 | 3 1 195 | rear cam <- 5 0 -> front cam 196 | 4 2 197 | """ 198 | overlap_mask = (voxel_mask_count == 2) 199 | if self.num_cams == 3: 200 | feat1 = voxel_feat_list[0] 201 | feat2 = voxel_feat_list[1] + voxel_feat_list[2] 202 | elif self.num_cams == 6: 203 | feat1 = voxel_feat_list[0] + voxel_feat_list[3] + voxel_feat_list[4] 204 | feat2 = voxel_feat_list[1] + voxel_feat_list[2] + voxel_feat_list[5] 205 | else: 206 | raise NotImplementedError 207 | 208 | voxel = torch.cat([feat1, feat2], dim=1) 209 | for conv_o in self.conv_overlap: 210 | voxel = conv_o(voxel) 211 | return voxel * overlap_mask.float() 212 | 213 | def project_voxel_into_image(self, voxel_feat, inv_K, extrinsics): 214 | """ 215 | This function projects voxels into 2D image coordinate. 216 | [b, feat_dim, n_voxels] -> [b, feat_dim, d, h, w] 217 | """ 218 | # define depth bin 219 | # [b, feat_dim, n_voxels] -> [b, feat_dim, d, h, w] 220 | b, feat_dim, _ = voxel_feat.size() 221 | voxel_feat = voxel_feat.view(b, feat_dim, self.z_dim, self.y_dim, self.x_dim) 222 | 223 | proj_feats = [] 224 | for cam in range(self.num_cams): 225 | # construct 3D point grid for each view 226 | cam_points = torch.matmul(inv_K[:, cam, :3, :3], self.pixel_grid) 227 | cam_points = self.depth_grid * cam_points.view(self.batch_size, 3, 1, self.num_pix) 228 | cam_points = torch.cat([cam_points, self.pixel_ones], dim=1) # [b, 4, n_depthbins, n_pixels] 229 | cam_points = cam_points.view(self.batch_size, 4, -1) # [b, 4, n_depthbins * n_pixels] 230 | 231 | # apply extrinsic: local 3D point -> global coordinate, [b, 3, n_depthbins * n_pixels] 232 | points = torch.matmul(extrinsics[:, cam, :3, :], cam_points) 233 | 234 | # 3D grid_sample [b, n_voxels, 3], value: (x, y, z) point 235 | grid = points.permute(0, 2, 1) 236 | 237 | for i in range(3): 238 | v_length = self.voxel_end_p[i] - self.voxel_str_p[i] 239 | grid[:, :, i] = (grid[:, :, i] - self.voxel_str_p[i]) / v_length * 2. - 1. 240 | 241 | grid = grid.view(self.batch_size, self.proj_d_bins, self.img_h, self.img_w, 3) 242 | proj_feat = F.grid_sample(voxel_feat, grid, mode='bilinear', padding_mode='zeros', align_corners=True) 243 | proj_feat = proj_feat.view(b, self.proj_d_bins * self.v_dim_o[-1], self.img_h, self.img_w) 244 | 245 | # conv, reduce dimension 246 | proj_feat = self.reduce_dim(proj_feat) 247 | proj_feats.append(proj_feat) 248 | return proj_feats 249 | 250 | def augment_extrinsics(self, ext): 251 | """ 252 | This function augments depth estimation results using augmented extrinsics [batch, cam, 4, 4] 253 | """ 254 | with torch.no_grad(): 255 | b, cam, _, _ = ext.size() 256 | ext_aug = ext.clone() 257 | 258 | # rotation augmentation 259 | angle = torch.rand(b, cam, 3) 260 | self.aug_angle = [15, 15, 40] 261 | for i in range(3): 262 | angle[:, :, i] = (angle[:, :, i] - 0.5) * self.aug_angle[i] 263 | angle_mat = axis_angle_to_matrix(angle) # 3x3 264 | tform_mat = torch.eye(4).repeat(b, cam, 1, 1) 265 | tform_mat[:, :, :3, :3] = angle_mat 266 | tform_mat = tform_mat.to(device=ext.device, dtype=ext.dtype) 267 | 268 | ext_aug = tform_mat @ ext_aug 269 | return ext_aug 270 | 271 | def forward(self, inputs, feats_agg): 272 | mask = inputs['mask'] 273 | K = inputs['K', self.fusion_level+1] 274 | inv_K = inputs['inv_K', self.fusion_level+1] 275 | extrinsics = inputs['extrinsics'] 276 | extrinsics_inv = inputs['extrinsics_inv'] 277 | 278 | fusion_dict = {} 279 | for cam in range(self.num_cams): 280 | fusion_dict[('cam', cam)] = {} 281 | 282 | # device, dtype check, match dtype and device 283 | sample_tensor = feats_agg[0, 0, ...] # B, n_cam, c, h, w 284 | self.type_check(sample_tensor) 285 | 286 | # backproject each per-pixel feature into 3D space (or sample per-pixel features for each voxel) 287 | voxel_feat = self.backproject_into_voxel(feats_agg, mask, K, extrinsics_inv) # [batch_size, feat_dim, n_voxels] 288 | 289 | if self.model == 'depth': 290 | # for each pixel, collect voxel features -> output image feature 291 | proj_feats = self.project_voxel_into_image(voxel_feat, inv_K, extrinsics) 292 | fusion_dict['proj_feat'] = pack_cam_feat(torch.stack(proj_feats, 1)) 293 | return fusion_dict 294 | 295 | elif self.model == 'pose': 296 | b, c, _ = voxel_feat.shape 297 | voxel_feat = voxel_feat.view(b, c*self.z_dim, 298 | self.y_dim, self.x_dim) 299 | bev_feat= self.reduce_dim(voxel_feat) 300 | return bev_feat 301 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | anyio==4.5.2 2 | argon2-cffi==23.1.0 3 | argon2-cffi-bindings==21.2.0 4 | arrow==1.3.0 5 | asttokens==2.4.1 6 | async-lru==2.0.4 7 | attrs==24.2.0 8 | babel==2.16.0 9 | backcall==0.2.0 10 | beautifulsoup4==4.12.3 11 | bleach==6.1.0 12 | boto3==1.34.74 13 | botocore==1.34.162 14 | cachetools==5.5.0 15 | certifi==2024.8.30 16 | cffi==1.17.1 17 | charset-normalizer==3.4.0 18 | comm==0.2.2 19 | cycler==0.12.1 20 | debugpy==1.8.7 21 | decorator==5.1.1 22 | defusedxml==0.7.1 23 | descartes==1.1.0 24 | diskcache==5.6.3 25 | e3nn==0.4.0 26 | einops==0.7.0 27 | exceptiongroup==1.2.2 28 | executing==2.1.0 29 | fastjsonschema==2.20.0 30 | fire==0.7.0 31 | fonttools==4.54.1 32 | fqdn==1.5.1 33 | fvcore==0.1.5.post20221221 34 | h11==0.14.0 35 | httpcore==1.0.6 36 | httpx==0.27.2 37 | idna==3.10 38 | imageio==2.35.1 39 | importlib_metadata==8.5.0 40 | importlib_resources==6.4.5 41 | iopath==0.1.10 42 | ipykernel==6.29.5 43 | ipython==8.12.3 44 | ipywidgets==8.1.5 45 | isoduration==20.11.0 46 | jaxtyping==0.2.19 47 | jedi==0.19.1 48 | Jinja2==3.1.4 49 | jmespath==1.0.1 50 | joblib==1.4.2 51 | json5==0.9.25 52 | jsonpointer==3.0.0 53 | jsonschema==4.23.0 54 | jsonschema-specifications==2023.12.1 55 | jupyter==1.1.1 56 | jupyter-console==6.6.3 57 | jupyter-events==0.10.0 58 | jupyter-lsp==2.2.5 59 | jupyter_client==8.6.3 60 | jupyter_core==5.7.2 61 | jupyter_server==2.14.2 62 | jupyter_server_terminals==0.5.3 63 | jupyterlab==4.2.5 64 | jupyterlab_pygments==0.3.0 65 | jupyterlab_server==2.27.3 66 | jupyterlab_widgets==3.0.13 67 | kiwisolver==1.4.7 68 | lazy_loader==0.4 69 | lpips==0.1.4 70 | MarkupSafe==2.1.5 71 | matplotlib==3.5.3 72 | matplotlib-inline==0.1.7 73 | mistune==3.0.2 74 | mpmath==1.3.0 75 | nbclient==0.10.0 76 | nbconvert==7.16.4 77 | nbformat==5.10.4 78 | nest-asyncio==1.6.0 79 | networkx==3.1 80 | notebook==7.2.2 81 | notebook_shim==0.2.4 82 | numpy==1.24.4 83 | nuscenes-devkit==1.1.9 84 | opencv-python==4.10.0.84 85 | opt-einsum-fx==0.1.4 86 | opt_einsum==3.4.0 87 | overrides==7.7.0 88 | packaging==24.1 89 | pandas==2.0.3 90 | pandocfilters==1.5.1 91 | parso==0.8.4 92 | pexpect==4.9.0 93 | pickleshare==0.7.5 94 | Pillow==9.5.0 95 | pkgutil_resolve_name==1.3.10 96 | platformdirs==4.3.6 97 | portalocker==2.10.1 98 | prometheus_client==0.21.0 99 | prompt_toolkit==3.0.48 100 | protobuf==3.20.3 101 | psutil==6.1.0 102 | ptyprocess==0.7.0 103 | pure_eval==0.2.3 104 | pycocotools==2.0.7 105 | pycparser==2.22 106 | Pygments==2.18.0 107 | pyparsing==3.1.4 108 | pyquaternion==0.9.5 109 | python-dateutil==2.9.0.post0 110 | python-json-logger==2.0.7 111 | pytorch3d==0.3.0 112 | pytz==2024.2 113 | PyWavelets==1.4.1 114 | PyYAML==6.0.1 115 | pyzmq==26.2.0 116 | referencing==0.35.1 117 | requests==2.32.3 118 | rfc3339-validator==0.1.4 119 | rfc3986-validator==0.1.1 120 | rpds-py==0.20.0 121 | s3transfer==0.10.3 122 | scikit-image==0.21.0 123 | scikit-learn==1.3.2 124 | scipy==1.10.1 125 | Send2Trash==1.8.3 126 | shapely==2.0.6 127 | six==1.16.0 128 | sniffio==1.3.1 129 | soupsieve==2.6 130 | stack-data==0.6.3 131 | sympy==1.13.3 132 | tabulate==0.9.0 133 | tenacity==8.2.3 134 | tensorboardX==2.6.2.2 135 | termcolor==2.4.0 136 | terminado==0.18.1 137 | threadpoolctl==3.5.0 138 | tifffile==2023.7.10 139 | tinycss2==1.3.0 140 | tomli==2.0.2 141 | tornado==6.4.1 142 | tqdm==4.66.5 143 | traitlets==5.14.3 144 | trimesh==4.5.2 145 | typeguard==4.3.0 146 | types-python-dateutil==2.9.0.20241003 147 | typing_extensions==4.12.2 148 | tzdata==2024.2 149 | uri-template==1.3.0 150 | urllib3==1.26.20 151 | wcwidth==0.2.13 152 | webcolors==24.8.0 153 | webencodings==0.5.1 154 | websocket-client==1.8.0 155 | widgetsnbextension==4.0.13 156 | xarray==2023.1.0 157 | yacs==0.1.8 158 | zipp==3.20.2 159 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | torch.backends.cudnn.deterministic = True 4 | torch.backends.cudnn.benchmark = False 5 | torch.backends.cuda.matmul.allow_tf32 = False 6 | torch.manual_seed(0) 7 | 8 | import utils 9 | from models import DrivingForwardModel 10 | from trainer import DrivingForwardTrainer 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser(description='training script') 14 | parser.add_argument('--config_file', default ='./configs/nuscenes/main.yaml', type=str, help='config yaml file') 15 | parser.add_argument('--novel_view_mode', default='MF', type=str, help='MF of SF') 16 | args = parser.parse_args() 17 | return args 18 | 19 | def train(cfg): 20 | model = DrivingForwardModel(cfg, 0) 21 | trainer = DrivingForwardTrainer(cfg, 0) 22 | trainer.learn(model) 23 | 24 | if __name__ == '__main__': 25 | args = parse_args() 26 | cfg = utils.get_config(args.config_file, mode='train', novel_view_mode=args.novel_view_mode) 27 | 28 | train(cfg) 29 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import DrivingForwardTrainer 2 | 3 | __all__ = ['DrivingForwardTrainer'] 4 | -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import defaultdict 3 | from tqdm import tqdm 4 | 5 | import torch 6 | import torch.distributed as dist 7 | from torch import Tensor 8 | 9 | from utils import Logger 10 | 11 | from lpips import LPIPS 12 | from jaxtyping import Float, UInt8 13 | from skimage.metrics import structural_similarity 14 | from einops import reduce 15 | 16 | from PIL import Image 17 | from pathlib import Path 18 | from einops import rearrange, repeat 19 | from typing import Union 20 | import numpy as np 21 | 22 | from tqdm import tqdm 23 | 24 | FloatImage = Union[ 25 | Float[Tensor, "height width"], 26 | Float[Tensor, "channel height width"], 27 | Float[Tensor, "batch channel height width"], 28 | ] 29 | 30 | class DrivingForwardTrainer: 31 | """ 32 | Trainer class for training and evaluation 33 | """ 34 | def __init__(self, cfg, rank, use_tb=True): 35 | self.read_config(cfg) 36 | self.rank = rank 37 | if rank == 0: 38 | self.logger = Logger(cfg, use_tb) 39 | self.depth_metric_names = self.logger.get_metric_names() 40 | 41 | self.lpips = LPIPS(net="vgg").cuda(rank) 42 | 43 | def read_config(self, cfg): 44 | for attr in cfg.keys(): 45 | for k, v in cfg[attr].items(): 46 | setattr(self, k, v) 47 | 48 | def learn(self, model): 49 | """ 50 | This function sets training process. 51 | """ 52 | train_dataloader = model.train_dataloader() 53 | if self.rank == 0: 54 | val_dataloader = model.val_dataloader() 55 | self.val_iter = iter(val_dataloader) 56 | 57 | self.step = 0 58 | start_time = time.time() 59 | for self.epoch in range(self.num_epochs): 60 | 61 | self.train(model, train_dataloader, start_time) 62 | 63 | # save model after each epoch using rank 0 gpu 64 | if self.rank == 0: 65 | model.save_model(self.epoch) 66 | print('-'*110) 67 | 68 | if self.ddp_enable: 69 | dist.barrier() 70 | 71 | if self.rank == 0: 72 | self.logger.close_tb() 73 | 74 | def train(self, model, data_loader, start_time): 75 | """ 76 | This function trains models. 77 | """ 78 | model.set_train() 79 | pbar = tqdm(total=len(data_loader), desc='training on epoch {}'.format(self.epoch), mininterval=100) 80 | for batch_idx, inputs in enumerate(data_loader): 81 | before_op_time = time.time() 82 | model.optimizer.zero_grad(set_to_none=True) 83 | outputs, losses = model.process_batch(inputs, self.rank) 84 | losses['total_loss'].backward() 85 | model.optimizer.step() 86 | 87 | if self.rank == 0: 88 | self.logger.update( 89 | 'train', 90 | self.epoch, 91 | self.world_size, 92 | batch_idx, 93 | self.step, 94 | start_time, 95 | before_op_time, 96 | inputs, 97 | outputs, 98 | losses 99 | ) 100 | 101 | if self.logger.is_checkpoint(self.step): 102 | self.validate(model) 103 | 104 | self.step += 1 105 | pbar.update(1) 106 | 107 | pbar.close() 108 | model.lr_scheduler.step() 109 | 110 | @torch.no_grad() 111 | def validate(self, model, vis_results=False): 112 | """ 113 | This function validates models on the validation dataset to monitor training process. 114 | """ 115 | val_dataloader = model.val_dataloader() 116 | val_iter = iter(val_dataloader) 117 | 118 | # Ensure the model is in validation mode 119 | model.set_val() 120 | 121 | avg_reconstruction_metric = defaultdict(float) 122 | 123 | inputs = next(val_iter) 124 | outputs, _ = model.process_batch(inputs, self.rank) 125 | 126 | psnr, ssim, lpips= self.compute_reconstruction_metrics(inputs, outputs) 127 | 128 | avg_reconstruction_metric['psnr'] += psnr 129 | avg_reconstruction_metric['ssim'] += ssim 130 | avg_reconstruction_metric['lpips'] += lpips 131 | 132 | print('Validation reconstruction result...\n') 133 | print(f"\n{inputs['token'][0]}") 134 | self.logger.print_perf(avg_reconstruction_metric, 'reconstruction') 135 | 136 | # Set the model back to training mode 137 | model.set_train() 138 | 139 | @torch.no_grad() 140 | def evaluate(self, model): 141 | """ 142 | This function evaluates models on validation dataset of samples with context. 143 | """ 144 | eval_dataloader = model.eval_dataloader() 145 | 146 | # load model 147 | model.load_weights() 148 | model.set_eval() 149 | 150 | avg_reconstruction_metric = defaultdict(float) 151 | 152 | count = 0 153 | 154 | process = tqdm(eval_dataloader) 155 | for batch_idx, inputs in enumerate(process): 156 | outputs, _ = model.process_batch(inputs, self.rank) 157 | 158 | psnr, ssim, lpips= self.compute_reconstruction_metrics(inputs, outputs) 159 | 160 | avg_reconstruction_metric['psnr'] += psnr 161 | avg_reconstruction_metric['ssim'] += ssim 162 | avg_reconstruction_metric['lpips'] += lpips 163 | count += 1 164 | 165 | process.set_description(f"PSNR: {psnr:.4f}, SSIM: {ssim:.4f}, LPIPS: {lpips:.4f}") 166 | 167 | print(f"\n{inputs['token'][0]}") 168 | print(f"avg PSNR: {avg_reconstruction_metric['psnr']/count:.4f}, avg SSIM: {avg_reconstruction_metric['ssim']/count:.4f}, avg LPIPS: {avg_reconstruction_metric['lpips']/count:.4f}") 169 | 170 | avg_reconstruction_metric['psnr'] /= len(eval_dataloader) 171 | avg_reconstruction_metric['ssim'] /= len(eval_dataloader) 172 | avg_reconstruction_metric['lpips'] /= len(eval_dataloader) 173 | 174 | print('Evaluation reconstruction result...\n') 175 | self.logger.print_perf(avg_reconstruction_metric, 'reconstruction') 176 | 177 | def save_image( 178 | self, 179 | image: FloatImage, 180 | path: Union[Path, str], 181 | ) -> None: 182 | """Save an image. Assumed to be in range 0-1.""" 183 | 184 | # Create the parent directory if it doesn't already exist. 185 | path = Path(path) 186 | path.parent.mkdir(exist_ok=True, parents=True) 187 | 188 | # Save the image. 189 | Image.fromarray(self.prep_image(image)).save(path) 190 | 191 | 192 | def prep_image(self, image: FloatImage) -> UInt8[np.ndarray, "height width channel"]: 193 | # Handle batched images. 194 | if image.ndim == 4: 195 | image = rearrange(image, "b c h w -> c h (b w)") 196 | 197 | # Handle single-channel images. 198 | if image.ndim == 2: 199 | image = rearrange(image, "h w -> () h w") 200 | 201 | # Ensure that there are 3 or 4 channels. 202 | channel, _, _ = image.shape 203 | if channel == 1: 204 | image = repeat(image, "() h w -> c h w", c=3) 205 | assert image.shape[0] in (3, 4) 206 | 207 | image = (image.detach().clip(min=0, max=1) * 255).type(torch.uint8) 208 | return rearrange(image, "c h w -> h w c").cpu().numpy() 209 | 210 | @torch.no_grad() 211 | def compute_reconstruction_metrics(self, inputs, outputs): 212 | """ 213 | This function computes reconstruction metrics. 214 | """ 215 | psnr = 0.0 216 | ssim = 0.0 217 | lpips = 0.0 218 | if self.novel_view_mode == 'SF': 219 | frame_id = 1 220 | elif self.novel_view_mode == 'MF': 221 | frame_id = 0 222 | else: 223 | raise ValueError(f"Invalid novel view mode: {self.novel_view_mode}") 224 | for cam in range(self.num_cams): 225 | rgb_gt = inputs[('color', frame_id, 0)][:, cam, ...] 226 | image = outputs[('cam', cam)][('gaussian_color', frame_id, 0)] 227 | psnr += self.compute_psnr(rgb_gt, image).mean() 228 | ssim += self.compute_ssim(rgb_gt, image).mean() 229 | lpips += self.compute_lpips(rgb_gt, image).mean() 230 | if self.save_images: 231 | assert self.eval_batch_size == 1 232 | if self.novel_view_mode == 'SF': 233 | self.save_image(image, Path(self.save_path) / inputs['token'][0] / f"{cam}.png") 234 | self.save_image(rgb_gt, Path(self.save_path) / inputs['token'][0] / f"{cam}_gt.png") 235 | self.save_image(inputs[('color', 0, 0)][:, cam, ...], Path(self.save_path) / inputs['token'][0] / f"{cam}_0_gt.png") 236 | elif self.novel_view_mode == 'MF': 237 | self.save_image(image, Path(self.save_path) / inputs['token'][0] / f"{cam}.png") 238 | self.save_image(rgb_gt, Path(self.save_path) / inputs['token'][0] / f"{cam}_gt.png") 239 | self.save_image(inputs[('color', -1, 0)][:, cam, ...], Path(self.save_path) / inputs['token'][0] / f"{cam}_prev_gt.png") 240 | self.save_image(inputs[('color', 1, 0)][:, cam, ...], Path(self.save_path) / inputs['token'][0] / f"{cam}_next_gt.png") 241 | psnr /= self.num_cams 242 | ssim /= self.num_cams 243 | lpips /= self.num_cams 244 | return psnr, ssim, lpips 245 | 246 | @torch.no_grad() 247 | def compute_psnr( 248 | self, 249 | ground_truth: Float[Tensor, "batch channel height width"], 250 | predicted: Float[Tensor, "batch channel height width"], 251 | ) -> Float[Tensor, " batch"]: 252 | ground_truth = ground_truth.clip(min=0, max=1) 253 | predicted = predicted.clip(min=0, max=1) 254 | mse = reduce((ground_truth - predicted) ** 2, "b c h w -> b", "mean") 255 | return -10 * mse.log10() 256 | 257 | @torch.no_grad() 258 | def compute_lpips( 259 | self, 260 | ground_truth: Float[Tensor, "batch channel height width"], 261 | predicted: Float[Tensor, "batch channel height width"], 262 | ) -> Float[Tensor, " batch"]: 263 | value = self.lpips.forward(ground_truth, predicted, normalize=True) 264 | return value[:, 0, 0, 0] 265 | 266 | @torch.no_grad() 267 | def compute_ssim( 268 | self, 269 | ground_truth: Float[Tensor, "batch channel height width"], 270 | predicted: Float[Tensor, "batch channel height width"], 271 | ) -> Float[Tensor, " batch"]: 272 | ssim = [ 273 | structural_similarity( 274 | gt.detach().cpu().numpy(), 275 | hat.detach().cpu().numpy(), 276 | win_size=11, 277 | gaussian_weights=True, 278 | channel_axis=0, 279 | data_range=1.0, 280 | ) 281 | for gt, hat in zip(ground_truth, predicted) 282 | ] 283 | return torch.tensor(ssim, dtype=predicted.dtype, device=predicted.device) 284 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import Logger 2 | from .misc import get_config 3 | 4 | __all__ = ['Logger', 'get_config'] 5 | 6 | 7 | import sys 8 | 9 | 10 | _LIBS = ['./external/packnet_sfm', './external/dgp', './external/monodepth2'] 11 | 12 | def setup_env(): 13 | if not _LIBS[0] in sys.path: 14 | for lib in _LIBS: 15 | sys.path.append(lib) 16 | 17 | setup_env() -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import PIL.Image as pil 9 | 10 | from tensorboardX import SummaryWriter 11 | 12 | from .visualize import colormap 13 | from .misc import pretty_ts, cal_depth_error 14 | 15 | 16 | def set_tb_title(*args): 17 | """ 18 | This function sets title for tensorboard plot. 19 | """ 20 | title = '' 21 | for i, s in enumerate(args): 22 | if not i%2: title += '/' 23 | s = s if isinstance(s, str) else str(s) 24 | title += s 25 | return title[1:] 26 | 27 | 28 | def resize_for_tb(image): 29 | """ 30 | This function resizes images for tensorboard plot. 31 | """ 32 | h, w = image.size()[-2:] 33 | return F.interpolate(image, [h//2, w//2], mode='bilinear', align_corners=True) 34 | 35 | 36 | def plot_tb(writer, step, img, title, j=0): 37 | """ 38 | This function plots images on tensotboard. 39 | """ 40 | img_resized = resize_for_tb(img) 41 | writer.add_image(title, img_resized[j].data, step) 42 | 43 | 44 | def plot_norm_tb(writer, step, img, title, j=0): 45 | """ 46 | This function plots normalized images on tensotboard. 47 | """ 48 | img_resized = torch.clamp(resize_for_tb(img), 0., 1.) 49 | writer.add_image(title, img_resized[j].data, step) 50 | 51 | 52 | def plot_disp_tb(writer, step, disp, title, j=0): 53 | """ 54 | This function plots disparity maps on tensotboard. 55 | """ 56 | disp_resized = resize_for_tb(disp).float() 57 | disp_resized = colormap(disp_resized[j, 0]) 58 | writer.add_image(title, disp_resized, step) 59 | 60 | 61 | class Logger: 62 | """ 63 | Logger class to monitor training 64 | """ 65 | def __init__(self, cfg, use_tb=True): 66 | self.read_config(cfg) 67 | os.makedirs(self.log_path, exist_ok=True) 68 | 69 | if use_tb: 70 | self.init_tb() 71 | 72 | if self.eval_visualize: 73 | self.init_vis() 74 | 75 | self._metric_names = ['abs_rel', 'sq_rel', 'rms', 'log_rms', 'a1', 'a2', 'a3'] 76 | 77 | def read_config(self, cfg): 78 | for attr in cfg.keys(): 79 | for k, v in cfg[attr].items(): 80 | setattr(self, k, v) 81 | 82 | def init_tb(self): 83 | self.writers = {} 84 | for mode in ['train', 'val']: 85 | self.writers[mode] = SummaryWriter(os.path.join(self.log_path, mode)) 86 | 87 | def close_tb(self): 88 | for mode in ['train', 'val']: 89 | self.writers[mode].close() 90 | 91 | def init_vis(self): 92 | vis_path = os.path.join(self.log_path, 'vis_results') 93 | os.makedirs(vis_path, exist_ok=True) 94 | 95 | self.cam_paths = [] 96 | for cam_id in range(self.num_cams): 97 | cam_path = os.path.join(vis_path, f'cam{cam_id:d}') 98 | os.makedirs(cam_path, exist_ok=True) 99 | self.cam_paths.append(cam_path) 100 | 101 | def get_metric_names(self): 102 | return self._metric_names 103 | 104 | def to_disp(self, depth_in, K_in): 105 | """ 106 | This function transforms depth values into disparity values while dividing the value by the focal length. 107 | """ 108 | disp = depth_in * self.focal_length_scale / K_in[:, 0:1, 0:1].unsqueeze(2) 109 | disp = 1 / (disp + 0.00001) 110 | 111 | min_disp = 1/self.max_depth 112 | max_disp = 1/self.min_depth 113 | disp_range = max_disp - min_disp 114 | 115 | disp = (disp - min_disp) / disp_range 116 | 117 | return disp 118 | 119 | def update(self, mode, epoch, world_size, batch_idx, step, start_time, before_op_time, inputs, outputs, losses): 120 | """ 121 | Display logs with respect to the log frequency 122 | """ 123 | # iteration duration 124 | duration = time.time() - before_op_time 125 | 126 | if self.is_checkpoint(step): 127 | self.log_time(epoch, batch_idx * world_size, duration, losses, start_time) 128 | self.log_tb(mode, inputs, outputs, losses, step) 129 | 130 | def is_checkpoint(self, step): 131 | """ 132 | Log less frequently after the early phase steps 133 | """ 134 | early_phase = (step % self.log_frequency == 0) and (step < self.early_phase) 135 | late_phase = step % self.late_log_frequency == 0 136 | return (early_phase or late_phase) 137 | 138 | def log_time(self, epoch, batch_idx, duration, loss, start_time): 139 | """ 140 | This function prints epoch, iteration, duration, loss and spent time. 141 | """ 142 | rep_loss = loss['total_loss'].item() 143 | samples_per_sec = self.batch_size / duration 144 | time_sofar = time.time() - start_time 145 | print("") 146 | print(f'epoch: {epoch:2d} | batch: {batch_idx:6d} |' + \ 147 | f'examples/s: {samples_per_sec:5.1f} | loss: {rep_loss:.3f} | time elapsed: {pretty_ts(time_sofar)}') 148 | 149 | def log_tb(self, mode, inputs, outputs, losses, step): 150 | """ 151 | This function logs outputs for monitoring using tensorboard. 152 | """ 153 | writer = self.writers[mode] 154 | # loss 155 | for l, v in losses.items(): 156 | writer.add_scalar(f'{l}', v, step) 157 | 158 | scale = 0 # plot the maximum scale 159 | for cam_id in range(self.num_cams): 160 | target_view = outputs[('cam', cam_id)] 161 | 162 | plot_tb(writer, step, inputs[('color', 0, scale)][:, cam_id, ...], set_tb_title('cam', cam_id, 'frame_0')) # frame_id 0 163 | plot_tb(writer, step, inputs[('color', 1, scale)][:, cam_id, ...], set_tb_title('cam', cam_id, 'frame_1')) # frame_id 1 164 | plot_tb(writer, step, inputs[('color', -1, scale)][:, cam_id, ...], set_tb_title('cam', cam_id, 'frame_-1')) # frame_id -1 165 | plot_disp_tb(writer, step, target_view[('disp', scale)], set_tb_title('cam', cam_id, 'disp')) # disparity 166 | depth_gt = inputs['depth'][:, cam_id, ...] 167 | far = depth_gt == 0 168 | depth_gt[far] += 80.0 169 | disp_gt = self.to_disp(depth_gt, inputs[('K', 0)][:, cam_id, ...]) 170 | plot_disp_tb(writer, step, disp_gt, set_tb_title('cam', cam_id, 'disp_gt')) 171 | 172 | plot_tb(writer, step, target_view[('reproj_loss', scale)], set_tb_title('cam', cam_id, 'reproj')) # reprojection image 173 | plot_tb(writer, step, target_view[('reproj_mask', scale)], set_tb_title('cam', cam_id, 'reproj_mask')) # reprojection mask 174 | plot_tb(writer, step, inputs['mask'][:, cam_id, ...], set_tb_title('cam', cam_id, 'self_occ_mask')) 175 | 176 | if self.spatio: 177 | plot_norm_tb(writer, step, target_view[('overlap', 0, scale)], set_tb_title('cam', cam_id, 'sp')) 178 | plot_tb(writer, step, target_view[('overlap_mask', 0, scale)], set_tb_title('cam', cam_id, 'sp_mask')) 179 | 180 | if self.spatio_temporal: 181 | for frame_id in self.frame_ids: 182 | if frame_id == 0: 183 | continue 184 | plot_norm_tb(writer, step, target_view[('color', frame_id, scale)], set_tb_title('cam', cam_id, 'pred_', frame_id)) 185 | plot_norm_tb(writer, step, target_view[('overlap', frame_id, scale)], set_tb_title('cam', cam_id, 'sp_tm_', frame_id)) 186 | plot_tb(writer, step, target_view[('overlap_mask', frame_id, scale)], set_tb_title('cam', cam_id, 'sp_tm_mask_', frame_id)) 187 | 188 | if self.gaussian: 189 | if self.novel_view_mode == 'SF': 190 | frame_id = 1 191 | elif self.novel_view_mode == 'MF': 192 | frame_id = 0 193 | else: 194 | raise ValueError(f'Novel view mode {self.novel_view_mode} not supported.') 195 | plot_norm_tb(writer, step, target_view[('gaussian_color', frame_id, scale)], set_tb_title('cam', cam_id, f'gaussian_pred_frame_{frame_id}')) 196 | 197 | def log_result(self, inputs, outputs, idx, syn_visualize=False): 198 | """ 199 | This function logs outputs for visualization. 200 | """ 201 | scale = 0 202 | for cam_id in range(self.num_cams): 203 | target_view = outputs[('cam', cam_id)] 204 | disps = target_view['disp', scale] 205 | for jdx, disp in enumerate(disps): 206 | disp = colormap(disp)[0,...].transpose(1,2,0) 207 | disp = pil.fromarray((disp * 255).astype(np.uint8)) 208 | cur_idx = idx*self.batch_size + jdx 209 | disp.save(os.path.join(self.cam_paths[cam_id], f'{cur_idx:03d}_disp.jpg')) 210 | 211 | def compute_depth_losses(self, inputs, outputs, vis_scale=False): 212 | """ 213 | This function computes depth metrics, to allow monitoring of training process on validation dataset. 214 | """ 215 | min_eval_depth = self.eval_min_depth 216 | max_eval_depth = self.eval_max_depth 217 | 218 | med_scale = [] 219 | 220 | error_metric_dict = defaultdict(float) 221 | error_median_dict = defaultdict(float) 222 | 223 | for cam in range(self.num_cams): 224 | target_view = outputs['cam', cam] 225 | 226 | depth_gt = inputs['depth'][:, cam, ...] 227 | 228 | _, _, h, w = depth_gt.shape 229 | 230 | depth_pred = target_view[('depth', 0, 0)].to(depth_gt.device) 231 | depth_pred = torch.clamp(F.interpolate( 232 | depth_pred, [h, w], mode='bilinear', align_corners=False), 233 | min_eval_depth, max_eval_depth) 234 | depth_pred = depth_pred.detach() 235 | 236 | mask = (depth_gt > min_eval_depth) * (depth_gt < max_eval_depth) * inputs['mask'][:, cam, ...] 237 | mask = mask.bool() 238 | 239 | depth_gt = depth_gt[mask] 240 | depth_pred = depth_pred[mask] 241 | 242 | # calculate median scale 243 | scale_val = torch.median(depth_gt) / torch.median(depth_pred) 244 | med_scale.append(round(scale_val.cpu().numpy().item(), 2)) 245 | 246 | depth_pred_metric = torch.clamp(depth_pred, min=min_eval_depth, max=max_eval_depth) 247 | depth_errors_metric = cal_depth_error(depth_pred_metric, depth_gt) 248 | 249 | depth_pred_median = torch.clamp(depth_pred * scale_val, min=min_eval_depth, max=max_eval_depth) 250 | depth_errors_median = cal_depth_error(depth_pred_median, depth_gt) 251 | 252 | for i in range(len(depth_errors_metric)): 253 | key = self._metric_names[i] 254 | error_metric_dict[key] += depth_errors_metric[i] 255 | error_median_dict[key] += depth_errors_median[i] 256 | 257 | if vis_scale==True: 258 | # print median scale 259 | print(f' | median scale = {med_scale}') 260 | 261 | for key in error_metric_dict.keys(): 262 | error_metric_dict[key] = error_metric_dict[key].cpu().numpy() / self.num_cams 263 | error_median_dict[key] = error_median_dict[key].cpu().numpy() / self.num_cams 264 | 265 | return error_metric_dict, error_median_dict 266 | 267 | def print_perf(self, loss, scale): 268 | """ 269 | This function prints various metrics for depth estimation accuracy. 270 | """ 271 | perf = ' '*3 + scale 272 | for k, v in loss.items(): 273 | perf += ' | ' + str(k) + f': {v:.3f}' 274 | print(perf) 275 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from collections import defaultdict 4 | 5 | import torch 6 | 7 | _NUSC_CAM_LIST = ['CAM_FRONT', 'CAM_FRONT_LEFT', 'CAM_FRONT_RIGHT', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT', 'CAM_BACK'] 8 | _REL_CAM_DICT = {0: [1,2], 1: [0,3], 2: [0,4], 3: [1,5], 4: [2,5], 5: [3,4]} 9 | 10 | 11 | def camera2ind(cameras): 12 | """ 13 | This function transforms camera name list to indices 14 | """ 15 | indices = [] 16 | for cam in cameras: 17 | if cam in _NUSC_CAM_LIST: 18 | ind = _NUSC_CAM_LIST.index(cam) 19 | else: 20 | ind = None 21 | indices.append(ind) 22 | return indices 23 | 24 | 25 | def get_relcam(cameras): 26 | """ 27 | This function returns relative camera indices from given camera list 28 | """ 29 | relcam_dict = defaultdict(list) 30 | indices = camera2ind(cameras) 31 | for ind in indices: 32 | relcam_dict[ind] = [] 33 | relcam_cand = _REL_CAM_DICT[ind] 34 | for cand in relcam_cand: 35 | if cand in indices: 36 | relcam_dict[ind].append(cand) 37 | return relcam_dict 38 | 39 | 40 | def get_config(config, mode='train', weight_path='./', novel_view_mode='MF'): 41 | """ 42 | This function reads the configuration file and return as dictionary 43 | """ 44 | with open(config, 'r') as stream: 45 | cfg = yaml.load(stream, Loader=yaml.FullLoader) 46 | 47 | cfg_name = os.path.splitext(os.path.basename(config))[0] 48 | print('Experiment: ', cfg_name) 49 | 50 | _log_path = os.path.join(cfg['data']['log_dir'], cfg_name) 51 | cfg['data']['log_path'] = _log_path 52 | cfg['data']['save_weights_root'] = os.path.join(_log_path, 'models') 53 | cfg['data']['num_cams'] = len(cfg['data']['cameras']) 54 | cfg['data']['rel_cam_list'] = get_relcam(cfg['data']['cameras']) 55 | 56 | cfg['model']['mode'] = mode 57 | cfg['model']['novel_view_mode'] = novel_view_mode 58 | 59 | cfg['load']['load_weights_dir'] = weight_path 60 | 61 | if mode == 'eval': 62 | cfg['ddp']['world_size'] = 1 63 | cfg['ddp']['gpus'] = [0] 64 | cfg['training']['batch_size'] = cfg['eval']['eval_batch_size'] 65 | cfg['training']['depth_flip'] = False 66 | return cfg 67 | 68 | 69 | def pretty_ts(ts): 70 | """ 71 | This function prints amount of time taken in user friendly way. 72 | """ 73 | second = int(ts) 74 | minute = second // 60 75 | hour = minute // 60 76 | return f'{hour:02d}h{(minute%60):02d}m{(second%60):02d}s' 77 | 78 | 79 | def cal_depth_error(pred, target): 80 | """ 81 | This function calculates depth error using various metrics. 82 | """ 83 | abs_rel = torch.mean(torch.abs(pred-target) / target) 84 | sq_rel = torch.mean((pred-target).pow(2) / target) 85 | rmse = torch.sqrt(torch.mean((pred-target).pow(2))) 86 | rmse_log = torch.sqrt(torch.mean((torch.log(target) - torch.log(pred)).pow(2))) 87 | 88 | thresh = torch.max((target/pred), (pred/ target)) 89 | a1 = (thresh < 1.25).float().mean() 90 | a2 = (thresh < 1.25**2).float().mean() 91 | a3 = (thresh < 1.25**3).float().mean() 92 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | import torch 4 | 5 | _DEGTORAD = 0.0174533 6 | 7 | 8 | def aug_depth_params(K, n_steps= 75): 9 | """ 10 | This function augments camera parameters for depth synthesis. 11 | """ 12 | # augmented parameters for visualization 13 | aug_params = [] 14 | 15 | # roll augmentations 16 | roll_aug = [i for i in range(0, n_steps + 1, 2)] + [i for i in range(n_steps, -n_steps - 1, -2)] + [i for i in range(-n_steps, 1, 2)] 17 | ang_y, ang_z = 0.0, 0.0 18 | for angle in roll_aug: 19 | ang_x = _DEGTORAD * (angle / n_steps * 10.) 20 | aug_params.append([torch.inverse(K), ang_x, ang_y, ang_z]) 21 | 22 | # pitch augmentations 23 | pitch_aug = [i for i in range(0, 50 + 1, 2)] + [i for i in range(50, -50 - 1, -2)] + [i for i in range(-50, 1, 2)] 24 | ang_x, ang_z = 0.0, 0.0 25 | for angle in pitch_aug: 26 | ang_y = _DEGTORAD * (angle / 10.) 27 | aug_params.append([torch.inverse(K), ang_x, ang_y, ang_z]) 28 | 29 | # focal length augmentations 30 | focal_ratio = K[:, 1, 0, 0] / K[:, 0, 0, 0] 31 | focal_ratio_aug = focal_ratio / 1.5 32 | ang_x, ang_y, ang_z = 0.0, 0.0, 0.0 33 | 34 | for f_idx in range(100 + 1): 35 | f_scale = (f_idx / 100. * focal_ratio_aug + (1 - f_idx / 100.))[:, None] 36 | K_aug = K.clone() 37 | K_aug[:, :, 0, 0] *= f_scale 38 | K_aug[:, :, 1, 1] *= f_scale 39 | aug_params.append([torch.inverse(K_aug), ang_x, ang_y, ang_z]) 40 | 41 | for f_idx in range(50 + 1): 42 | f_scale = (f_idx / 50. * focal_ratio + (1 - f_idx / 50.) * focal_ratio_aug)[:, None] 43 | K_aug = K.clone() 44 | K_aug[:, :, 0, 0] *= f_scale 45 | K_aug[:, :, 1, 1] *= f_scale 46 | aug_params.append([torch.inverse(K_aug), ang_x, ang_y, ang_z]) 47 | 48 | # yaw augmentations 49 | yaw_aug = [i for i in range(360)] 50 | inv_K_aug = torch.inverse(K_aug) 51 | ang_x, ang_y = 0.0, 0.0 52 | for i in yaw_aug: 53 | ratio_i = i / 360. 54 | ang_z = _DEGTORAD * 360 * ratio_i 55 | aug_params.append([inv_K_aug, ang_x, ang_y, ang_z]) 56 | return aug_params 57 | 58 | 59 | def colormap(vis, normalize=True, torch_transpose=True): 60 | """ 61 | This function visualizes disparity map using colormap specified with disparity map variable. 62 | """ 63 | disparity_map = plt.get_cmap('plasma', 256) # for plotting 64 | 65 | if isinstance(vis, torch.Tensor): 66 | vis = vis.detach().cpu().numpy() 67 | 68 | if normalize: 69 | ma = float(vis.max()) 70 | mi = float(vis.min()) 71 | d = ma - mi if ma != mi else 1e5 72 | vis = (vis - mi) / d 73 | 74 | if vis.ndim == 4: 75 | vis = vis.transpose([0, 2, 3, 1]) 76 | vis = disparity_map(vis) 77 | vis = vis[:, :, :, 0, :3] 78 | if torch_transpose: 79 | vis = vis.transpose(0, 3, 1, 2) 80 | elif vis.ndim == 3: 81 | vis = disparity_map(vis) 82 | vis = vis[:, :, :, :3] 83 | if torch_transpose: 84 | vis = vis.transpose(0, 3, 1, 2) 85 | elif vis.ndim == 2: 86 | vis = disparity_map(vis) 87 | vis = vis[..., :3] 88 | if torch_transpose: 89 | vis = vis.transpose(2, 0, 1) 90 | return vis --------------------------------------------------------------------------------